Skip to content

Commit 0a8586c

Browse files
[WIP] Allow autocast for 1.6 (#2384)
* Fixes Xiao's repro * Ports nms to use full dispatcher * Move HIPGuard to nms_cuda * clang-format * run models in test_models.py on GPU if available * Francisco's comment, also disable cuda model tests to see if CPU alone still passes * cuda tests now pass locally, although still not comparing to saved numerics * add note for thing to ask francisco * Allow cuda and cpu tests to share a data file * ignore suffix if unneeded * Skip autocast numerics checks for a few models * Add roi_align test Co-authored-by: Michael Carilli <[email protected]>
1 parent 4246abc commit 0a8586c

File tree

12 files changed

+314
-125
lines changed

12 files changed

+314
-125
lines changed

test/common_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def is_iterable(obj):
8585
class TestCase(unittest.TestCase):
8686
precision = 1e-5
8787

88-
def assertExpected(self, output, subname=None, prec=None):
88+
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
8989
r"""
9090
Test that a python value matches the recorded contents of a file
9191
derived from the name of this test and subname. The value must be
@@ -96,16 +96,24 @@ def assertExpected(self, output, subname=None, prec=None):
9696
9797
If you call this multiple times in a single function, you must
9898
give a unique subname each time.
99+
100+
strip_suffix allows different tests that expect similar numerics, e.g.
101+
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
102+
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
103+
strip_suffix="_cpu", and they would both use a data file name based on
104+
"test_xyz".
99105
"""
100-
def remove_prefix(text, prefix):
106+
def remove_prefix_suffix(text, prefix, suffix):
101107
if text.startswith(prefix):
102-
return text[len(prefix):]
108+
text = text[len(prefix):]
109+
if suffix is not None and text.endswith(suffix):
110+
text = text[:len(text) - len(suffix)]
103111
return text
104112
# NB: we take __file__ from the module that defined the test
105113
# class, so we place the expect directory where the test script
106114
# lives, NOT where test/common_utils.py lives.
107115
module_id = self.__class__.__module__
108-
munged_id = remove_prefix(self.id(), module_id + ".")
116+
munged_id = remove_prefix_suffix(self.id(), module_id + ".", strip_suffix)
109117
test_file = os.path.realpath(sys.modules[module_id].__file__)
110118
expected_file = os.path.join(os.path.dirname(test_file),
111119
"expect",

test/test_models.py

Lines changed: 138 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -74,72 +74,114 @@ def get_available_video_models():
7474
}
7575

7676

77+
# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
78+
# This may be caused by the harness environment (e.g. num classes, input initialization
79+
# via torch.rand), and does not prove autocast is unsuitable when training with real data
80+
# (autocast has been used successfully with real data for some of these models).
81+
# TODO: investigate why autocast numerics are flaky in the harnesses.
82+
#
83+
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
84+
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
85+
# autocast coverage suffices to prevent dtype errors in each model.
86+
autocast_flaky_numerics = (
87+
"fasterrcnn_resnet50_fpn",
88+
"inception_v3",
89+
"keypointrcnn_resnet50_fpn",
90+
"maskrcnn_resnet50_fpn",
91+
"resnet101",
92+
"resnet152",
93+
"wide_resnet101_2",
94+
)
95+
96+
7797
class ModelTester(TestCase):
7898
def checkModule(self, model, name, args):
7999
if name not in script_test_models:
80100
return
81101
unwrapper = script_test_models[name].get('unwrapper', None)
82102
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)
83103

84-
def _test_classification_model(self, name, input_shape):
104+
def _test_classification_model(self, name, input_shape, dev):
85105
set_rng_seed(0)
86106
# passing num_class equal to a number other than 1000 helps in making the test
87107
# more enforcing in nature
88108
model = models.__dict__[name](num_classes=50)
89-
model.eval()
90-
x = torch.rand(input_shape)
109+
model.eval().to(device=dev)
110+
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
111+
x = torch.rand(input_shape).to(device=dev)
91112
out = model(x)
92-
self.assertExpected(out, prec=0.1)
113+
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
93114
self.assertEqual(out.shape[-1], 50)
94115
self.checkModule(model, name, (x,))
95116

96-
def _test_segmentation_model(self, name):
117+
if dev == "cuda":
118+
with torch.cuda.amp.autocast():
119+
out = model(x)
120+
# See autocast_flaky_numerics comment at top of file.
121+
if name not in autocast_flaky_numerics:
122+
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
123+
self.assertEqual(out.shape[-1], 50)
124+
125+
def _test_segmentation_model(self, name, dev):
97126
# passing num_class equal to a number other than 1000 helps in making the test
98127
# more enforcing in nature
99128
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
100-
model.eval()
129+
model.eval().to(device=dev)
101130
input_shape = (1, 3, 300, 300)
102-
x = torch.rand(input_shape)
131+
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
132+
x = torch.rand(input_shape).to(device=dev)
103133
out = model(x)
104134
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
105135
self.checkModule(model, name, (x,))
106136

