Skip to content

Commit 1e35ee7

Browse files
committed
reinstate old test
1 parent f339e6c commit 1e35ee7

File tree

2 files changed

+107
-66
lines changed

2 files changed

+107
-66
lines changed

test/datasets_utils.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -171,46 +171,6 @@ def wrapper(self):
171171
return wrapper
172172

173173

174-
def _no_collate(batch):
175-
return batch
176-
177-
178-
def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_target_keys=False):
179-
from torch.utils.data import DataLoader
180-
from torchvision import datapoints
181-
from torchvision.datasets import wrap_dataset_for_transforms_v2
182-
183-
def check_wrapped_samples(dataset):
184-
for wrapped_sample in dataset:
185-
assert tree_any(
186-
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
187-
)
188-
189-
target_keyss = [None]
190-
if supports_target_keys:
191-
target_keyss.append("all")
192-
193-
for target_keys in target_keyss:
194-
with dataset_test_case.create_dataset(config) as (dataset, info):
195-
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
196-
197-
assert isinstance(wrapped_dataset, type(dataset))
198-
assert len(wrapped_dataset) == info["num_examples"]
199-
200-
check_wrapped_samples(wrapped_dataset)
201-
202-
# On macOS, forking for multiprocessing is not available and thus spawning is used by default. For this to work,
203-
# the whole pipeline including the dataset needs to be pickleable, which is what we are enforcing here.
204-
if platform.system() == "Darwin":
205-
with dataset_test_case.create_dataset(config) as (dataset, _):
206-
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
207-
dataloader = DataLoader(
208-
wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate
209-
)
210-
211-
check_wrapped_samples(dataloader)
212-
213-
214174
class DatasetTestCase(unittest.TestCase):
215175
"""Abstract base class for all dataset testcases.
216176
@@ -606,6 +566,42 @@ def test_transforms(self, config):
606566

607567
mock.assert_called()
608568

569+
@test_all_configs
570+
def test_transforms_v2_wrapper(self, config):
571+
from torchvision import datapoints
572+
from torchvision.datasets import wrap_dataset_for_transforms_v2
573+
574+
try:
575+
with self.create_dataset(config) as (dataset, info):
576+
for target_keys in [None, "all"]:
577+
if target_keys is not None and self.DATASET_CLASS not in {
578+
torchvision.datasets.CocoDetection,
579+
torchvision.datasets.VOCDetection,
580+
torchvision.datasets.Kitti,
581+
torchvision.datasets.WIDERFace,
582+
}:
583+
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
584+
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
585+
continue
586+
587+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
588+
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
589+
assert len(wrapped_dataset) == info["num_examples"]
590+
591+
wrapped_sample = wrapped_dataset[0]
592+
assert tree_any(
593+
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
594+
)
595+
except TypeError as error:
596+
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
597+
if str(error).startswith(msg):
598+
pytest.skip(msg)
599+
raise error
600+
except RuntimeError as error:
601+
if "currently not supported by this wrapper" in str(error):
602+
pytest.skip("Config is currently not supported by this wrapper")
603+
raise error
604+
609605

610606
class ImageDatasetTestCase(DatasetTestCase):
611607
"""Abstract base class for image dataset testcases.
@@ -687,6 +683,40 @@ def wrapper(tmpdir, config):
687683

688684
return wrapper
689685

686+
@test_all_configs
687+
def test_transforms_v2_wrapper(self, config):
688+
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
689+
# or use the supported `"TCHW"`
690+
if config.setdefault("output_format", "TCHW") == "THWC":
691+
return
692+
693+
super().test_transforms_v2_wrapper.__wrapped__(self, config)
694+
695+
696+
def _no_collate(batch):
697+
return batch
698+
699+
700+
def check_transforms_v2_wrapper_spawn(dataset):
701+
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
702+
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
703+
# we are enforcing here.
704+
if platform.system() != "Darwin":
705+
pytest.skip("Multiprocessing spawning is only checked on macOS.")
706+
707+
from torch.utils.data import DataLoader
708+
from torchvision import datapoints
709+
from torchvision.datasets import wrap_dataset_for_transforms_v2
710+
711+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
712+
713+
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
714+
715+
for wrapped_sample in dataloader:
716+
assert tree_any(
717+
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
718+
)
719+
690720

