diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 4eb347b2d11f..b195222a0ae9 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -358,22 +358,22 @@ def _test(metric_device): def _test_distrib_integration_multiclass(device): rank = idist.get_rank() - torch.manual_seed(12) def _test(n_epochs, metric_device): metric_device = torch.device(metric_device) n_iters = 80 - s = 16 + batch_size = 16 n_classes = 10 - offset = n_iters * s - y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) - y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) + torch.manual_seed(12 + rank) + + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) def update(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset], + y_preds[i * batch_size : (i + 1) * batch_size, :], + y_true[i * batch_size : (i + 1) * batch_size], ) engine = Engine(update) @@ -384,6 +384,9 @@ def update(engine, i): data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) + assert ( acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"