Skip to content

Commit d744da9

Browse files
pmeierNicolasHug
andauthored
allow subclasses in dataset wrappers (#7236)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent b570f2c commit d744da9

File tree

3 files changed

+76
-15
lines changed

3 files changed

+76
-15
lines changed

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def test_transforms_v2_wrapper(self, config):
596596
wrapped_sample = wrapped_dataset[0]
597597
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
598598
except TypeError as error:
599-
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
599+
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
600600
return
601601
raise error
602602
except RuntimeError as error:

test/test_prototype_datapoints.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import re
2+
13
import pytest
24
import torch
35

46
from PIL import Image
7+
8+
from torchvision import datasets
59
from torchvision.prototype import datapoints
610

711

@@ -159,3 +163,43 @@ def test_bbox_instance(data, format):
159163
if isinstance(format, str):
160164
format = datapoints.BoundingBoxFormat.from_str(format.upper())
161165
assert bboxes.format == format
166+
167+
168+
class TestDatasetWrapper:
169+
def test_unknown_type(self):
170+
unknown_object = object()
171+
with pytest.raises(
172+
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
173+
):
174+
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
175+
176+
def test_unknown_dataset(self):
177+
class MyVisionDataset(datasets.VisionDataset):
178+
pass
179+
180+
dataset = MyVisionDataset("root")
181+
182+
with pytest.raises(TypeError, match="No wrapper exist"):
183+
datapoints.wrap_dataset_for_transforms_v2(dataset)
184+
185+
def test_missing_wrapper(self):
186+
dataset = datasets.FakeData()
187+
188+
with pytest.raises(TypeError, match="please open an issue"):
189+
datapoints.wrap_dataset_for_transforms_v2(dataset)
190+
191+
def test_subclass(self, mocker):
192+
sentinel = object()
193+
mocker.patch.dict(
194+
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
195+
clear=False,
196+
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
197+
)
198+
199+
class MyFakeData(datasets.FakeData):
200+
pass
201+
202+
dataset = MyFakeData()
203+
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
204+
205+
assert wrapped_dataset[0] is sentinel

torchvision/prototype/datapoints/_dataset_wrapper.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,26 @@ def decorator(wrapper_factory):
3939
class VisionDatasetDatapointWrapper(Dataset):
4040
def __init__(self, dataset):
4141
dataset_cls = type(dataset)
42-
wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls)
43-
if wrapper_factory is None:
44-
# TODO: If we have documentation on how to do that, put a link in the error message.
45-
msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
46-
if dataset_cls in datasets.__dict__.values():
47-
msg = (
48-
f"{msg} If an automated wrapper for this dataset would be useful for you, "
49-
f"please open an issue at https://github.com/pytorch/vision/issues."
50-
)
51-
raise TypeError(msg)
42+
43+
if not isinstance(dataset, datasets.VisionDataset):
44+
raise TypeError(
45+
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
46+
f"but got a '{dataset_cls.__name__}' instead."
47+
)
48+
49+
for cls in dataset_cls.mro():
50+
if cls in WRAPPER_FACTORIES:
51+
wrapper_factory = WRAPPER_FACTORIES[cls]
52+
break
53+
elif cls is datasets.VisionDataset:
54+
# TODO: If we have documentation on how to do that, put a link in the error message.
55+
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
56+
if dataset_cls in datasets.__dict__.values():
57+
msg = (
58+
f"{msg} If an automated wrapper for this dataset would be useful for you, "
59+
f"please open an issue at https://github.com/pytorch/vision/issues."
60+
)
61+
raise TypeError(msg)
5262

5363
self._dataset = dataset
5464
self._wrapper = wrapper_factory(dataset)
@@ -98,6 +108,13 @@ def identity(item):
98108
return item
99109

100110

111+
def identity_wrapper_factory(dataset):
112+
def wrapper(idx, sample):
113+
return sample
114+
115+
return wrapper
116+
117+
101118
def pil_image_to_mask(pil_image):
102119
return datapoints.Mask(pil_image)
103120

@@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
125142

126143

127144
def classification_wrapper_factory(dataset):
128-
def wrapper(idx, sample):
129-
return sample
130-
131-
return wrapper
145+
return identity_wrapper_factory(dataset)
132146

133147

134148
for dataset_cls in [
@@ -237,6 +251,9 @@ def wrapper(idx, sample):
237251
return wrapper
238252

239253

254+
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
255+
256+
240257
VOC_DETECTION_CATEGORIES = [
241258
"__background__",
242259
"aeroplane",

0 commit comments

Comments
 (0)