Skip to content

Added wrap_dataset_for_transforms_v2 into datasets and handled beta w… #7279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/test-linux-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
# Create Conda Env
conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy
conda activate /work/ci_env

# Install PyTorch, Torchvision, and testing libraries
set -ex
conda install \
Expand All @@ -55,3 +55,9 @@ jobs:
# Run Tests
python3 -m torch.utils.collect_env
python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20

# Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2"
# We keep them separate to avoid any side effects due to warnings / imports.
# TODO: Remove this and add proper tests (possibly using a sub-process solution as described
# in https://github.com/pytorch/vision/pull/7269).
python3 -m pytest -v test/check_v2_dataset_warnings.py
19 changes: 19 additions & 0 deletions test/check_v2_dataset_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest


def test_warns_if_imported_from_datasets(mocker):
mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)

import torchvision

with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
from torchvision.datasets import wrap_dataset_for_transforms_v2

assert callable(wrap_dataset_for_transforms_v2)


@pytest.mark.filterwarnings("error")
def test_no_warns_if_imported_from_datasets():
from torchvision.datasets import wrap_dataset_for_transforms_v2

assert callable(wrap_dataset_for_transforms_v2)
2 changes: 1 addition & 1 deletion test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, _):
Expand Down
43 changes: 43 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pathlib
import pickle
import random
import re
import shutil
import string
import unittest
Expand Down Expand Up @@ -3309,5 +3310,47 @@ def test_bad_input(self):
pass


class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)

def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass

dataset = MyVisionDataset("root")

with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)

def test_missing_wrapper(self):
dataset = datasets.FakeData()

with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)

def test_subclass(self, mocker):
from torchvision import datapoints

sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)

class MyFakeData(datasets.FakeData):
pass

dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

assert wrapped_dataset[0] is sentinel


if __name__ == "__main__":
unittest.main()
44 changes: 1 addition & 43 deletions test/test_prototype_datapoints.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import re

import pytest
import torch

from PIL import Image

from torchvision import datapoints, datasets
from torchvision import datapoints
from torchvision.prototype import datapoints as proto_datapoints


Expand Down Expand Up @@ -163,43 +161,3 @@ def test_bbox_instance(data, format):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format


class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)

def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass

dataset = MyVisionDataset("root")

with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)

def test_missing_wrapper(self):
dataset = datasets.FakeData()

with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)

def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)

class MyFakeData(datasets.FakeData):
pass

dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)

assert wrapped_dataset[0] is sentinel
6 changes: 2 additions & 4 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video

from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip

from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings

Expand Down
15 changes: 15 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,18 @@
"InStereo2k",
"ETH3DStereo",
)


# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# with beta state v2 warning from torchvision.datapoints
# We also want to avoid raising the warning when importing other attributes
# from torchvision.datasets
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2

return wrap_dataset_for_transforms_v2

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")