107-
def _test_detection_model(self, name):
137+
if dev == "cuda":
138+
with torch.cuda.amp.autocast():
139+
out = model(x)
140+
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
141+
142+
def _test_detection_model(self, name, dev):
108143
set_rng_seed(0)
109144
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
110-
model.eval()
145+
model.eval().to(device=dev)
111146
input_shape = (3, 300, 300)
112-
x = torch.rand(input_shape)
147+
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
148+
x = torch.rand(input_shape).to(device=dev)
113149
model_input = [x]
114150
out = model(model_input)
115151
self.assertIs(model_input[0], x)
116-
self.assertEqual(len(out), 1)
117152

118-
def subsample_tensor(tensor):
119-
num_elems = tensor.numel()
120-
num_samples = 20
121-
if num_elems <= num_samples:
122-
return tensor
123-
124-
flat_tensor = tensor.flatten()
125-
ith_index = num_elems // num_samples
126-
return flat_tensor[ith_index - 1::ith_index]
127-
128-
def compute_mean_std(tensor):
129-
# can't compute mean of integral tensor
130-
tensor = tensor.to(torch.double)
131-
mean = torch.mean(tensor)
132-
std = torch.std(tensor)
133-
return {"mean": mean, "std": std}
134-
135-
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
136-
# compare results with mean and std
137-
if name == "maskrcnn_resnet50_fpn":
138-
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
139-
# mean values are small, use large prec
140-
self.assertExpected(test_value, prec=.01)
141-
else:
142-
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), prec=0.01)
153+
def check_out(out):
154+
self.assertEqual(len(out), 1)
155+
156+
def subsample_tensor(tensor):
157+
num_elems = tensor.numel()
158+
num_samples = 20
159+
if num_elems <= num_samples:
160+
return tensor
161+
162+
flat_tensor = tensor.flatten()
163+
ith_index = num_elems // num_samples
164+
return flat_tensor[ith_index - 1::ith_index]
165+
166+
def compute_mean_std(tensor):
167+
# can't compute mean of integral tensor
168+
tensor = tensor.to(torch.double)
169+
mean = torch.mean(tensor)
170+
std = torch.std(tensor)
171+
return {"mean": mean, "std": std}
172+
173+
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
174+
# compare results with mean and std
175+
if name == "maskrcnn_resnet50_fpn":
176+
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
177+
# mean values are small, use large prec
178+
self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev)
179+
else:
180+
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor),
181+
prec=0.01,
182+
strip_suffix="_" + dev)
183+
184+
check_out(out)
143185

144186
scripted_model = torch.jit.script(model)
145187
scripted_model.eval()
@@ -156,6 +198,13 @@ def compute_mean_std(tensor):
156198
# self.check_script(model, name)
157199
self.checkModule(model, name, ([x],))
158200

201+
if dev == "cuda":
202+
with torch.cuda.amp.autocast():
203+
out = model(model_input)
204+
# See autocast_flaky_numerics comment at top of file.
205+
if name not in autocast_flaky_numerics:
206+
check_out(out)
207+
159208
def _test_detection_model_validation(self, name):
160209
set_rng_seed(0)
161210
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
@@ -179,18 +228,24 @@ def _test_detection_model_validation(self, name):
179228
targets = [{'boxes': boxes}]
180229
self.assertRaises(ValueError, model, x, targets=targets)
181230

182-
def _test_video_model(self, name):
231+
def _test_video_model(self, name, dev):
183232
# the default input shape is
184233
# bs * num_channels * clip_len * h *w
185234
input_shape = (1, 3, 4, 112, 112)
186235
# test both basicblock and Bottleneck
187236
model = models.video.__dict__[name](num_classes=50)
188-
model.eval()
189-
x = torch.rand(input_shape)
237+
model.eval().to(device=dev)
238+
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
239+
x = torch.rand(input_shape).to(device=dev)
190240
out = model(x)
191241
self.checkModule(model, name, (x,))
192242
self.assertEqual(out.shape[-1], 50)
193243

244+
if dev == "cuda":
245+
with torch.cuda.amp.autocast():
246+
out = model(x)
247+
self.assertEqual(out.shape[-1], 50)
248+
194249
def _make_sliced_model(self, model, stop_layer):
195250
layers = OrderedDict()
196251
for name, layer in model.named_children():
@@ -272,6 +327,12 @@ def test_googlenet_eval(self):
272327

273328
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
274329
def test_fasterrcnn_switch_devices(self):
330+
def checkOut(out):
331+
self.assertEqual(len(out), 1)
332+
self.assertTrue("boxes" in out[0])
333+
self.assertTrue("scores" in out[0])
334+
self.assertTrue("labels" in out[0])
335+
275336
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
276337
model.cuda()
277338
model.eval()
@@ -280,17 +341,20 @@ def test_fasterrcnn_switch_devices(self):
280341
model_input = [x]
281342
out = model(model_input)
282343
self.assertIs(model_input[0], x)
283-
self.assertEqual(len(out), 1)
284-
self.assertTrue("boxes" in out[0])
285-
self.assertTrue("scores" in out[0])
286-
self.assertTrue("labels" in out[0])
344+
345+
checkOut(out)
346+
347+
with torch.cuda.amp.autocast():
348+
out = model(model_input)
349+
350+
checkOut(out)
351+
287352
# now switch to cpu and make sure it works
288353
model.cpu()
289354
x = x.cpu()
290355
out_cpu = model([x])
291-
self.assertTrue("boxes" in out_cpu[0])
292-
self.assertTrue("scores" in out_cpu[0])
293-
self.assertTrue("labels" in out_cpu[0])
356+
357+
checkOut(out_cpu)
294358

