@@ -17,45 +17,40 @@ def test_zero_sample():
1717 mse .compute ()
1818
1919
20- def test_compute ():
20+ @pytest .fixture (params = [item for item in range (4 )])
21+ def test_case (request ):
22+ return [
23+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
24+ (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 1 ),
25+ # updated batches
26+ (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
27+ (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
28+ ][request .param ]
29+
30+
31+ @pytest .mark .parametrize ("n_times" , range (5 ))
32+ def test_compute (n_times , test_case ):
2133
2234 mse = MeanSquaredError ()
2335
24- def _test (y_pred , y , batch_size ):
25- mse .reset ()
26- if batch_size > 1 :
27- n_iters = y .shape [0 ] // batch_size + 1
28- for i in range (n_iters ):
29- idx = i * batch_size
30- mse .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
31- else :
32- mse .update ((y_pred , y ))
33-
34- np_y = y .numpy ()
35- np_y_pred = y_pred .numpy ()
36-
37- np_res = np .power ((np_y - np_y_pred ), 2.0 ).sum () / np_y .shape [0 ]
38-
39- assert isinstance (mse .compute (), float )
40- assert mse .compute () == np_res
41-
42- def get_test_cases ():
43-
44- test_cases = [
45- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 1 ),
46- (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 1 ),
47- # updated batches
48- (torch .randint (0 , 10 , size = (100 , 1 )), torch .randint (0 , 10 , size = (100 , 1 )), 16 ),
49- (torch .randint (- 20 , 20 , size = (100 , 5 )), torch .randint (- 20 , 20 , size = (100 , 5 )), 16 ),
50- ]
51-
52- return test_cases
53-
54- for _ in range (5 ):
55- # check multiple random inputs as random exact occurencies are rare
56- test_cases = get_test_cases ()
57- for y_pred , y , batch_size in test_cases :
58- _test (y_pred , y , batch_size )
36+ y_pred , y , batch_size = test_case
37+
38+ mse .reset ()
39+ if batch_size > 1 :
40+ n_iters = y .shape [0 ] // batch_size + 1
41+ for i in range (n_iters ):
42+ idx = i * batch_size
43+ mse .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
44+ else :
45+ mse .update ((y_pred , y ))
46+
47+ np_y = y .numpy ()
48+ np_y_pred = y_pred .numpy ()
49+
50+ np_res = np .power ((np_y - np_y_pred ), 2.0 ).sum () / np_y .shape [0 ]
51+
52+ assert isinstance (mse .compute (), float )
53+ assert mse .compute () == np_res
5954
6055
6156def _test_distrib_integration (device , tol = 1e-6 ):
0 commit comments