Skip to content

Commit 033b106

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] enforce pickleability for v2 transforms and wrapped datasets (#7860)
Summary: (Note: this ignores all push blocking failures!) Reviewed By: matteobettini Differential Revision: D48900370 fbshipit-source-id: be1b23dcab58d2a8b5bca7190f94c0123263d036
1 parent 10e7cb8 commit 033b106

File tree

6 files changed

+102
-10
lines changed

6 files changed

+102
-10
lines changed

test/datasets_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import os
77
import pathlib
8+
import platform
89
import random
910
import shutil
1011
import string
@@ -548,7 +549,7 @@ def test_feature_types(self, config):
548549
@test_all_configs
549550
def test_num_examples(self, config):
550551
with self.create_dataset(config) as (dataset, info):
551-
assert len(dataset) == info["num_examples"]
552+
assert len(list(dataset)) == len(dataset) == info["num_examples"]
552553

553554
@test_all_configs
554555
def test_transforms(self, config):
@@ -692,6 +693,31 @@ def test_transforms_v2_wrapper(self, config):
692693
super().test_transforms_v2_wrapper.__wrapped__(self, config)
693694

694695

696+
def _no_collate(batch):
697+
return batch
698+
699+
700+
def check_transforms_v2_wrapper_spawn(dataset):
701+
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
702+
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
703+
# we are enforcing here.
704+
if platform.system() != "Darwin":
705+
pytest.skip("Multiprocessing spawning is only checked on macOS.")
706+
707+
from torch.utils.data import DataLoader
708+
from torchvision import datapoints
709+
from torchvision.datasets import wrap_dataset_for_transforms_v2
710+
711+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
712+
713+
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
714+
715+
for wrapped_sample in dataloader:
716+
assert tree_any(
717+
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
718+
)
719+
720+
695721
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
696722
r"""Create a random uint8 tensor.
697723

test/test_datasets.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,18 @@ def test_combined_targets(self):
183183
), "Type of the combined target does not match the type of the corresponding individual target: "
184184
f"{actual} is not {expected}",
185185

186+
def test_transforms_v2_wrapper_spawn(self):
187+
with self.create_dataset(target_type="category") as (dataset, _):
188+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
189+
186190

187191
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
188192
DATASET_CLASS = datasets.Caltech256
189193

190194
def inject_fake_data(self, tmpdir, config):
191195
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
192196

193-
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
197+
categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
194198
num_images_per_category = 2
195199

196200
for idx, category in categories:
@@ -258,6 +262,10 @@ def inject_fake_data(self, tmpdir, config):
258262

259263
return split_to_num_examples[config["split"]]
260264

265+
def test_transforms_v2_wrapper_spawn(self):
266+
with self.create_dataset() as (dataset, _):
267+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
268+
261269

262270
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
263271
DATASET_CLASS = datasets.Cityscapes
@@ -382,6 +390,11 @@ def test_feature_types_target_polygon(self):
382390
assert isinstance(polygon_img, PIL.Image.Image)
383391
(polygon_target, info["expected_polygon_target"])
384392

393+
def test_transforms_v2_wrapper_spawn(self):
394+
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
395+
with self.create_dataset(target_type=target_type) as (dataset, _):
396+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
397+
385398

386399
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
387400
DATASET_CLASS = datasets.ImageNet
@@ -413,6 +426,10 @@ def inject_fake_data(self, tmpdir, config):
413426
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
414427
return num_examples
415428

429+
def test_transforms_v2_wrapper_spawn(self):
430+
with self.create_dataset() as (dataset, _):
431+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
432+
416433

417434
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
418435
DATASET_CLASS = datasets.CIFAR10
@@ -607,6 +624,11 @@ def test_images_names_split(self):
607624

608625
assert merged_imgs_names == all_imgs_names
609626

627+
def test_transforms_v2_wrapper_spawn(self):
628+
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
629+
with self.create_dataset(target_type=target_type) as (dataset, _):
630+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
631+
610632

611633
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
612634
DATASET_CLASS = datasets.VOCSegmentation
@@ -694,6 +716,10 @@ def add_bndbox(obj, bndbox=None):
694716

695717
return data
696718

719+
def test_transforms_v2_wrapper_spawn(self):
720+
with self.create_dataset() as (dataset, _):
721+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
722+
697723

698724
class VOCDetectionTestCase(VOCSegmentationTestCase):
699725
DATASET_CLASS = datasets.VOCDetection
@@ -714,6 +740,10 @@ def test_annotations(self):
714740

715741
assert object == info["annotation"]
716742

743+
def test_transforms_v2_wrapper_spawn(self):
744+
with self.create_dataset() as (dataset, _):
745+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
746+
717747

718748
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
719749
DATASET_CLASS = datasets.CocoDetection
@@ -784,6 +814,10 @@ def _create_json(self, root, name, content):
784814
json.dump(content, fh)
785815
return file
786816

817+
def test_transforms_v2_wrapper_spawn(self):
818+
with self.create_dataset() as (dataset, _):
819+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
820+
787821

788822
class CocoCaptionsTestCase(CocoDetectionTestCase):
789823
DATASET_CLASS = datasets.CocoCaptions
@@ -800,6 +834,11 @@ def test_captions(self):
800834
_, captions = dataset[0]
801835
assert tuple(captions) == tuple(info["captions"])
802836

837+
def test_transforms_v2_wrapper_spawn(self):
838+
# We need to define this method, because otherwise the test from the super class will
839+
# be run
840+
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
841+
803842

804843
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
805844
DATASET_CLASS = datasets.UCF101
@@ -966,6 +1005,10 @@ def inject_fake_data(self, tmpdir, config):
9661005
)
9671006
return num_videos_per_class * len(classes)
9681007

1008+
def test_transforms_v2_wrapper_spawn(self):
1009+
with self.create_dataset(output_format="TCHW") as (dataset, _):
1010+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1011+
9691012

9701013
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
9711014
DATASET_CLASS = datasets.HMDB51
@@ -1193,6 +1236,10 @@ def _create_segmentation(self, size):
11931236
def _file_stem(self, idx):
11941237
return f"2008_{idx:06d}"
11951238

1239+
def test_transforms_v2_wrapper_spawn(self):
1240+
with self.create_dataset(mode="segmentation") as (dataset, _):
1241+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1242+
11961243

11971244
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
11981245
DATASET_CLASS = datasets.FakeData
@@ -1642,6 +1689,10 @@ def inject_fake_data(self, tmpdir, config):
16421689

16431690
return split_to_num_examples[config["train"]]
16441691

1692+
def test_transforms_v2_wrapper_spawn(self):
1693+
with self.create_dataset() as (dataset, _):
1694+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1695+
16451696

16461697
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
16471698
DATASET_CLASS = datasets.SVHN
@@ -2516,6 +2567,10 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25162567
breed_id = "-1"
25172568
return (image_id, class_id, species, breed_id)
25182569

2570+
def test_transforms_v2_wrapper_spawn(self):
2571+
with self.create_dataset() as (dataset, _):
2572+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
2573+
25192574

25202575
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
25212576
DATASET_CLASS = datasets.StanfordCars

test/test_transforms_v2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import pathlib
3+
import pickle
34
import random
45
import warnings
56

@@ -169,8 +170,11 @@ class TestSmoke:
169170
next(make_vanilla_tensor_images()),
170171
],
171172
)
173+
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
172174
@pytest.mark.parametrize("device", cpu_and_cuda())
173-
def test_common(self, transform, adapter, container_type, image_or_video, device):
175+
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
176+
transform = de_serialize(transform)
177+
174178
canvas_size = F.get_size(image_or_video)
175179
input = dict(
176180
image_or_video=image_or_video,

test/test_transforms_v2_refactored.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import decimal
33
import inspect
44
import math
5+
import pickle
56
import re
67
from pathlib import Path
78
from unittest import mock
@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
247248
def check_transform(transform_cls, input, *args, **kwargs):
248249
transform = transform_cls(*args, **kwargs)
249250

251+
pickle.loads(pickle.dumps(transform))
252+
250253
output = transform(input)
251254
assert isinstance(output, type(input))
252255

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(self, dataset, target_keys):
162162
raise TypeError(msg)
163163

164164
self._dataset = dataset
165+
self._target_keys = target_keys
165166
self._wrapper = wrapper_factory(dataset, target_keys)
166167

167168
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
@@ -197,6 +198,9 @@ def __getitem__(self, idx):
197198
def __len__(self):
198199
return len(self._dataset)
199200

201+
def __reduce__(self):
202+
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
203+
200204

201205
def raise_not_supported(description):
202206
raise RuntimeError(

torchvision/datasets/widerface.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ def parse_train_val_annotations_file(self) -> None:
137137
{
138138
"img_path": img_path,
139139
"annotations": {
140-
"bbox": labels_tensor[:, 0:4], # x, y, width, height
141-
"blur": labels_tensor[:, 4],
142-
"expression": labels_tensor[:, 5],
143-
"illumination": labels_tensor[:, 6],
144-
"occlusion": labels_tensor[:, 7],
145-
"pose": labels_tensor[:, 8],
146-
"invalid": labels_tensor[:, 9],
140+
"bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
141+
"blur": labels_tensor[:, 4].clone(),
142+
"expression": labels_tensor[:, 5].clone(),
143+
"illumination": labels_tensor[:, 6].clone(),
144+
"occlusion": labels_tensor[:, 7].clone(),
145+
"pose": labels_tensor[:, 8].clone(),
146+
"invalid": labels_tensor[:, 9].clone(),
147147
},
148148
}
149149
)

0 commit comments

Comments
 (0)