-
Notifications
You must be signed in to change notification settings - Fork 98
Allow opset_version to be set explicitly when exporting
#2615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2615 +/- ##
=======================================
Coverage 70.39% 70.39%
=======================================
Files 222 222
Lines 26281 26281
Branches 2629 2629
=======================================
Hits 18500 18500
Misses 6861 6861
Partials 920 920 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a reasonable change, thanks.
|
@gramalingam for a second opinion |
I think it would be nice to explicitly set opset_version when exporting, particularly when a custom/particular Opset is being used and the default opset can't be inferred.
Example:
```py
from onnxscript import script
from onnxscript import opset15 as op
from onnxscript.values import Opset
import numpy as np
from onnxscript import STRING
from onnxruntime import InferenceSession
ai_onnx = Opset("ai.onnx.ml", version=2)
@script(ai_onnx, default_opset = op)
def label_encoder(X: STRING["D"]):
Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42)
# Y = Y + 0.0 # to force opset version downgrade
return Y
print(label_encoder(np.array(["a", "b", "c"])))
session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString())
for key, value in {"a": 0, "b": 1, "c": 2}.items():
assert label_encoder(np.array([key]))[0] == value
assert session.run(None, {"X": np.array([key])})[0] == value
```
This currently errors with
```sh
Traceback (most recent call last):
File "/Users/XXX/Development/projects/jet/test_onnxscript_label.py", line 25, in <module>
session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString())
File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 472, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 552, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : /Users/runner/work/1/s/onnxruntime/core/graph/model_load_utils.h:56 void onnxruntime::model_load_utils::ValidateOpsetForDomain(const std::unordered_map<std::string, int> &, const logging::Logger &, bool, const std::string &, int) ONNX Runtime only *guarantees* support for models stamped with official released onnx opset versions. Opset 23 is under development and support for this is limited. The operator schemas and or other functionality may change before next ONNX release and in this case ONNX Runtime will not guarantee backward compatibility. Current official support for domain ai.onnx is till opset 22.
```
To force it to work in the current state, one would have to do:
```
@script(ai_onnx, default_opset = op)
def label_encoder(X: STRING["D"]):
Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42)
# Y = Y + 0.0 # to force opset version downgrade
return Y
```
To force the opset to be downgraded, since the `default_opset` is never called.
Happy to be challenged if there is a better way. I can imagine something weird/unintended might occur if the user sets `default_opset` to something other than what is defined in `@script(..., default_opset=<op>)` but that generally shouldn't be a problem?
|
Thanks! Merging now. @gramalingam feel free to add if you have further comments |
I think it would be nice to explicitly set opset_version when exporting, particularly when a custom/particular Opset is being used and the default opset can't be inferred.
Example:
This currently errors with
To force it to work in the current state, one would have to do:
To force the opset to be downgraded, since the
default_opsetis never called.Happy to be challenged if there is a better way. I can imagine something weird/unintended might occur if the user sets
default_opsetto something other than what is defined in@script(..., default_opset=<op>)but that generally shouldn't be a problem?