Skip to content

Commit 1036f1c

Browse files
Parametrized tests for test_mean_squared_error.py (#2630)
1 parent 81705c4 commit 1036f1c

File tree

1 file changed

+31
-36
lines changed

1 file changed

+31
-36
lines changed

tests/ignite/metrics/test_mean_squared_error.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,40 @@ def test_zero_sample():
1717
mse.compute()
1818

1919

20-
def test_compute():
20+
@pytest.fixture(params=[item for item in range(4)])
21+
def test_case(request):
22+
return [
23+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
24+
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
25+
# updated batches
26+
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
27+
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
28+
][request.param]
29+
30+
31+
@pytest.mark.parametrize("n_times", range(5))
32+
def test_compute(n_times, test_case):
2133

2234
mse = MeanSquaredError()
2335

24-
def _test(y_pred, y, batch_size):
25-
mse.reset()
26-
if batch_size > 1:
27-
n_iters = y.shape[0] // batch_size + 1
28-
for i in range(n_iters):
29-
idx = i * batch_size
30-
mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
31-
else:
32-
mse.update((y_pred, y))
33-
34-
np_y = y.numpy()
35-
np_y_pred = y_pred.numpy()
36-
37-
np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0]
38-
39-
assert isinstance(mse.compute(), float)
40-
assert mse.compute() == np_res
41-
42-
def get_test_cases():
43-
44-
test_cases = [
45-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 1),
46-
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 1),
47-
# updated batches
48-
(torch.randint(0, 10, size=(100, 1)), torch.randint(0, 10, size=(100, 1)), 16),
49-
(torch.randint(-20, 20, size=(100, 5)), torch.randint(-20, 20, size=(100, 5)), 16),
50-
]
51-
52-
return test_cases
53-
54-
for _ in range(5):
55-
# check multiple random inputs as random exact occurencies are rare
56-
test_cases = get_test_cases()
57-
for y_pred, y, batch_size in test_cases:
58-
_test(y_pred, y, batch_size)
36+
y_pred, y, batch_size = test_case
37+
38+
mse.reset()
39+
if batch_size > 1:
40+
n_iters = y.shape[0] // batch_size + 1
41+
for i in range(n_iters):
42+
idx = i * batch_size
43+
mse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
44+
else:
45+
mse.update((y_pred, y))
46+
47+
np_y = y.numpy()
48+
np_y_pred = y_pred.numpy()
49+
50+
np_res = np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0]
51+
52+
assert isinstance(mse.compute(), float)
53+
assert mse.compute() == np_res
5954

6055

6156
def _test_distrib_integration(device, tol=1e-6):

0 commit comments

Comments
 (0)