From 7c203239b8f74fa5eb28be192f379da8acfaa98a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 28 Sep 2021 19:03:45 +0100 Subject: [PATCH 1/6] Add autouse fixture to save and reset RNG in tests --- test/conftest.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 3cffeeac88f..eb995438269 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -80,3 +80,15 @@ def pytest_sessionfinish(session, exitstatus): # To avoid this, we transform this 5 into a 0 to make testpilot happy. if exitstatus == 5: session.exitstatus = 0 + + +@pytest.fixture(autouse=True) +def prevent_leaking_rng(): + # prevent each test from leaking the rng to all other test when they call torch.manual_seed() + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + yield + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) From 910d213aa41b8b4032a4e1b16dd27e925a3cf6ab Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Sep 2021 09:24:10 +0100 Subject: [PATCH 2/6] Add other RNG generators --- test/conftest.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index eb995438269..a84b9f8dd52 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,7 @@ from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG import torch +import numpy as np +import random import pytest @@ -84,11 +86,22 @@ def pytest_sessionfinish(session, exitstatus): @pytest.fixture(autouse=True) def prevent_leaking_rng(): - # prevent each test from leaking the rng to all other test when they call torch.manual_seed() - rng_state = torch.get_rng_state() + # Prevent each test from leaking the rng to all other test when they call + # torch.manual_seed() or random.seed() or np.random.seed(). + # Note: the numpy rngs should never leak anyway, as we never use + # np.random.seed() and instead rely on np.random.RandomState instances (see + # issue #4247). We still do it for extra precaution. + + torch_rng_state = torch.get_rng_state() + builtin_rng_state = random.getstate() + nunmpy_rng_state = np.random.get_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + yield + + torch.set_rng_state(torch_rng_state) + random.setstate(builtin_rng_state) + np.random.set_state(nunmpy_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) - torch.set_rng_state(rng_state) From 97d7c4aab84097c9176f8599d96a0793361c15a9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Sep 2021 09:26:52 +0100 Subject: [PATCH 3/6] delete freeze_rng_state --- test/common_utils.py | 11 ----------- test/test_models.py | 18 +++++++----------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a9f6703fcd0..a9ad27fecd0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -82,17 +82,6 @@ def is_iterable(obj): return False -@contextlib.contextmanager -def freeze_rng_state(): - rng_state = torch.get_rng_state() - if torch.cuda.is_available(): - cuda_rng_state = torch.cuda.get_rng_state() - yield - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) - torch.set_rng_state(rng_state) - - def cycle_over(objs): for idx, obj1 in enumerate(objs): for obj2 in objs[:idx] + objs[idx + 1:]: diff --git a/test/test_models.py b/test/test_models.py index 9e376bedce5..3d5a198b6a3 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,7 +1,7 @@ import os import io import sys -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import map_nested_tensor_object, set_rng_seed, cpu_and_gpu, needs_cuda from _utils_internal import get_relative_path from collections import OrderedDict import functools @@ -101,10 +101,8 @@ def get_export_import_copy(m): return imported m_import = get_export_import_copy(m) - with freeze_rng_state(): - results = m(*args) - with freeze_rng_state(): - results_from_imported = m_import(*args) + results = m(*args) + results_from_imported = m_import(*args) tol = 3e-4 try: torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) @@ -129,13 +127,11 @@ def get_export_import_copy(m): sm = torch.jit.script(nn_module) - with freeze_rng_state(): - eager_out = nn_module(*args) + eager_out = nn_module(*args) - with freeze_rng_state(): - script_out = sm(*args) - if unwrapper: - script_out = unwrapper(script_out) + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) assert_export_import_module(sm, args) From 7566dc4cdaca224cc3d12f494b7d481a485e47bd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Sep 2021 12:57:57 +0100 Subject: [PATCH 4/6] Hopefully fix GaussianBlur test --- test/test_transforms_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index f8316c9da05..0b8045845da 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -67,6 +67,7 @@ def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_matc f = method(**meth_kwargs) scripted_fn = torch.jit.script(f) + torch.manual_seed(12) tensor, pil_img = _create_data(26, 34, channels, device=device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) From 92cc80ac0bc5998fbe6ee2cc273037a22c28d661 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Sep 2021 12:58:38 +0100 Subject: [PATCH 5/6] Alternative fix, probably better --- test/test_transforms_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0b8045845da..8aa398bd006 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -67,7 +67,6 @@ def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_matc f = method(**meth_kwargs) scripted_fn = torch.jit.script(f) - torch.manual_seed(12) tensor, pil_img = _create_data(26, 34, channels, device=device) # set seed to reproduce the same transformation for tensor and PIL image torch.manual_seed(12) @@ -715,6 +714,7 @@ def test_random_apply(device): @pytest.mark.parametrize('channels', [1, 3]) def test_gaussian_blur(device, channels, meth_kwargs): tol = 1.0 + 1e-10 + torch.manual_seed(12) _test_class_op( T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels, test_exact_match=False, device=device, agg_method="max", tol=tol From 3d58993024e76a43020750bd81a698b56d9760dc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 29 Sep 2021 13:54:28 +0100 Subject: [PATCH 6/6] revert changes to test_models --- test/common_utils.py | 11 +++++++++++ test/test_models.py | 18 +++++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a9ad27fecd0..a9f6703fcd0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -82,6 +82,17 @@ def is_iterable(obj): return False +@contextlib.contextmanager +def freeze_rng_state(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + yield + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) + + def cycle_over(objs): for idx, obj1 in enumerate(objs): for obj2 in objs[:idx] + objs[idx + 1:]: diff --git a/test/test_models.py b/test/test_models.py index 3d5a198b6a3..9e376bedce5 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,7 +1,7 @@ import os import io import sys -from common_utils import map_nested_tensor_object, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from _utils_internal import get_relative_path from collections import OrderedDict import functools @@ -101,8 +101,10 @@ def get_export_import_copy(m): return imported m_import = get_export_import_copy(m) - results = m(*args) - results_from_imported = m_import(*args) + with freeze_rng_state(): + results = m(*args) + with freeze_rng_state(): + results_from_imported = m_import(*args) tol = 3e-4 try: torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) @@ -127,11 +129,13 @@ def get_export_import_copy(m): sm = torch.jit.script(nn_module) - eager_out = nn_module(*args) + with freeze_rng_state(): + eager_out = nn_module(*args) - script_out = sm(*args) - if unwrapper: - script_out = unwrapper(script_out) + with freeze_rng_state(): + script_out = sm(*args) + if unwrapper: + script_out = unwrapper(script_out) torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) assert_export_import_module(sm, args)