diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index 5dc7550d868..68ebc54f2d1 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -41,7 +41,7 @@ jobs: # Create Conda Env conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy conda activate /work/ci_env - + # Install PyTorch, Torchvision, and testing libraries set -ex conda install \ @@ -55,3 +55,9 @@ jobs: # Run Tests python3 -m torch.utils.collect_env python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 + + # Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2" + # We keep them separate to avoid any side effects due to warnings / imports. + # TODO: Remove this and add proper tests (possibly using a sub-process solution as described + # in https://github.com/pytorch/vision/pull/7269). + python3 -m pytest -v test/check_v2_dataset_warnings.py diff --git a/test/check_v2_dataset_warnings.py b/test/check_v2_dataset_warnings.py new file mode 100644 index 00000000000..8bb53ee3434 --- /dev/null +++ b/test/check_v2_dataset_warnings.py @@ -0,0 +1,19 @@ +import pytest + + +def test_warns_if_imported_from_datasets(mocker): + mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True) + + import torchvision + + with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING): + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) + + +@pytest.mark.filterwarnings("error") +def test_no_warns_if_imported_from_datasets(): + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index e8290b55c4b..f4bcdfc42dc 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -584,8 +584,8 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - from torchvision.datapoints import wrap_dataset_for_transforms_v2 from torchvision.datapoints._datapoint import Datapoint + from torchvision.datasets import wrap_dataset_for_transforms_v2 try: with self.create_dataset(config) as (dataset, _): diff --git a/test/test_datasets.py b/test/test_datasets.py index 015f727a17a..605b799e7a9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -8,6 +8,7 @@ import pathlib import pickle import random +import re import shutil import string import unittest @@ -3309,5 +3310,47 @@ def test_bad_input(self): pass +class TestDatasetWrapper: + def test_unknown_type(self): + unknown_object = object() + with pytest.raises( + TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") + ): + datasets.wrap_dataset_for_transforms_v2(unknown_object) + + def test_unknown_dataset(self): + class MyVisionDataset(datasets.VisionDataset): + pass + + dataset = MyVisionDataset("root") + + with pytest.raises(TypeError, match="No wrapper exist"): + datasets.wrap_dataset_for_transforms_v2(dataset) + + def test_missing_wrapper(self): + dataset = datasets.FakeData() + + with pytest.raises(TypeError, match="please open an issue"): + datasets.wrap_dataset_for_transforms_v2(dataset) + + def test_subclass(self, mocker): + from torchvision import datapoints + + sentinel = object() + mocker.patch.dict( + datapoints._dataset_wrapper.WRAPPER_FACTORIES, + clear=False, + values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, + ) + + class MyFakeData(datasets.FakeData): + pass + + dataset = MyFakeData() + wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset) + + assert wrapped_dataset[0] is sentinel + + if __name__ == "__main__": unittest.main() diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 615fa9f614d..b7aebd4c137 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,11 +1,9 @@ -import re - import pytest import torch from PIL import Image -from torchvision import datapoints, datasets +from torchvision import datapoints from torchvision.prototype import datapoints as proto_datapoints @@ -163,43 +161,3 @@ def test_bbox_instance(data, format): if isinstance(format, str): format = datapoints.BoundingBoxFormat.from_str(format.upper()) assert bboxes.format == format - - -class TestDatasetWrapper: - def test_unknown_type(self): - unknown_object = object() - with pytest.raises( - TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") - ): - datapoints.wrap_dataset_for_transforms_v2(unknown_object) - - def test_unknown_dataset(self): - class MyVisionDataset(datasets.VisionDataset): - pass - - dataset = MyVisionDataset("root") - - with pytest.raises(TypeError, match="No wrapper exist"): - datapoints.wrap_dataset_for_transforms_v2(dataset) - - def test_missing_wrapper(self): - dataset = datasets.FakeData() - - with pytest.raises(TypeError, match="please open an issue"): - datapoints.wrap_dataset_for_transforms_v2(dataset) - - def test_subclass(self, mocker): - sentinel = object() - mocker.patch.dict( - datapoints._dataset_wrapper.WRAPPER_FACTORIES, - clear=False, - values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, - ) - - class MyFakeData(datasets.FakeData): - pass - - dataset = MyFakeData() - wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) - - assert wrapped_dataset[0] is sentinel diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index 4c90d957c59..c9343048a2a 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -1,13 +1,11 @@ +from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS + from ._bounding_box import BoundingBox, BoundingBoxFormat from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._mask import Mask from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video -from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip - -from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS - if _WARN_ABOUT_BETA_TRANSFORMS: import warnings diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e18a9a54b16..7d3357e3dc2 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -128,3 +128,18 @@ "InStereo2k", "ETH3DStereo", ) + + +# We override current module's attributes to handle the import: +# from torchvision.datasets import wrap_dataset_for_transforms_v2 +# with beta state v2 warning from torchvision.datapoints +# We also want to avoid raising the warning when importing other attributes +# from torchvision.datasets +# Ref: https://peps.python.org/pep-0562/ +def __getattr__(name): + if name in ("wrap_dataset_for_transforms_v2",): + from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2 + + return wrap_dataset_for_transforms_v2 + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")