3
3
4
4
import torch
5
5
6
- _onnx_opset_version = 11
6
+ _onnx_opset_version_11 = 11
7
+ _onnx_opset_version_16 = 16
8
+ base_onnx_opset_version = _onnx_opset_version_11
7
9
8
10
9
11
def _register_custom_op ():
10
12
from torch .onnx .symbolic_helper import parse_args
11
- from torch .onnx .symbolic_opset11 import select , squeeze , unsqueeze
12
13
from torch .onnx .symbolic_opset9 import _cast_Long
14
+ from torch .onnx .symbolic_opset11 import select , squeeze , unsqueeze
13
15
14
16
@parse_args ("v" , "v" , "f" )
15
17
def symbolic_multi_label_nms (g , boxes , scores , iou_threshold ):
@@ -20,32 +22,56 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
20
22
nms_out = g .op ("NonMaxSuppression" , boxes , scores , max_output_per_class , iou_threshold )
21
23
return squeeze (g , select (g , nms_out , 1 , g .op ("Constant" , value_t = torch .tensor ([2 ], dtype = torch .long ))), 1 )
22
24
23
- @parse_args ("v" , "v" , "f" , "i" , "i" , "i" , "i" )
24
- def roi_align (g , input , rois , spatial_scale , pooled_height , pooled_width , sampling_ratio , aligned ):
25
- batch_indices = _cast_Long (
25
+ def _process_batch_indices_for_roi_align (g , rois ):
26
+ return _cast_Long (
26
27
g , squeeze (g , select (g , rois , 1 , g .op ("Constant" , value_t = torch .tensor ([0 ], dtype = torch .long ))), 1 ), False
27
28
)
28
- rois = select (g , rois , 1 , g .op ("Constant" , value_t = torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .long )))
29
- # TODO: Remove this warning after ONNX opset 16 is supported.
30
- if aligned :
31
- warnings .warn (
32
- "ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
33
- "The workaround is that the user need apply the patch "
34
- "https://github.com/microsoft/onnxruntime/pull/8564 "
35
- "and build ONNXRuntime from source."
36
- )
37
29
38
- # ONNX doesn't support negative sampling_ratio
30
+ def _process_rois_for_roi_align (g , rois ):
31
+ return select (g , rois , 1 , g .op ("Constant" , value_t = torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .long )))
32
+
33
+ def _process_sampling_ratio_for_roi_align (g , sampling_ratio : int ):
39
34
if sampling_ratio < 0 :
40
35
warnings .warn (
41
- "ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
36
+ "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
37
+ "The model will be exported with a sampling_ratio of 0."
42
38
)
43
39
sampling_ratio = 0
40
+ return sampling_ratio
41
+
42
+ @parse_args ("v" , "v" , "f" , "i" , "i" , "i" , "i" )
43
+ def roi_align_opset11 (g , input , rois , spatial_scale , pooled_height , pooled_width , sampling_ratio , aligned ):
44
+ batch_indices = _process_batch_indices_for_roi_align (g , rois )
45
+ rois = _process_rois_for_roi_align (g , rois )
46
+ if aligned :
47
+ warnings .warn (
48
+ "ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
49
+ "Please export with opset 16 or higher to use aligned=False."
50
+ )
51
+ sampling_ratio = _process_sampling_ratio_for_roi_align (g , sampling_ratio )
52
+ return g .op (
53
+ "RoiAlign" ,
54
+ input ,
55
+ rois ,
56
+ batch_indices ,
57
+ spatial_scale_f = spatial_scale ,
58
+ output_height_i = pooled_height ,
59
+ output_width_i = pooled_width ,
60
+ sampling_ratio_i = sampling_ratio ,
61
+ )
62
+
63
+ @parse_args ("v" , "v" , "f" , "i" , "i" , "i" , "i" )
64
+ def roi_align_opset16 (g , input , rois , spatial_scale , pooled_height , pooled_width , sampling_ratio , aligned ):
65
+ batch_indices = _process_batch_indices_for_roi_align (g , rois )
66
+ rois = _process_rois_for_roi_align (g , rois )
67
+ coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
68
+ sampling_ratio = _process_sampling_ratio_for_roi_align (g , sampling_ratio )
44
69
return g .op (
45
70
"RoiAlign" ,
46
71
input ,
47
72
rois ,
48
73
batch_indices ,
74
+ coordinate_transformation_mode_s = coordinate_transformation_mode ,
49
75
spatial_scale_f = spatial_scale ,
50
76
output_height_i = pooled_height ,
51
77
output_width_i = pooled_width ,
@@ -61,6 +87,7 @@ def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
61
87
62
88
from torch .onnx import register_custom_op_symbolic
63
89
64
- register_custom_op_symbolic ("torchvision::nms" , symbolic_multi_label_nms , _onnx_opset_version )
65
- register_custom_op_symbolic ("torchvision::roi_align" , roi_align , _onnx_opset_version )
66
- register_custom_op_symbolic ("torchvision::roi_pool" , roi_pool , _onnx_opset_version )
90
+ register_custom_op_symbolic ("torchvision::nms" , symbolic_multi_label_nms , _onnx_opset_version_11 )
91
+ register_custom_op_symbolic ("torchvision::roi_align" , roi_align_opset11 , _onnx_opset_version_11 )
92
+ register_custom_op_symbolic ("torchvision::roi_align" , roi_align_opset16 , _onnx_opset_version_16 )
93
+ register_custom_op_symbolic ("torchvision::roi_pool" , roi_pool , _onnx_opset_version_11 )
0 commit comments