Skip to content

Commit 4738267

Browse files
authored
Simplify the assertExpected method (#2965)
* Simplify the ACCEPT=True logic in assertExpected(). * Separate the expected filename estimation from assertExpected
1 parent 32e5700 commit 4738267

File tree

1 file changed

+31
-40
lines changed

1 file changed

+31
-40
lines changed

test/common_utils.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,7 @@ def is_iterable(obj):
8888
class TestCase(unittest.TestCase):
8989
precision = 1e-5
9090

91-
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
92-
r"""
93-
Test that a python value matches the recorded contents of a file
94-
derived from the name of this test and subname. The value must be
95-
pickable with `torch.save`. This file
96-
is placed in the 'expect' directory in the same directory
97-
as the test script. You can automatically update the recorded test
98-
output using --accept.
99-
100-
If you call this multiple times in a single function, you must
101-
give a unique subname each time.
102-
103-
strip_suffix allows different tests that expect similar numerics, e.g.
104-
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
105-
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
106-
strip_suffix="_cpu", and they would both use a data file name based on
107-
"test_xyz".
108-
"""
91+
def _get_expected_file(self, subname=None, strip_suffix=None):
10992
def remove_prefix_suffix(text, prefix, suffix):
11093
if text.startswith(prefix):
11194
text = text[len(prefix):]
@@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix):
128111
subname_output = " ({})".format(subname)
129112
expected_file += "_expect.pkl"
130113

131-
def accept_output(update_type):
132-
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output))
114+
if not ACCEPT and not os.path.exists(expected_file):
115+
raise RuntimeError(
116+
("No expect file exists for {}{}; to accept the current output, run:\n"
117+
"python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id))
118+
119+
return expected_file
120+
121+
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
122+
r"""
123+
Test that a python value matches the recorded contents of a file
124+
derived from the name of this test and subname. The value must be
125+
pickable with `torch.save`. This file
126+
is placed in the 'expect' directory in the same directory
127+
as the test script. You can automatically update the recorded test
128+
output using --accept.
129+
130+
If you call this multiple times in a single function, you must
131+
give a unique subname each time.
132+
133+
strip_suffix allows different tests that expect similar numerics, e.g.
134+
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
135+
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
136+
strip_suffix="_cpu", and they would both use a data file name based on
137+
"test_xyz".
138+
"""
139+
expected_file = self._get_expected_file(subname, strip_suffix)
140+
141+
if ACCEPT:
142+
print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output))
133143
torch.save(output, expected_file)
134144
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
135145
binary_size = os.path.getsize(expected_file)
136146
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)
137-
138-
try:
139-
expected = torch.load(expected_file)
140-
except IOError as e:
141-
if e.errno != errno.ENOENT:
142-
raise
143-
elif ACCEPT:
144-
accept_output("output")
145-
return
146-
else:
147-
raise RuntimeError(
148-
("I got this output for {}{}:\n\n{}\n\n"
149-
"No expect file exists; to accept the current output, run:\n"
150-
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))
151-
152-
if ACCEPT:
153-
try:
154-
self.assertEqual(output, expected, prec=prec)
155-
except Exception:
156-
accept_output("updated output")
157147
else:
148+
expected = torch.load(expected_file)
158149
self.assertEqual(output, expected, prec=prec)
159150

160151
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):

0 commit comments

Comments
 (0)