From 20924653eebd2eaa091818e5a4e03da788e889ae Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 17 Feb 2023 11:38:53 +0100 Subject: [PATCH 1/5] Added wrap_dataset_for_transforms_v2 into datasets and handled beta warning + tests --- test/test_v2_dataset_warnings.py | 22 ++++++++++++++++++++++ torchvision/datasets/__init__.py | 9 +++++++++ 2 files changed, 31 insertions(+) create mode 100644 test/test_v2_dataset_warnings.py diff --git a/test/test_v2_dataset_warnings.py b/test/test_v2_dataset_warnings.py new file mode 100644 index 00000000000..6a787605473 --- /dev/null +++ b/test/test_v2_dataset_warnings.py @@ -0,0 +1,22 @@ +import pytest + + +def test_warns_if_imported_from_datasets(): + import torchvision + + value = torchvision._WARN_ABOUT_BETA_TRANSFORMS + + setattr(torchvision, "_WARN_ABOUT_BETA_TRANSFORMS", True) + + with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING): + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + setattr(torchvision, "_WARN_ABOUT_BETA_TRANSFORMS", value) + + +def test_no_warns_if_imported_from_datasets(): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error") + from torchvision.datasets import wrap_dataset_for_transforms_v2 diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e18a9a54b16..cfdc365eb42 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -128,3 +128,12 @@ "InStereo2k", "ETH3DStereo", ) + + +def __getattr__(name): + if name in ("wrap_dataset_for_transforms_v2",): + from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2 + + return wrap_dataset_for_transforms_v2 + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From eafeb5b6d90366b39555be09b78d59b52bf0c199 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 17 Feb 2023 14:06:38 +0100 Subject: [PATCH 2/5] Addressed PR comments --- .github/workflows/test-linux-cpu.yml | 5 ++++- test/check_v2_dataset_warnings.py | 19 +++++++++++++++++++ test/test_v2_dataset_warnings.py | 22 ---------------------- torchvision/datapoints/__init__.py | 2 -- torchvision/datasets/__init__.py | 6 ++++++ 5 files changed, 29 insertions(+), 25 deletions(-) create mode 100644 test/check_v2_dataset_warnings.py delete mode 100644 test/test_v2_dataset_warnings.py diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index 5dc7550d868..7b807791e91 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -41,7 +41,7 @@ jobs: # Create Conda Env conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy conda activate /work/ci_env - + # Install PyTorch, Torchvision, and testing libraries set -ex conda install \ @@ -55,3 +55,6 @@ jobs: # Run Tests python3 -m torch.utils.collect_env python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 + + # Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2" + python3 -m pytest -v test/check_v2_dataset_warnings.py diff --git a/test/check_v2_dataset_warnings.py b/test/check_v2_dataset_warnings.py new file mode 100644 index 00000000000..8bb53ee3434 --- /dev/null +++ b/test/check_v2_dataset_warnings.py @@ -0,0 +1,19 @@ +import pytest + + +def test_warns_if_imported_from_datasets(mocker): + mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True) + + import torchvision + + with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING): + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) + + +@pytest.mark.filterwarnings("error") +def test_no_warns_if_imported_from_datasets(): + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + assert callable(wrap_dataset_for_transforms_v2) diff --git a/test/test_v2_dataset_warnings.py b/test/test_v2_dataset_warnings.py deleted file mode 100644 index 6a787605473..00000000000 --- a/test/test_v2_dataset_warnings.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - - -def test_warns_if_imported_from_datasets(): - import torchvision - - value = torchvision._WARN_ABOUT_BETA_TRANSFORMS - - setattr(torchvision, "_WARN_ABOUT_BETA_TRANSFORMS", True) - - with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING): - from torchvision.datasets import wrap_dataset_for_transforms_v2 - - setattr(torchvision, "_WARN_ABOUT_BETA_TRANSFORMS", value) - - -def test_no_warns_if_imported_from_datasets(): - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("error") - from torchvision.datasets import wrap_dataset_for_transforms_v2 diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index 4c90d957c59..195023d5814 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -4,8 +4,6 @@ from ._mask import Mask from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video -from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip - from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS if _WARN_ABOUT_BETA_TRANSFORMS: diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index cfdc365eb42..7d3357e3dc2 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -130,6 +130,12 @@ ) +# We override current module's attributes to handle the import: +# from torchvision.datasets import wrap_dataset_for_transforms_v2 +# with beta state v2 warning from torchvision.datapoints +# We also want to avoid raising the warning when importing other attributes +# from torchvision.datasets +# Ref: https://peps.python.org/pep-0562/ def __getattr__(name): if name in ("wrap_dataset_for_transforms_v2",): from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2 From 24fd7d6381b8078f44988a4cb93e0c0c4e8e4712 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 17 Feb 2023 14:23:52 +0100 Subject: [PATCH 3/5] Fixed wrong imports --- test/datasets_utils.py | 2 +- test/test_datasets.py | 43 ++++++++++++++++++++++++++++++ test/test_prototype_datapoints.py | 41 ---------------------------- torchvision/datapoints/__init__.py | 4 +-- 4 files changed, 46 insertions(+), 44 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index e8290b55c4b..f4bcdfc42dc 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -584,8 +584,8 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - from torchvision.datapoints import wrap_dataset_for_transforms_v2 from torchvision.datapoints._datapoint import Datapoint + from torchvision.datasets import wrap_dataset_for_transforms_v2 try: with self.create_dataset(config) as (dataset, _): diff --git a/test/test_datasets.py b/test/test_datasets.py index 015f727a17a..605b799e7a9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -8,6 +8,7 @@ import pathlib import pickle import random +import re import shutil import string import unittest @@ -3309,5 +3310,47 @@ def test_bad_input(self): pass +class TestDatasetWrapper: + def test_unknown_type(self): + unknown_object = object() + with pytest.raises( + TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") + ): + datasets.wrap_dataset_for_transforms_v2(unknown_object) + + def test_unknown_dataset(self): + class MyVisionDataset(datasets.VisionDataset): + pass + + dataset = MyVisionDataset("root") + + with pytest.raises(TypeError, match="No wrapper exist"): + datasets.wrap_dataset_for_transforms_v2(dataset) + + def test_missing_wrapper(self): + dataset = datasets.FakeData() + + with pytest.raises(TypeError, match="please open an issue"): + datasets.wrap_dataset_for_transforms_v2(dataset) + + def test_subclass(self, mocker): + from torchvision import datapoints + + sentinel = object() + mocker.patch.dict( + datapoints._dataset_wrapper.WRAPPER_FACTORIES, + clear=False, + values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, + ) + + class MyFakeData(datasets.FakeData): + pass + + dataset = MyFakeData() + wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset) + + assert wrapped_dataset[0] is sentinel + + if __name__ == "__main__": unittest.main() diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 615fa9f614d..436fc19c5a1 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -5,7 +5,6 @@ from PIL import Image -from torchvision import datapoints, datasets from torchvision.prototype import datapoints as proto_datapoints @@ -163,43 +162,3 @@ def test_bbox_instance(data, format): if isinstance(format, str): format = datapoints.BoundingBoxFormat.from_str(format.upper()) assert bboxes.format == format - - -class TestDatasetWrapper: - def test_unknown_type(self): - unknown_object = object() - with pytest.raises( - TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") - ): - datapoints.wrap_dataset_for_transforms_v2(unknown_object) - - def test_unknown_dataset(self): - class MyVisionDataset(datasets.VisionDataset): - pass - - dataset = MyVisionDataset("root") - - with pytest.raises(TypeError, match="No wrapper exist"): - datapoints.wrap_dataset_for_transforms_v2(dataset) - - def test_missing_wrapper(self): - dataset = datasets.FakeData() - - with pytest.raises(TypeError, match="please open an issue"): - datapoints.wrap_dataset_for_transforms_v2(dataset) - - def test_subclass(self, mocker): - sentinel = object() - mocker.patch.dict( - datapoints._dataset_wrapper.WRAPPER_FACTORIES, - clear=False, - values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, - ) - - class MyFakeData(datasets.FakeData): - pass - - dataset = MyFakeData() - wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) - - assert wrapped_dataset[0] is sentinel diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index 195023d5814..c9343048a2a 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -1,11 +1,11 @@ +from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS + from ._bounding_box import BoundingBox, BoundingBoxFormat from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._mask import Mask from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video -from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS - if _WARN_ABOUT_BETA_TRANSFORMS: import warnings From 474b1ccf850077bae8438ef9b1f3fc0b78b0cc3d Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 17 Feb 2023 14:50:01 +0100 Subject: [PATCH 4/5] Update .github/workflows/test-linux-cpu.yml Co-authored-by: Nicolas Hug --- .github/workflows/test-linux-cpu.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index 7b807791e91..68ebc54f2d1 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -57,4 +57,7 @@ jobs: python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 # Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2" + # We keep them separate to avoid any side effects due to warnings / imports. + # TODO: Remove this and add proper tests (possibly using a sub-process solution as described + # in https://github.com/pytorch/vision/pull/7269). python3 -m pytest -v test/check_v2_dataset_warnings.py From 25a7ef93801708ca91088ca428d3af93cf2176ed Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 17 Feb 2023 14:53:17 +0100 Subject: [PATCH 5/5] Fixed import issue in test_prototype_datapoints.py --- test/test_prototype_datapoints.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 436fc19c5a1..b7aebd4c137 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,10 +1,9 @@ -import re - import pytest import torch from PIL import Image +from torchvision import datapoints from torchvision.prototype import datapoints as proto_datapoints