Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions tests/ignite/metrics/test_root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,40 @@ def test_zero_sample():
rmse.compute()


def test_compute():
@pytest.fixture(params=[0, 1, 2, 3])
def test_data(request):
return [
(torch.empty(10).uniform_(0, 10), torch.empty(10).uniform_(0, 10), 1),
(torch.empty(10, 1).uniform_(-10, 10), torch.empty(10, 1).uniform_(-10, 10), 1),
# updated batches
(torch.empty(50).uniform_(0, 10), torch.empty(50).uniform_(0, 10), 16),
(torch.empty(50, 1).uniform_(-10, 10), torch.empty(50, 1).uniform_(-10, 10), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(3))
def test_compute(n_times, test_data):

rmse = RootMeanSquaredError()

def _test(y_pred, y, batch_size):
rmse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
rmse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
rmse.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

np_res = np.sqrt(np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0])
res = rmse.compute()

assert isinstance(res, float)
assert pytest.approx(res) == np_res

def get_test_cases():

test_cases = [
(torch.empty(10).uniform_(0, 10), torch.empty(10).uniform_(0, 10), 1),
(torch.empty(10, 1).uniform_(-10, 10), torch.empty(10, 1).uniform_(-10, 10), 1),
# updated batches
(torch.empty(50).uniform_(0, 10), torch.empty(50).uniform_(0, 10), 16),
(torch.empty(50, 1).uniform_(-10, 10), torch.empty(50, 1).uniform_(-10, 10), 16),
]

return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
y_pred, y, batch_size = test_data
rmse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
rmse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
rmse.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

np_res = np.sqrt(np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0])
res = rmse.compute()

assert isinstance(res, float)
assert pytest.approx(res) == np_res


def _test_distrib_integration(device, tol=1e-6):
Expand Down