diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index 19521cdd011..456f4dfdd99 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -55,3 +55,9 @@ jobs: # Run Tests python3 -m torch.utils.collect_env python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 + + # Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2" + # We keep them separate to avoid any side effects due to warnings / imports. + # TODO: Remove this and add proper tests (possibly using a sub-process solution as described + # in https://github.com/pytorch/vision/pull/7269). + python3 test/check_v2_dataset_warnings.py diff --git a/test/check_v2_dataset_warnings.py b/test/check_v2_dataset_warnings.py new file mode 100644 index 00000000000..e17f2a0be64 --- /dev/null +++ b/test/check_v2_dataset_warnings.py @@ -0,0 +1,35 @@ +import warnings + +import torchvision + + +def test_warns_if_imported_from_datasets(): + with warnings.catch_warnings(record=True) as w: + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) + + assert len(w) == 2 + assert "torchvision.transforms.v2" in str(w[-1].message) + + +def test_no_warns_if_imported_from_datasets(): + + torchvision.disable_beta_transforms_warning() + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) + + from torchvision.datasets import cifar + + assert hasattr(cifar, "CIFAR10") + + +if __name__ == "__main__": + # We can't rely on pytest due to various side-effects, e.g. conftest etc + test_warns_if_imported_from_datasets() + test_no_warns_if_imported_from_datasets() diff --git a/test/conftest.py b/test/conftest.py index b3ab70af650..5408c21a549 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,11 +4,11 @@ import pytest import torch import torchvision -from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG - torchvision.disable_beta_transforms_warning() +from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG + def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems)