Skip to content

Commit 1d6145d

Browse files
lara-hdrfmassa
authored andcommitted
Support Exporting RPN to ONNX (#1329)
* Support Exporting RPN to ONNX * address PR comments * fix cat * add flatten * replace cat by stack * update test to run only on rpn module * use tolerate_small_mismatch
1 parent f16b672 commit 1d6145d

File tree

4 files changed

+115
-26
lines changed

4 files changed

+115
-26
lines changed

test/test_onnx.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import io
22
import torch
33
from torchvision import ops
4+
from torchvision.models.detection.image_list import ImageList
45
from torchvision.models.detection.transform import GeneralizedRCNNTransform
6+
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
7+
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
58

69
from collections import OrderedDict
710

@@ -20,7 +23,7 @@ class ONNXExporterTester(unittest.TestCase):
2023
def setUpClass(cls):
2124
torch.manual_seed(123)
2225

23-
def run_model(self, model, inputs_list):
26+
def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
2427
model.eval()
2528

2629
onnx_io = io.BytesIO()
@@ -36,9 +39,9 @@ def run_model(self, model, inputs_list):
3639
test_ouputs = model(*test_inputs)
3740
if isinstance(test_ouputs, torch.Tensor):
3841
test_ouputs = (test_ouputs,)
39-
self.ort_validate(onnx_io, test_inputs, test_ouputs)
42+
self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
4043

41-
def ort_validate(self, onnx_io, inputs, outputs):
44+
def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
4245

4346
inputs, _ = torch.jit._flatten(inputs)
4447
outputs, _ = torch.jit._flatten(outputs)
@@ -58,7 +61,13 @@ def to_numpy(tensor):
5861
ort_outs = ort_session.run(None, ort_inputs)
5962

6063
for i in range(0, len(outputs)):
61-
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
64+
try:
65+
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
66+
except AssertionError as error:
67+
if tolerate_small_mismatch:
68+
assert ("(0.00%)" in str(error)), str(error)
69+
else:
70+
assert False, str(error)
6271

6372
def test_nms(self):
6473
boxes = torch.rand(5, 4)
@@ -91,11 +100,7 @@ def test_transform_images(self):
91100
class TransformModule(torch.nn.Module):
92101
def __init__(self_module):
93102
super(TransformModule, self_module).__init__()
94-
min_size = 800
95-
max_size = 1333
96-
image_mean = [0.485, 0.456, 0.406]
97-
image_std = [0.229, 0.224, 0.225]
98-
self_module.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
103+
self_module.transform = self._init_test_generalized_rcnn_transform()
99104

100105
def forward(self_module, images):
101106
return self_module.transform(images)[0].tensors
@@ -104,6 +109,66 @@ def forward(self_module, images):
104109
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
105110
self.run_model(TransformModule(), [input, input_test])
106111

112+
def _init_test_generalized_rcnn_transform(self):
113+
min_size = 800
114+
max_size = 1333
115+
image_mean = [0.485, 0.456, 0.406]
116+
image_std = [0.229, 0.224, 0.225]
117+
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
118+
return transform
119+
120+
def _init_test_rpn(self):
121+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
122+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
123+
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
124+
out_channels = 256
125+
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
126+
rpn_fg_iou_thresh = 0.7
127+
rpn_bg_iou_thresh = 0.3
128+
rpn_batch_size_per_image = 256
129+
rpn_positive_fraction = 0.5
130+
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
131+
rpn_post_nms_top_n = dict(training=2000, testing=1000)
132+
rpn_nms_thresh = 0.7
133+
134+
rpn = RegionProposalNetwork(
135+
rpn_anchor_generator, rpn_head,
136+
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
137+
rpn_batch_size_per_image, rpn_positive_fraction,
138+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
139+
return rpn
140+
141+
def test_rpn(self):
142+
class RPNModule(torch.nn.Module):
143+
def __init__(self_module, images):
144+
super(RPNModule, self_module).__init__()
145+
self_module.rpn = self._init_test_rpn()
146+
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
147+
148+
def forward(self_module, features):
149+
return self_module.rpn(self_module.images, features)
150+
151+
def get_features(images):
152+
s0, s1 = images.shape[-2:]
153+
features = [
154+
('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
155+
('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
156+
('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
157+
('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
158+
('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
159+
]
160+
features = OrderedDict(features)
161+
return features
162+
163+
images = torch.rand(2, 3, 600, 600)
164+
features = get_features(images)
165+
test_features = get_features(images)
166+
167+
model = RPNModule(images)
168+
model.eval()
169+
model(features)
170+
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
171+
107172
def test_multi_scale_roi_align(self):
108173

109174
class TransformModule(torch.nn.Module):

torchvision/models/detection/_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44

55
import torch
6+
import torchvision
67

78

89
class BalancedPositiveNegativeSampler(object):
@@ -162,7 +163,7 @@ def decode(self, rel_codes, boxes):
162163
if isinstance(rel_codes, (list, tuple)):
163164
rel_codes = torch.cat(rel_codes, dim=0)
164165
assert isinstance(rel_codes, torch.Tensor)
165-
boxes_per_image = [len(b) for b in boxes]
166+
boxes_per_image = [b.size(0) for b in boxes]
166167
concat_boxes = torch.cat(boxes, dim=0)
167168
pred_boxes = self.decode_single(
168169
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
@@ -201,16 +202,11 @@ def decode_single(self, rel_codes, boxes):
201202
pred_w = torch.exp(dw) * widths[:, None]
202203
pred_h = torch.exp(dh) * heights[:, None]
203204

204-
pred_boxes = torch.zeros_like(rel_codes)
205-
# x1
206-
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
207-
# y1
208-
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
209-
# x2
210-
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
211-
# y2
212-
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h
213-
205+
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
206+
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
207+
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
208+
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
209+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
214210
return pred_boxes
215211

216212

torchvision/models/detection/rpn.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,25 @@
33
from torch.nn import functional as F
44
from torch import nn
55

6+
import torchvision
67
from torchvision.ops import boxes as box_ops
78

89
from . import _utils as det_utils
910

1011

12+
@torch.jit.unused
13+
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
14+
from torch.onnx import operators
15+
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
16+
# TODO : remove cast to IntTensor/num_anchors.dtype when
17+
# ONNX Runtime version is updated with ReduceMin int64 support
18+
pre_nms_top_n = torch.min(torch.cat(
19+
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
20+
num_anchors), 0).to(torch.int32)).to(num_anchors.dtype)
21+
22+
return num_anchors, pre_nms_top_n
23+
24+
1125
class AnchorGenerator(nn.Module):
1226
"""
1327
Module that generates anchors for a set of feature maps and
@@ -85,14 +99,24 @@ def grid_anchors(self, grid_sizes, strides):
8599
):
86100
grid_height, grid_width = size
87101
stride_height, stride_width = stride
102+
if torchvision._is_tracing():
103+
# required in ONNX export for mult operation with float32
104+
stride_width = torch.tensor(stride_width, dtype=torch.float32)
105+
stride_height = torch.tensor(stride_height, dtype=torch.float32)
88106
device = base_anchors.device
89107
shifts_x = torch.arange(
90108
0, grid_width, dtype=torch.float32, device=device
91109
) * stride_width
92110
shifts_y = torch.arange(
93111
0, grid_height, dtype=torch.float32, device=device
94112
) * stride_height
95-
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
113+
# TODO: remove tracing pass when exporting torch.meshgrid()
114+
# is suported in ONNX
115+
if torchvision._is_tracing():
116+
shift_y = shifts_y.view(-1, 1).expand(grid_height, grid_width)
117+
shift_x = shifts_x.view(1, -1).expand(grid_height, grid_width)
118+
else:
119+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
96120
shift_x = shift_x.reshape(-1)
97121
shift_y = shift_y.reshape(-1)
98122
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -114,7 +138,9 @@ def cached_grid_anchors(self, grid_sizes, strides):
114138
def forward(self, image_list, feature_maps):
115139
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
116140
image_size = image_list.tensors.shape[-2:]
117-
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes)
141+
strides = tuple((float(image_size[0]) / float(g[0]),
142+
float(image_size[1]) / float(g[1]))
143+
for g in grid_sizes)
118144
dtype, device = feature_maps[0].dtype, feature_maps[0].device
119145
self.set_cell_anchors(dtype, device)
120146
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
@@ -300,8 +326,11 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
300326
r = []
301327
offset = 0
302328
for ob in objectness.split(num_anchors_per_level, 1):
303-
num_anchors = ob.shape[1]
304-
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
329+
if torchvision._is_tracing():
330+
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n)
331+
else:
332+
num_anchors = ob.shape[1]
333+
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
305334
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
306335
r.append(top_n_idx + offset)
307336
offset += num_anchors

torchvision/ops/misc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _output_size(dim):
132132

133133

134134
# This is not in nn
135-
class FrozenBatchNorm2d(torch.jit.ScriptModule):
135+
class FrozenBatchNorm2d(torch.nn.Module):
136136
"""
137137
BatchNorm2d where the batch statistics and the affine parameters
138138
are fixed
@@ -145,7 +145,6 @@ def __init__(self, n):
145145
self.register_buffer("running_mean", torch.zeros(n))
146146
self.register_buffer("running_var", torch.ones(n))
147147

148-
@torch.jit.script_method
149148
def forward(self, x):
150149
# move reshapes to the beginning
151150
# to make it fuser-friendly

0 commit comments

Comments
 (0)