Skip to content

Commit 7f7ff05

Browse files
authored
Fix flakiness on detection tests (#2966)
* Simplify the ACCEPT=True logic in assertExpected(). * Separate the expected filename estimation from assertExpected * Unflatten expected values. * Assert for duplicate scores if primary check fails. * Remove custom exceptions for algorithms and add a compact function for shrinking large ouputs. * Removing unused variables. * Add warning and comments. * Re-enable all autocast unit-test for detection and marking the tests as skipped in partial validation. * Move test skip at the end. * Changing the warning message.
1 parent 4738267 commit 7f7ff05

6 files changed

+56
-34
lines changed

test/common_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,14 @@ def remove_prefix_suffix(text, prefix, suffix):
105105
"expect",
106106
munged_id)
107107

108-
subname_output = ""
109108
if subname:
110109
expected_file += "_" + subname
111-
subname_output = " ({})".format(subname)
112110
expected_file += "_expect.pkl"
113111

114112
if not ACCEPT and not os.path.exists(expected_file):
115113
raise RuntimeError(
116-
("No expect file exists for {}{}; to accept the current output, run:\n"
117-
"python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id))
114+
("No expect file exists for {}; to accept the current output, run:\n"
115+
"python {} {} --accept").format(os.path.basename(expected_file), __main__.__file__, munged_id))
118116

119117
return expected_file
120118

@@ -139,11 +137,13 @@ def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
139137
expected_file = self._get_expected_file(subname, strip_suffix)
140138

141139
if ACCEPT:
142-
print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output))
140+
filename = {os.path.basename(expected_file)}
141+
print("Accepting updated output for {}:\n\n{}".format(filename, output))
143142
torch.save(output, expected_file)
144143
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
145144
binary_size = os.path.getsize(expected_file)
146-
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)
145+
if binary_size > MAX_PICKLE_SIZE:
146+
raise RuntimeError("The output for {}, is larger than 50kb".format(filename))
147147
else:
148148
expected = torch.load(expected_file)
149149
self.assertEqual(output, expected, prec=prec)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state
22
from collections import OrderedDict
33
from itertools import product
4+
import functools
5+
import operator
46
import torch
57
import torch.nn as nn
68
import numpy as np
79
from torchvision import models
810
import unittest
911
import random
10-
11-
from torchvision.models.detection._utils import overwrite_eps
12+
import warnings
1213

1314

1415
def set_rng_seed(seed):
@@ -88,14 +89,10 @@ def get_available_video_models():
8889
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
8990
# autocast coverage suffices to prevent dtype errors in each model.
9091
autocast_flaky_numerics = (
91-
"fasterrcnn_resnet50_fpn",
9292
"inception_v3",
93-
"keypointrcnn_resnet50_fpn",
94-
"maskrcnn_resnet50_fpn",
9593
"resnet101",
9694
"resnet152",
9795
"wide_resnet101_2",
98-
"retinanet_resnet50_fpn",
9996
)
10097

10198

@@ -148,10 +145,9 @@ def _test_detection_model(self, name, dev):
148145
set_rng_seed(0)
149146
kwargs = {}
150147
if "retinanet" in name:
151-
kwargs["score_thresh"] = 0.013
148+
# Reduce the default threshold to ensure the returned boxes are not empty.
149+
kwargs["score_thresh"] = 0.01
152150
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
153-
if "keypointrcnn" in name or "retinanet" in name:
154-
overwrite_eps(model, 0.0)
155151
model.eval().to(device=dev)
156152
input_shape = (3, 300, 300)
157153
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
@@ -163,15 +159,22 @@ def _test_detection_model(self, name, dev):
163159
def check_out(out):
164160
self.assertEqual(len(out), 1)
165161

162+
def compact(tensor):
163+
size = tensor.size()
164+
elements_per_sample = functools.reduce(operator.mul, size[1:], 1)
165+
if elements_per_sample > 30:
166+
return compute_mean_std(tensor)
167+
else:
168+
return subsample_tensor(tensor)
169+
166170
def subsample_tensor(tensor):
167-
num_elems = tensor.numel()
171+
num_elems = tensor.size(0)
168172
num_samples = 20
169173
if num_elems <= num_samples:
170174
return tensor
171175

172-
flat_tensor = tensor.flatten()
173176
ith_index = num_elems // num_samples
174-
return flat_tensor[ith_index - 1::ith_index]
177+
return tensor[ith_index - 1::ith_index]
175178

176179
def compute_mean_std(tensor):
177180
# can't compute mean of integral tensor
@@ -180,18 +183,32 @@ def compute_mean_std(tensor):
180183
std = torch.std(tensor)
181184
return {"mean": mean, "std": std}
182185

183-
if name == "maskrcnn_resnet50_fpn":
184-
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
185-
# compare results with mean and std
186-
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
187-
# mean values are small, use large prec
188-
self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev)
189-
else:
190-
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor),
191-
prec=0.01,
192-
strip_suffix="_" + dev)
193-
194-
check_out(out)
186+
output = map_nested_tensor_object(out, tensor_map_fn=compact)
187+
prec = 0.01
188+
strip_suffix = "_" + dev
189+
try:
190+
# We first try to assert the entire output if possible. This is not
191+
# only the best way to assert results but also handles the cases
192+
# where we need to create a new expected result.
193+
self.assertExpected(output, prec=prec, strip_suffix=strip_suffix)
194+
except AssertionError:
195+
# Unfortunately detection models are flaky due to the unstable sort
196+
# in NMS. If matching across all outputs fails, use the same approach
197+
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
198+
# scores.
199+
expected_file = self._get_expected_file(strip_suffix=strip_suffix)
200+
expected = torch.load(expected_file)
201+
self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec)
202+
203+
# Note: Fmassa proposed turning off NMS by adapting the threshold
204+
# and then using the Hungarian algorithm as in DETR to find the
205+
# best match between output and expected boxes and eliminate some
206+
# of the flakiness. Worth exploring.
207+
return False # Partial validation performed
208+
209+
return True # Full validation performed
210+
211+
full_validation = check_out(out)
195212

196213
scripted_model = torch.jit.script(model)
197214
scripted_model.eval()
@@ -200,9 +217,6 @@ def compute_mean_std(tensor):
200217
self.assertEqual(scripted_out[0]["scores"], out[0]["scores"])
201218
# labels currently float in script: need to investigate (though same result)
202219
self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
203-
self.assertTrue("boxes" in out[0])
204-
self.assertTrue("scores" in out[0])
205-
self.assertTrue("labels" in out[0])
206220
# don't check script because we are compiling it here:
207221
# TODO: refactor tests
208222
# self.check_script(model, name)
@@ -213,7 +227,15 @@ def compute_mean_std(tensor):
213227
out = model(model_input)
214228
# See autocast_flaky_numerics comment at top of file.
215229
if name not in autocast_flaky_numerics:
216-
check_out(out)
230+
full_validation &= check_out(out)
231+
232+
if not full_validation:
233+
msg = "The output of {} could only be partially validated. " \
234+
"This is likely due to unit-test flakiness, but you may " \
235+
"want to do additional manual checks if you made " \
236+
"significant changes to the codebase.".format(self._testMethodName)
237+
warnings.warn(msg, RuntimeWarning)
238+
raise unittest.SkipTest(msg)
217239

218240
def _test_detection_model_validation(self, name):
219241
set_rng_seed(0)

0 commit comments

Comments
 (0)