@@ -88,24 +88,7 @@ def is_iterable(obj):
88
88
class TestCase (unittest .TestCase ):
89
89
precision = 1e-5
90
90
91
- def assertExpected (self , output , subname = None , prec = None , strip_suffix = None ):
92
- r"""
93
- Test that a python value matches the recorded contents of a file
94
- derived from the name of this test and subname. The value must be
95
- pickable with `torch.save`. This file
96
- is placed in the 'expect' directory in the same directory
97
- as the test script. You can automatically update the recorded test
98
- output using --accept.
99
-
100
- If you call this multiple times in a single function, you must
101
- give a unique subname each time.
102
-
103
- strip_suffix allows different tests that expect similar numerics, e.g.
104
- "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
105
- test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
106
- strip_suffix="_cpu", and they would both use a data file name based on
107
- "test_xyz".
108
- """
91
+ def _get_expected_file (self , subname = None , strip_suffix = None ):
109
92
def remove_prefix_suffix (text , prefix , suffix ):
110
93
if text .startswith (prefix ):
111
94
text = text [len (prefix ):]
@@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix):
128
111
subname_output = " ({})" .format (subname )
129
112
expected_file += "_expect.pkl"
130
113
131
- def accept_output (update_type ):
132
- print ("Accepting {} for {}{}:\n \n {}" .format (update_type , munged_id , subname_output , output ))
114
+ if not ACCEPT and not os .path .exists (expected_file ):
115
+ raise RuntimeError (
116
+ ("No expect file exists for {}{}; to accept the current output, run:\n "
117
+ "python {} {} --accept" ).format (munged_id , subname_output , __main__ .__file__ , munged_id ))
118
+
119
+ return expected_file
120
+
121
+ def assertExpected (self , output , subname = None , prec = None , strip_suffix = None ):
122
+ r"""
123
+ Test that a python value matches the recorded contents of a file
124
+ derived from the name of this test and subname. The value must be
125
+ pickable with `torch.save`. This file
126
+ is placed in the 'expect' directory in the same directory
127
+ as the test script. You can automatically update the recorded test
128
+ output using --accept.
129
+
130
+ If you call this multiple times in a single function, you must
131
+ give a unique subname each time.
132
+
133
+ strip_suffix allows different tests that expect similar numerics, e.g.
134
+ "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
135
+ test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
136
+ strip_suffix="_cpu", and they would both use a data file name based on
137
+ "test_xyz".
138
+ """
139
+ expected_file = self ._get_expected_file (subname , strip_suffix )
140
+
141
+ if ACCEPT :
142
+ print ("Accepting updated output for {}:\n \n {}" .format (os .path .basename (expected_file ), output ))
133
143
torch .save (output , expected_file )
134
144
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
135
145
binary_size = os .path .getsize (expected_file )
136
146
self .assertTrue (binary_size <= MAX_PICKLE_SIZE )
137
-
138
- try :
139
- expected = torch .load (expected_file )
140
- except IOError as e :
141
- if e .errno != errno .ENOENT :
142
- raise
143
- elif ACCEPT :
144
- accept_output ("output" )
145
- return
146
- else :
147
- raise RuntimeError (
148
- ("I got this output for {}{}:\n \n {}\n \n "
149
- "No expect file exists; to accept the current output, run:\n "
150
- "python {} {} --accept" ).format (munged_id , subname_output , output , __main__ .__file__ , munged_id ))
151
-
152
- if ACCEPT :
153
- try :
154
- self .assertEqual (output , expected , prec = prec )
155
- except Exception :
156
- accept_output ("updated output" )
157
147
else :
148
+ expected = torch .load (expected_file )
158
149
self .assertEqual (output , expected , prec = prec )
159
150
160
151
def assertEqual (self , x , y , prec = None , message = '' , allow_inf = False ):
0 commit comments