1
1
import io
2
2
import torch
3
3
from torchvision import ops
4
+ from torchvision .models .detection .image_list import ImageList
4
5
from torchvision .models .detection .transform import GeneralizedRCNNTransform
6
+ from torchvision .models .detection .rpn import AnchorGenerator , RPNHead , RegionProposalNetwork
7
+ from torchvision .models .detection .backbone_utils import resnet_fpn_backbone
5
8
6
9
from collections import OrderedDict
7
10
@@ -20,7 +23,7 @@ class ONNXExporterTester(unittest.TestCase):
20
23
def setUpClass (cls ):
21
24
torch .manual_seed (123 )
22
25
23
- def run_model (self , model , inputs_list ):
26
+ def run_model (self , model , inputs_list , tolerate_small_mismatch = False ):
24
27
model .eval ()
25
28
26
29
onnx_io = io .BytesIO ()
@@ -36,9 +39,9 @@ def run_model(self, model, inputs_list):
36
39
test_ouputs = model (* test_inputs )
37
40
if isinstance (test_ouputs , torch .Tensor ):
38
41
test_ouputs = (test_ouputs ,)
39
- self .ort_validate (onnx_io , test_inputs , test_ouputs )
42
+ self .ort_validate (onnx_io , test_inputs , test_ouputs , tolerate_small_mismatch )
40
43
41
- def ort_validate (self , onnx_io , inputs , outputs ):
44
+ def ort_validate (self , onnx_io , inputs , outputs , tolerate_small_mismatch = False ):
42
45
43
46
inputs , _ = torch .jit ._flatten (inputs )
44
47
outputs , _ = torch .jit ._flatten (outputs )
@@ -58,7 +61,13 @@ def to_numpy(tensor):
58
61
ort_outs = ort_session .run (None , ort_inputs )
59
62
60
63
for i in range (0 , len (outputs )):
61
- torch .testing .assert_allclose (outputs [i ], ort_outs [i ], rtol = 1e-03 , atol = 1e-05 )
64
+ try :
65
+ torch .testing .assert_allclose (outputs [i ], ort_outs [i ], rtol = 1e-03 , atol = 1e-05 )
66
+ except AssertionError as error :
67
+ if tolerate_small_mismatch :
68
+ assert ("(0.00%)" in str (error )), str (error )
69
+ else :
70
+ assert False , str (error )
62
71
63
72
def test_nms (self ):
64
73
boxes = torch .rand (5 , 4 )
@@ -91,11 +100,7 @@ def test_transform_images(self):
91
100
class TransformModule (torch .nn .Module ):
92
101
def __init__ (self_module ):
93
102
super (TransformModule , self_module ).__init__ ()
94
- min_size = 800
95
- max_size = 1333
96
- image_mean = [0.485 , 0.456 , 0.406 ]
97
- image_std = [0.229 , 0.224 , 0.225 ]
98
- self_module .transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
103
+ self_module .transform = self ._init_test_generalized_rcnn_transform ()
99
104
100
105
def forward (self_module , images ):
101
106
return self_module .transform (images )[0 ].tensors
@@ -104,6 +109,66 @@ def forward(self_module, images):
104
109
input_test = [torch .rand (3 , 800 , 1280 ), torch .rand (3 , 800 , 800 )]
105
110
self .run_model (TransformModule (), [input , input_test ])
106
111
112
+ def _init_test_generalized_rcnn_transform (self ):
113
+ min_size = 800
114
+ max_size = 1333
115
+ image_mean = [0.485 , 0.456 , 0.406 ]
116
+ image_std = [0.229 , 0.224 , 0.225 ]
117
+ transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
118
+ return transform
119
+
120
+ def _init_test_rpn (self ):
121
+ anchor_sizes = ((32 ,), (64 ,), (128 ,), (256 ,), (512 ,))
122
+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
123
+ rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
124
+ out_channels = 256
125
+ rpn_head = RPNHead (out_channels , rpn_anchor_generator .num_anchors_per_location ()[0 ])
126
+ rpn_fg_iou_thresh = 0.7
127
+ rpn_bg_iou_thresh = 0.3
128
+ rpn_batch_size_per_image = 256
129
+ rpn_positive_fraction = 0.5
130
+ rpn_pre_nms_top_n = dict (training = 2000 , testing = 1000 )
131
+ rpn_post_nms_top_n = dict (training = 2000 , testing = 1000 )
132
+ rpn_nms_thresh = 0.7
133
+
134
+ rpn = RegionProposalNetwork (
135
+ rpn_anchor_generator , rpn_head ,
136
+ rpn_fg_iou_thresh , rpn_bg_iou_thresh ,
137
+ rpn_batch_size_per_image , rpn_positive_fraction ,
138
+ rpn_pre_nms_top_n , rpn_post_nms_top_n , rpn_nms_thresh )
139
+ return rpn
140
+
141
+ def test_rpn (self ):
142
+ class RPNModule (torch .nn .Module ):
143
+ def __init__ (self_module , images ):
144
+ super (RPNModule , self_module ).__init__ ()
145
+ self_module .rpn = self ._init_test_rpn ()
146
+ self_module .images = ImageList (images , [i .shape [- 2 :] for i in images ])
147
+
148
+ def forward (self_module , features ):
149
+ return self_module .rpn (self_module .images , features )
150
+
151
+ def get_features (images ):
152
+ s0 , s1 = images .shape [- 2 :]
153
+ features = [
154
+ ('0' , torch .rand (2 , 256 , s0 // 4 , s1 // 4 )),
155
+ ('1' , torch .rand (2 , 256 , s0 // 8 , s1 // 8 )),
156
+ ('2' , torch .rand (2 , 256 , s0 // 16 , s1 // 16 )),
157
+ ('3' , torch .rand (2 , 256 , s0 // 32 , s1 // 32 )),
158
+ ('4' , torch .rand (2 , 256 , s0 // 64 , s1 // 64 )),
159
+ ]
160
+ features = OrderedDict (features )
161
+ return features
162
+
163
+ images = torch .rand (2 , 3 , 600 , 600 )
164
+ features = get_features (images )
165
+ test_features = get_features (images )
166
+
167
+ model = RPNModule (images )
168
+ model .eval ()
169
+ model (features )
170
+ self .run_model (model , [(features ,), (test_features ,)], tolerate_small_mismatch = True )
171
+
107
172
def test_multi_scale_roi_align (self ):
108
173
109
174
class TransformModule (torch .nn .Module ):
0 commit comments