Skip to content

Commit c9174e4

Browse files
NicolasHugdatumbox
authored andcommitted
[fbsync] [ONNX] Support exporting RoiAlign align=True to ONNX with opset 16 (#6685)
Summary: * Support exporting RoiAlign align=True to ONNX with opset 16 * lint: ufmt Reviewed By: datumbox Differential Revision: D40138746 fbshipit-source-id: 06dcd122e762c02491fd68a864773894b527b854 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 201cb2a commit c9174e4

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

test/test_onnx.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
from collections import OrderedDict
3-
from typing import List, Tuple
3+
from typing import List, Optional, Tuple
44

55
import pytest
66
import torch
@@ -11,7 +11,7 @@
1111
from torchvision.models.detection.roi_heads import RoIHeads
1212
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
1313
from torchvision.models.detection.transform import GeneralizedRCNNTransform
14-
from torchvision.ops._register_onnx_ops import _onnx_opset_version
14+
from torchvision.ops import _register_onnx_ops
1515

1616
# In environments without onnxruntime we prefer to
1717
# invoke all tests in the repo and have this one skipped rather than fail.
@@ -32,7 +32,11 @@ def run_model(
3232
dynamic_axes=None,
3333
output_names=None,
3434
input_names=None,
35+
opset_version: Optional[int] = None,
3536
):
37+
if opset_version is None:
38+
opset_version = _register_onnx_ops.base_onnx_opset_version
39+
3640
model.eval()
3741

3842
onnx_io = io.BytesIO()
@@ -46,10 +50,11 @@ def run_model(
4650
torch_onnx_input,
4751
onnx_io,
4852
do_constant_folding=do_constant_folding,
49-
opset_version=_onnx_opset_version,
53+
opset_version=opset_version,
5054
dynamic_axes=dynamic_axes,
5155
input_names=input_names,
5256
output_names=output_names,
57+
verbose=True,
5358
)
5459
# validate the exported model with onnx runtime
5560
for test_inputs in inputs_list:
@@ -140,39 +145,39 @@ def test_roi_align(self):
140145
model = ops.RoIAlign((5, 5), 1, -1)
141146
self.run_model(model, [(x, single_roi)])
142147

143-
@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
144148
def test_roi_align_aligned(self):
149+
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
145150
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
146151
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
147152
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
148-
self.run_model(model, [(x, single_roi)])
153+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
149154

150155
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
151156
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
152157
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
153-
self.run_model(model, [(x, single_roi)])
158+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
154159

155160
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
156161
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
157162
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
158-
self.run_model(model, [(x, single_roi)])
163+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
159164

160165
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
161166
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
162167
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
163-
self.run_model(model, [(x, single_roi)])
168+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
164169

165170
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
166171
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
167172
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
168-
self.run_model(model, [(x, single_roi)])
173+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
169174

170-
@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
171175
def test_roi_align_malformed_boxes(self):
176+
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
172177
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
173178
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
174179
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
175-
self.run_model(model, [(x, single_roi)])
180+
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
176181

177182
def test_roi_pool(self):
178183
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)

torchvision/ops/_register_onnx_ops.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import torch
55

6-
_onnx_opset_version = 11
6+
_onnx_opset_version_11 = 11
7+
_onnx_opset_version_16 = 16
8+
base_onnx_opset_version = _onnx_opset_version_11
79

810

911
def _register_custom_op():
@@ -20,32 +22,56 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
2022
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
2123
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
2224

23-
@parse_args("v", "v", "f", "i", "i", "i", "i")
24-
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
25-
batch_indices = _cast_Long(
25+
def _process_batch_indices_for_roi_align(g, rois):
26+
return _cast_Long(
2627
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
2728
)
28-
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
29-
# TODO: Remove this warning after ONNX opset 16 is supported.
30-
if aligned:
31-
warnings.warn(
32-
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
33-
"The workaround is that the user need apply the patch "
34-
"https://github.com/microsoft/onnxruntime/pull/8564 "
35-
"and build ONNXRuntime from source."
36-
)
3729

38-
# ONNX doesn't support negative sampling_ratio
30+
def _process_rois_for_roi_align(g, rois):
31+
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
32+
33+
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
3934
if sampling_ratio < 0:
4035
warnings.warn(
41-
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
36+
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
37+
"The model will be exported with a sampling_ratio of 0."
4238
)
4339
sampling_ratio = 0
40+
return sampling_ratio
41+
42+
@parse_args("v", "v", "f", "i", "i", "i", "i")
43+
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
44+
batch_indices = _process_batch_indices_for_roi_align(g, rois)
45+
rois = _process_rois_for_roi_align(g, rois)
46+
if aligned:
47+
warnings.warn(
48+
"ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
49+
"Please export with opset 16 or higher to use aligned=False."
50+
)
51+
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
52+
return g.op(
53+
"RoiAlign",
54+
input,
55+
rois,
56+
batch_indices,
57+
spatial_scale_f=spatial_scale,
58+
output_height_i=pooled_height,
59+
output_width_i=pooled_width,
60+
sampling_ratio_i=sampling_ratio,
61+
)
62+
63+
@parse_args("v", "v", "f", "i", "i", "i", "i")
64+
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
65+
batch_indices = _process_batch_indices_for_roi_align(g, rois)
66+
rois = _process_rois_for_roi_align(g, rois)
67+
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
68+
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
4469
return g.op(
4570
"RoiAlign",
4671
input,
4772
rois,
4873
batch_indices,
74+
coordinate_transformation_mode_s=coordinate_transformation_mode,
4975
spatial_scale_f=spatial_scale,
5076
output_height_i=pooled_height,
5177
output_width_i=pooled_width,
@@ -61,6 +87,7 @@ def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
6187

6288
from torch.onnx import register_custom_op_symbolic
6389

64-
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version)
65-
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version)
66-
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version)
90+
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
91+
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
92+
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
93+
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)

0 commit comments

Comments
 (0)