Skip to content

Working ONNX export with jit.trace + _batched_nms_coordinate_trick. #302

@artemisart

Description

@artemisart

Hello, I've got a working ONNX export thanks to tracing and _batched_nms_coordinate_trick. I've looked at previous issues and couldn't find a working ONNX export here or in any other pytorch implementation so I thought it would be nice to post this here.

I tested with efficientdet_d0 and tf_efficientdet_d0, both seem to work well. I had to replace batched_nms (code below) to only use _batched_nms_coordinate_trick, as torch onnx export doesn't like the curr_keep_indices tensor. However exporting torchvision models (retinanet, fasterrcnn...) works and I'm not sure why, my best guess is that torch infers that boxes.numel() <= 4000 for them to only use _batched_nms_coordinate_trick, but doesn't do that for effdet?

Anyway the hack to replace the function works enough for the export. Maybe @rwightman do you have any idea if there is an equivalent to torch.jit.is_scripting() for the export, so that we could call _batched_nms_coordinate_trick in this case to support natively the export?

The second trick is to trace the model: the direct export doesn't fail but the resulting file cannot be used, I get a [ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0. when trying to run an inference. onnx.checker.check_model(onnx.load(file)) doesn't raise any issue, I didn't manage to track the bug yet, I don't even know if it's on pytorch side or onnxruntime.

Details for the error with batched_nms:

torch cannot infer the type (or shape?) of curr_keep_indices inside torchvision.ops.boxes._batched_nms_vanilla. [open for stacktrace]

Traceback (most recent call last):
  File "/.../2024/11/export_onnx_effdet.py", line 52, in <module>
    main()
  File "/.../2024/11/export_onnx_effdet.py", line 28, in main
    torch.onnx.export(model, (bchw,), 'efficientdet_d0.onnx', input_names=['input'], output_names=['output'])
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6417, in prim_loop
    torch._C._jit_pass_onnx_block(
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6508, in prim_if
    torch._C._jit_pass_onnx_block(
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset9.py", line 6417, in prim_loop
    torch._C._jit_pass_onnx_block(
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/symbolic_opset11.py", line 938, in index
    or _type_utils.JitScalarType.from_value(index)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/_type_utils.py", line 239, in from_value
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>' instance and a default value was not provided.  [Caused by the value 'curr_keep_indices defined in (%curr_keep_indices : Tensor = onnx::If(%3938) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
  block0():
    %3940 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}]() # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
    %3941 : Long(*, device=cpu) = onnx::Squeeze(%3932, %3940) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
    -> (%3941)
  block1():
    %3942 : Long(*, *, device=cpu) = onnx::Identity(%3932) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
    -> (%3942)
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::If'.] 
    (node defined in   File "/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py", line 41
        _log_api_usage_once(nms)
    _assert_has_ops()
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
)

    Inputs:
        #0: 3938 defined in (%3938 : Bool(1, strides=[1], device=cpu) = onnx::Equal(%3936, %3937) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
    )  (type 'Tensor')
    Outputs:
        #0: curr_keep_indices defined in (%curr_keep_indices : Tensor = onnx::If(%3938) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
      block0():
        %3940 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}]() # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
        %3941 : Long(*, device=cpu) = onnx::Squeeze(%3932, %3940) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
        -> (%3941)
      block1():
        %3942 : Long(*, *, device=cpu) = onnx::Identity(%3932) # /home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torchvision/ops/boxes.py:41:11
        -> (%3942)
    )  (type 'Tensor')

Code of torchvision.ops.boxes:

def batched_nms(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float) -> Tensor:
    # [...]
    # Benchmarks that drove the following thresholds are at
    # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
    if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
        return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
    else:
        return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

@torch.jit._script_if_tracing
def _batched_nms_vanilla(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float) -> Tensor:
    # Based on Detectron2 implementation, just manually call nms() on each class independently
    keep_mask = torch.zeros_like(scores, dtype=torch.bool)
    for class_id in torch.unique(idxs):
        curr_indices = torch.where(idxs == class_id)[0]
        curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
        keep_mask[curr_indices[curr_keep_indices]] = True
    keep_indices = torch.where(keep_mask)[0]
    return keep_indices[scores[keep_indices].sort(descending=True)[1]]

Here is the code to do the export:

import onnxruntime as ort
import torch
from torchvision.ops.boxes import _batched_nms_coordinate_trick
import torchvision


def main():
    torchvision.ops.boxes.batched_nms = _batched_nms_coordinate_trick
    import effdet  # must be imported after the nms replacement

    bchw = torch.randn(1, 3, 512, 512)
    model = effdet.create_model('efficientdet_d0', bench_task='predict').eval()
    print(model(bchw)[:, :2])

    # Direct export completes but the resulting file cannot be run. Pytorch bug?
    torch.onnx.export(model, (bchw,), 'efficientdet_d0.onnx', input_names=['input'], output_names=['output'])
    try:
        omodel = ort.InferenceSession('efficientdet_d0.onnx')
        print(omodel.run(None, {'input': bchw.numpy()})[:, :2])
    except Exception as err:
        print(err)
        # currently fails with:
        # [ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0.

    # Export of the traced model works
    traced = torch.jit.trace(model, (bchw,))
    torch.onnx.export(traced, (bchw,), 'efficientdet_d0_traced.onnx', input_names=['input'], output_names=['output'])

    omodel = ort.InferenceSession('efficientdet_d0_traced.onnx')
    print(omodel.run(None, {'input': bchw.numpy()})[0][:, :2])


if __name__ == '__main__':
    main()

Example output on my machine, torch==2.5.1+cu121, torchvision==0.20.1+cu121, onnx==1.17.0, onnx==1.17.0, onnxruntime==1.20.1, effdet==0.4.1:

Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.
tensor([[[4.6644e+02, 2.9536e+02, 5.1953e+02, 3.2238e+02, 1.1469e-02,
          4.0000e+01],
         [3.8305e+02, 8.1529e+00, 4.2393e+02, 4.7712e+01, 1.1439e-02,
          3.6000e+01]]], grad_fn=<SliceBackward0>)
[ONNXRuntimeError] : 1 : FAIL : Load model from efficientdet_d0.onnx failed:Invalid tensor data type 0.
/home/monk/.cache/uv/archive-v0/zTMB4qSl8IEuHk0t4-dyt/lib/python3.11/site-packages/torch/onnx/utils.py:782: UserWarning: no signature found for builtin <built-in method __call__ of PyCapsule object at 0x7f2b773d5350>, skipping _decide_input_format
  warnings.warn(f"{e}, skipping _decide_input_format")
[[[4.6644165e+02 2.9536493e+02 5.1953387e+02 3.2238403e+02 1.1469424e-02
   4.0000000e+01]
  [3.8305466e+02 8.1528721e+00 4.2393228e+02 4.7711700e+01 1.1438489e-02
   3.6000000e+01]]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions