diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 5ee1c738a77..df180600225 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -115,9 +115,14 @@ def test_random(func, method, device, channels, fn_kwargs, match_kwargs): _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs) +@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("channels", [1, 3]) class TestColorJitter: + @pytest.fixture(autouse=True) + def set_random_seed(self, seed): + torch.random.manual_seed(seed) + @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]) def test_color_jitter_brightness(self, brightness, device, channels): tol = 1.0 + 1e-10