691721
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
692722
r"""Create a random uint8 tensor.

test/test_datasets.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,9 @@ def test_combined_targets(self):
183183
), "Type of the combined target does not match the type of the corresponding individual target: "
184184
f"{actual} is not {expected}",
185185

186-
def test_transforms_v2_wrapper(self):
187-
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type="category"))
186+
def test_transforms_v2_wrapper_spawn(self):
187+
with self.create_dataset(target_type="category") as (dataset, _):
188+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
188189

189190

190191
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
@@ -261,8 +262,9 @@ def inject_fake_data(self, tmpdir, config):
261262

262263
return split_to_num_examples[config["split"]]
263264

264-
def test_transforms_v2_wrapper(self):
265-
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
265+
def test_transforms_v2_wrapper_spawn(self):
266+
with self.create_dataset() as (dataset, _):
267+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
266268

267269

268270
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
@@ -388,9 +390,10 @@ def test_feature_types_target_polygon(self):
388390
assert isinstance(polygon_img, PIL.Image.Image)
389391
(polygon_target, info["expected_polygon_target"])
390392

391-
def test_transforms_v2_wrapper(self):
393+
def test_transforms_v2_wrapper_spawn(self):
392394
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
393-
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type))
395+
with self.create_dataset(target_type=target_type) as (dataset, _):
396+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
394397

395398

396399
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
@@ -423,8 +426,9 @@ def inject_fake_data(self, tmpdir, config):
423426
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
424427
return num_examples
425428

426-
def test_transforms_v2_wrapper(self):
427-
datasets_utils.check_transforms_v2_wrapper(self)
429+
def test_transforms_v2_wrapper_spawn(self):
430+
with self.create_dataset() as (dataset, _):
431+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
428432

429433

430434
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
@@ -620,9 +624,10 @@ def test_images_names_split(self):
620624

621625
assert merged_imgs_names == all_imgs_names
622626

623-
def test_transforms_v2_wrapper(self):
627+
def test_transforms_v2_wrapper_spawn(self):
624628
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
625-
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type))
629+
with self.create_dataset(target_type=target_type) as (dataset, _):
630+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
626631

627632

628633
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
@@ -711,8 +716,9 @@ def add_bndbox(obj, bndbox=None):
711716

712717
return data
713718

714-
def test_transforms_v2_wrapper(self):
715-
datasets_utils.check_transforms_v2_wrapper(self)
719+
def test_transforms_v2_wrapper_spawn(self):
720+
with self.create_dataset() as (dataset, _):
721+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
716722

717723

718724
class VOCDetectionTestCase(VOCSegmentationTestCase):
@@ -734,9 +740,9 @@ def test_annotations(self):
734740

735741
assert object == info["annotation"]
736742

737-
def test_transforms_v2_wrapper(self):
738-
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
739-
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
743+
def test_transforms_v2_wrapper_spawn(self):
744+
with self.create_dataset() as (dataset, _):
745+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
740746

741747

742748
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
@@ -808,8 +814,9 @@ def _create_json(self, root, name, content):
808814
json.dump(content, fh)
809815
return file
810816

811-
def test_transforms_v2_wrapper(self):
812-
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
817+
def test_transforms_v2_wrapper_spawn(self):
818+
with self.create_dataset() as (dataset, _):
819+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
813820

814821

815822
class CocoCaptionsTestCase(CocoDetectionTestCase):
@@ -827,7 +834,7 @@ def test_captions(self):
827834
_, captions = dataset[0]
828835
assert tuple(captions) == tuple(info["captions"])
829836

830-
def test_transforms_v2_wrapper(self):
837+
def test_transforms_v2_wrapper_spawn(self):
831838
# We need to define this method, because otherwise the test from the super class will
832839
# be run
833840
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
@@ -998,8 +1005,9 @@ def inject_fake_data(self, tmpdir, config):
9981005
)
9991006
return num_videos_per_class * len(classes)
10001007

1001-
def test_transforms_v2_wrapper(self):
1002-
datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW"))
1008+
def test_transforms_v2_wrapper_spawn(self):
1009+
with self.create_dataset(output_format="TCHW") as (dataset, _):
1010+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
10031011

10041012

10051013
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
@@ -1228,8 +1236,9 @@ def _create_segmentation(self, size):
12281236
def _file_stem(self, idx):
12291237
return f"2008_{idx:06d}"
12301238

1231-
def test_transforms_v2_wrapper(self):
1232-
datasets_utils.check_transforms_v2_wrapper(self, config=dict(mode="segmentation"))
1239+
def test_transforms_v2_wrapper_spawn(self):
1240+
with self.create_dataset(mode="segmentation") as (dataset, _):
1241+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
12331242

12341243

12351244
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
@@ -1680,8 +1689,9 @@ def inject_fake_data(self, tmpdir, config):
16801689

16811690
return split_to_num_examples[config["train"]]
16821691

1683-
def test_transforms_v2_wrapper(self):
1684-
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
1692+
def test_transforms_v2_wrapper_spawn(self):
1693+
with self.create_dataset() as (dataset, _):
1694+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
16851695

16861696

16871697
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
@@ -2557,8 +2567,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25572567
breed_id = "-1"
25582568
return (image_id, class_id, species, breed_id)
25592569

2560-
def test_transforms_v2_wrapper(self):
2561-
datasets_utils.check_transforms_v2_wrapper(self)
2570+
def test_transforms_v2_wrapper_spawn(self):
2571+
with self.create_dataset() as (dataset, _):
2572+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
25622573

25632574

25642575
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):

0 commit comments

Comments
 (0)