-
-
Notifications
You must be signed in to change notification settings - Fork 299
Description
Hello, I've got a working ONNX export thanks to tracing and _batched_nms_coordinate_trick
. I've looked at previous issues and couldn't find a working ONNX export here or in any other pytorch implementation so I thought it would be nice to post this here.
I tested with efficientdet_d0 and tf_efficientdet_d0, both seem to work well. I had to replace batched_nms
(code below) to only use _batched_nms_coordinate_trick
, as torch onnx export doesn't like the curr_keep_indices
tensor. However exporting torchvision models (retinanet, fasterrcnn...) works and I'm not sure why, my best guess is that torch infers that boxes.numel() <= 4000
for them to only use _batched_nms_coordinate_trick
, but doesn't do that for effdet?
Anyway the hack to replace the function works enough for the export. Maybe @rwightman do you have any idea if there is an equivalent to torch.jit.is_scripting()
for the export, so that we could call _batched_nms_coordinate_trick
in this case to support natively the export?
The second trick is to trace the model: the direct export doesn't fail but the resulting file cannot be used, I get a [ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0.
when trying to run an inference. onnx.checker.check_model(onnx.load(file))
doesn't raise any issue, I didn't manage to track the bug yet, I don't even know if it's on pytorch side or onnxruntime.
Details for the error with batched_nms:
torch cannot infer the type (or shape?) of curr_keep_indices
inside torchvision.ops.boxes._batched_nms_vanilla
. [open for stacktrace]
curr_keep_indices
inside torchvision.ops.boxes._batched_nms_vanilla
. [open for stacktrace]Traceback (most recent call last):
File "/.../2024/11/export_onnx_effdet.py", line 52, in <module>
main()
File "/.../2024/11/export_onnx_effdet.py", line 28, in main
torch.onnx.export(model, (bchw,), 'efficientdet_d0.onnx', input_names=['input'], output_names=['output'])
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
export(
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
_export(
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
graph = _optimize_graph(
^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6417, in prim_loop
torch._C._jit_pass_onnx_block(
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6508, in prim_if
torch._C._jit_pass_onnx_block(
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6417, in prim_loop
torch._C._jit_pass_onnx_block(
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset11.py", line 938, in index
or _type_utils.JitScalarType.from_value(index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/_type_utils.py", line 239, in from_value
raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>' instance and a default value was not provided. [Caused by the value 'curr_keep_indices defined in (%curr_keep_indices : Tensor = onnx::If(%3938) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
block0():
%3940 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}]() # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
%3941 : Long(*, device=cpu) = onnx::Squeeze(%3932, %3940) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
-> (%3941)
block1():
%3942 : Long(*, *, device=cpu) = onnx::Identity(%3932) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
-> (%3942)
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::If'.]
(node defined in File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py", line 41
_log_api_usage_once(nms)
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
)
Inputs:
#0: 3938 defined in (%3938 : Bool(1, strides=[1], device=cpu) = onnx::Equal(%3936, %3937) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
) (type 'Tensor')
Outputs:
#0: curr_keep_indices defined in (%curr_keep_indices : Tensor = onnx::If(%3938) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
block0():
%3940 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}]() # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
%3941 : Long(*, device=cpu) = onnx::Squeeze(%3932, %3940) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
-> (%3941)
block1():
%3942 : Long(*, *, device=cpu) = onnx::Identity(%3932) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
-> (%3942)
) (type 'Tensor')
Code of torchvision.ops.boxes
:
def batched_nms(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float) -> Tensor:
# [...]
# Benchmarks that drove the following thresholds are at
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
@torch.jit._script_if_tracing
def _batched_nms_vanilla(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float) -> Tensor:
# Based on Detectron2 implementation, just manually call nms() on each class independently
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
curr_indices = torch.where(idxs == class_id)[0]
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
Here is the code to do the export:
import onnxruntime as ort
import torch
from torchvision.ops.boxes import _batched_nms_coordinate_trick
import torchvision
def main():
torchvision.ops.boxes.batched_nms = _batched_nms_coordinate_trick
import effdet # must be imported after the nms replacement
bchw = torch.randn(1, 3, 512, 512)
model = effdet.create_model('efficientdet_d0', bench_task='predict').eval()
print(model(bchw)[:, :2])
# Direct export completes but the resulting file cannot be run. Pytorch bug?
torch.onnx.export(model, (bchw,), 'efficientdet_d0.onnx', input_names=['input'], output_names=['output'])
try:
omodel = ort.InferenceSession('efficientdet_d0.onnx')
print(omodel.run(None, {'input': bchw.numpy()})[:, :2])
except Exception as err:
print(err)
# currently fails with:
# [ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0.
# Export of the traced model works
traced = torch.jit.trace(model, (bchw,))
torch.onnx.export(traced, (bchw,), 'efficientdet_d0_traced.onnx', input_names=['input'], output_names=['output'])
omodel = ort.InferenceSession('efficientdet_d0_traced.onnx')
print(omodel.run(None, {'input': bchw.numpy()})[0][:, :2])
if __name__ == '__main__':
main()
Example output on my machine, torch==2.5.1+cu121, torchvision==0.20.1+cu121, onnx==1.17.0, onnx==1.17.0, onnxruntime==1.20.1, effdet==0.4.1
:
Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.
tensor([[[4.6644e+02, 2.9536e+02, 5.1953e+02, 3.2238e+02, 1.1469e-02,
4.0000e+01],
[3.8305e+02, 8.1529e+00, 4.2393e+02, 4.7712e+01, 1.1439e-02,
3.6000e+01]]], grad_fn=<SliceBackward0>)
[ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0.
/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py:782: UserWarning: no signature found for builtin <built-in method __call__ of PyCapsule object at 0x7f2b773d5350>, skipping _decide_input_format
warnings.warn(f"{e}, skipping _decide_input_format")
[[[4.6644165e+02 2.9536493e+02 5.1953387e+02 3.2238403e+02 1.1469424e-02
4.0000000e+01]
[3.8305466e+02 8.1528721e+00 4.2393228e+02 4.7711700e+01 1.1438489e-02
3.6000000e+01]]]