295359
def test_generalizedrcnn_transform_repr(self):
296360

@@ -312,34 +376,40 @@ def test_generalizedrcnn_transform_repr(self):
312376
self.assertEqual(t.__repr__(), expected_string)
313377

314378

379+
_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
380+
381+
315382
for model_name in get_available_classification_models():
316-
# for-loop bodies don't define scopes, so we have to save the variables
317-
# we want to close over in some way
318-
def do_test(self, model_name=model_name):
319-
input_shape = (1, 3, 224, 224)
320-
if model_name in ['inception_v3']:
321-
input_shape = (1, 3, 299, 299)
322-
self._test_classification_model(model_name, input_shape)
383+
for dev in _devs:
384+
# for-loop bodies don't define scopes, so we have to save the variables
385+
# we want to close over in some way
386+
def do_test(self, model_name=model_name, dev=dev):
387+
input_shape = (1, 3, 224, 224)
388+
if model_name in ['inception_v3']:
389+
input_shape = (1, 3, 299, 299)
390+
self._test_classification_model(model_name, input_shape, dev)
323391

324-
setattr(ModelTester, "test_" + model_name, do_test)
392+
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
325393

326394

327395
for model_name in get_available_segmentation_models():
328-
# for-loop bodies don't define scopes, so we have to save the variables
329-
# we want to close over in some way
330-
def do_test(self, model_name=model_name):
331-
self._test_segmentation_model(model_name)
396+
for dev in _devs:
397+
# for-loop bodies don't define scopes, so we have to save the variables
398+
# we want to close over in some way
399+
def do_test(self, model_name=model_name, dev=dev):
400+
self._test_segmentation_model(model_name, dev)
332401

333-
setattr(ModelTester, "test_" + model_name, do_test)
402+
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
334403

335404

336405
for model_name in get_available_detection_models():
337-
# for-loop bodies don't define scopes, so we have to save the variables
338-
# we want to close over in some way
339-
def do_test(self, model_name=model_name):
340-
self._test_detection_model(model_name)
406+
for dev in _devs:
407+
# for-loop bodies don't define scopes, so we have to save the variables
408+
# we want to close over in some way
409+
def do_test(self, model_name=model_name, dev=dev):
410+
self._test_detection_model(model_name, dev)
341411

342-
setattr(ModelTester, "test_" + model_name, do_test)
412+
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
343413

344414
def do_validation_test(self, model_name=model_name):
345415
self._test_detection_model_validation(model_name)
@@ -348,11 +418,11 @@ def do_validation_test(self, model_name=model_name):
348418

349419

350420
for model_name in get_available_video_models():
421+
for dev in _devs:
422+
def do_test(self, model_name=model_name, dev=dev):
423+
self._test_video_model(model_name, dev)
351424

352-
def do_test(self, model_name=model_name):
353-
self._test_video_model(model_name)
354-
355-
setattr(ModelTester, "test_" + model_name, do_test)
425+
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
356426

357427
if __name__ == '__main__':
358428
unittest.main()

test/test_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,30 @@ def _test_backward(self, device, contiguous):
5252

5353

5454
class RoIOpTester(OpTester):
55-
def _test_forward(self, device, contiguous):
55+
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
56+
x_dtype = self.dtype if x_dtype is None else x_dtype
57+
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
5658
pool_size = 5
5759
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
5860
n_channels = 2 * (pool_size ** 2)
59-
x = torch.rand(2, n_channels, 10, 10, dtype=self.dtype, device=device)
61+
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
6062
if not contiguous:
6163
x = x.permute(0, 1, 3, 2)
6264
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
6365
[0, 0, 5, 4, 9],
6466
[0, 5, 5, 9, 9],
6567
[1, 0, 0, 9, 9]],
66-
dtype=self.dtype, device=device)
68+
dtype=rois_dtype, device=device)
6769

6870
pool_h, pool_w = pool_size, pool_size
6971
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
72+
# the following should be true whether we're running an autocast test or not.
73+
self.assertTrue(y.dtype == x.dtype)
7074
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
7175
sampling_ratio=-1, device=device, dtype=self.dtype)
7276

73-
self.assertTrue(torch.allclose(gt_y, y))
77+
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
78+
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))
7479

7580
def _test_backward(self, device, contiguous):
7681
pool_size = 2
@@ -290,6 +295,13 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
290295
def _test_boxes_shape(self):
291296
self._helper_boxes_shape(ops.roi_align)
292297

298+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
299+
def test_roi_align_autocast(self):
300+
for x_dtype in (torch.float, torch.half):
301+
for rois_dtype in (torch.float, torch.half):
302+
with torch.cuda.amp.autocast():
303+
self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
304+
293305

294306
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
295307
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):

0 commit comments

Comments
 (0)