Skip to content

Commit 436e2a2

Browse files
Vincent Moensfacebook-github-bot
authored andcommitted
[fbsync] Expand tests for prototype datasets (#5187)
Summary: * refactor prototype datasets tests * skip tests with insufficient third party dependencies * cleanup * add tests for SBD prototype dataset * add tests for SEMEION prototype dataset * add tests for VOC prototype dataset * add tests for CelebA prototype dataset * add tests for DTD prototype dataset * add tests for FER2013 prototype dataset * add tests for CLEVR prototype dataset * add tests for oxford-iiit-pet prototype dataset * enforce tests for new datasets * add missing archive generation for oxford-iiit-pet tests * add tests for CUB200 prototype datasets * fix split generation * add capability to mark parametrization and xfail cub200 traverse tests Reviewed By: datumbox, NicolasHug Differential Revision: D33655253 fbshipit-source-id: 186591f2cb89e864c2d143d6a460449cf4991baa
1 parent c4d3768 commit 436e2a2

File tree

8 files changed

+789
-68
lines changed

8 files changed

+789
-68
lines changed

test/builtin_dataset_mocks.py

Lines changed: 702 additions & 33 deletions
Large diffs are not rendered by default.

test/test_prototype_builtin_datasets.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,28 @@
66
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
77
from torch.utils.data.graph import traverse
88
from torchdata.datapipes.iter import IterDataPipe, Shuffler
9-
from torchvision.prototype import transforms
9+
from torchvision.prototype import transforms, datasets
1010
from torchvision.prototype.utils._internal import sequence_to_str
1111

1212

13-
@parametrize_dataset_mocks(DATASET_MOCKS)
13+
def test_coverage():
14+
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
15+
if untested_datasets:
16+
raise AssertionError(
17+
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
18+
f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
19+
f"Please add mock data to `test/builtin_dataset_mocks.py`."
20+
)
21+
22+
1423
class TestCommon:
24+
@parametrize_dataset_mocks(DATASET_MOCKS)
1525
def test_smoke(self, dataset_mock, config):
1626
dataset, _ = dataset_mock.load(config)
1727
if not isinstance(dataset, IterDataPipe):
1828
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
1929

30+
@parametrize_dataset_mocks(DATASET_MOCKS)
2031
def test_sample(self, dataset_mock, config):
2132
dataset, _ = dataset_mock.load(config)
2233

@@ -31,6 +42,7 @@ def test_sample(self, dataset_mock, config):
3142
if not sample:
3243
raise AssertionError("Sample dictionary is empty.")
3344

45+
@parametrize_dataset_mocks(DATASET_MOCKS)
3446
def test_num_samples(self, dataset_mock, config):
3547
dataset, mock_info = dataset_mock.load(config)
3648

@@ -40,6 +52,7 @@ def test_num_samples(self, dataset_mock, config):
4052

4153
assert num_samples == mock_info["num_samples"]
4254

55+
@parametrize_dataset_mocks(DATASET_MOCKS)
4356
def test_decoding(self, dataset_mock, config):
4457
dataset, _ = dataset_mock.load(config)
4558

@@ -50,6 +63,7 @@ def test_decoding(self, dataset_mock, config):
5063
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
5164
)
5265

66+
@parametrize_dataset_mocks(DATASET_MOCKS)
5367
def test_no_vanilla_tensors(self, dataset_mock, config):
5468
dataset, _ = dataset_mock.load(config)
5569

@@ -60,16 +74,33 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
6074
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
6175
)
6276

77+
@parametrize_dataset_mocks(DATASET_MOCKS)
6378
def test_transformable(self, dataset_mock, config):
6479
dataset, _ = dataset_mock.load(config)
6580

6681
next(iter(dataset.map(transforms.Identity())))
6782

83+
@parametrize_dataset_mocks(
84+
DATASET_MOCKS,
85+
marks={
86+
"cub200": pytest.mark.xfail(
87+
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
88+
)
89+
},
90+
)
6891
def test_traversable(self, dataset_mock, config):
6992
dataset, _ = dataset_mock.load(config)
7093

7194
traverse(dataset)
7295

96+
@parametrize_dataset_mocks(
97+
DATASET_MOCKS,
98+
marks={
99+
"cub200": pytest.mark.xfail(
100+
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
101+
)
102+
},
103+
)
73104
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
74105
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
75106
def scan(graph):
@@ -86,8 +117,8 @@ def scan(graph):
86117
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
87118

88119

120+
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
89121
class TestQMNIST:
90-
@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])
91122
def test_extra_label(self, dataset_mock, config):
92123
dataset, _ = dataset_mock.load(config)
93124

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
hint_sharding,
2727
hint_shuffling,
2828
)
29+
from torchvision.prototype.features import Feature, Label, BoundingBox
2930

3031
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
3132

@@ -67,6 +68,7 @@ def _make_info(self) -> DatasetInfo:
6768
"celeba",
6869
type=DatasetType.IMAGE,
6970
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
71+
valid_options=dict(split=("train", "val", "test")),
7072
)
7173

7274
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
@@ -104,7 +106,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
104106

105107
_SPLIT_ID_TO_NAME = {
106108
"0": "train",
107-
"1": "valid",
109+
"1": "val",
108110
"2": "test",
109111
}
110112

@@ -117,22 +119,22 @@ def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[s
117119

118120
def _collate_and_decode_sample(
119121
self,
120-
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]],
122+
data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]],
121123
*,
122124
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
123125
) -> Dict[str, Any]:
124126
split_and_image_data, ann_data = data
125-
_, _, image_data = split_and_image_data
127+
_, (_, image_data) = split_and_image_data
126128
path, buffer = image_data
127129
_, ann = ann_data
128130

129131
image = decoder(buffer) if decoder else buffer
130132

131-
identity = int(ann["identity"]["identity"])
133+
identity = Label(int(ann["identity"]["identity"]))
132134
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
133-
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
135+
bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
134136
landmarks = {
135-
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
137+
landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
136138
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
137139
}
138140

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
105105
path = pathlib.Path(data[0])
106106
return path.with_suffix(".jpg").name
107107

108-
def _2011_decode_ann(
108+
def _2011_load_ann(
109109
self,
110110
data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]],
111111
*,
@@ -126,7 +126,7 @@ def _2010_anns_key(self, data: Tuple[str, io.IOBase]) -> Tuple[str, Tuple[str, i
126126
path = pathlib.Path(data[0])
127127
return path.with_suffix(".jpg").name, data
128128

129-
def _2010_decode_ann(
129+
def _2010_load_ann(
130130
self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
131131
) -> Dict[str, Any]:
132132
_, (path, buffer) = data
@@ -154,7 +154,7 @@ def _collate_and_decode_sample(
154154
label_str, category = dir_name.split(".")
155155

156156
return dict(
157-
(self._2011_decode_ann if year == "2011" else self._2010_decode_ann)(anns_data, decoder=decoder),
157+
(self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder),
158158
image=decoder(buffer) if decoder else buffer,
159159
label=Label(int(label_str), category=category),
160160
)
@@ -196,7 +196,7 @@ def _make_datapipe(
196196
else: # config.year == "2010"
197197
split_dp, images_dp, anns_dp = resource_dps
198198

199-
split_dp = Filter(split_dp, path_comparator("stem", config.split))
199+
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
200200
split_dp = LineReader(split_dp, decode=True, return_path=False)
201201
split_dp = Mapper(split_dp, self._2010_split_key)
202202

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import enum
2+
import functools
23
import io
34
import pathlib
45
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -126,7 +127,7 @@ def _make_datapipe(
126127
ref_key_fn=self._image_key_fn,
127128
buffer_size=INFINITE_BUFFER_SIZE,
128129
)
129-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
130+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
130131

131132
def _filter_images(self, data: Tuple[str, Any]) -> bool:
132133
return self._classify_archive(data) == DTDDemux.IMAGES

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
hint_sharding,
3232
hint_shuffling,
3333
)
34+
from torchvision.prototype.features import Feature
3435

3536

3637
class SBD(Dataset):
@@ -83,11 +84,11 @@ def _decode_ann(
8384

8485
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
8586
boundaries = (
86-
torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
87+
Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
8788
if decode_boundaries
8889
else None
8990
)
90-
segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None
91+
segmentation = Feature(raw_segmentation) if decode_segmentation else None
9192

9293
return boundaries, segmentation
9394

@@ -140,6 +141,7 @@ def _make_datapipe(
140141

141142
if config.split == "train_noval":
142143
split_dp = extra_split_dp
144+
split_dp = Filter(split_dp, path_comparator("stem", config.split))
143145
split_dp = LineReader(split_dp, decode=True)
144146
split_dp = hint_sharding(split_dp)
145147
split_dp = hint_shuffling(split_dp)

torchvision/prototype/datasets/_builtin/semeion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DatasetType,
1919
)
2020
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
21+
from torchvision.prototype.features import Image, Label
2122

2223

2324
class SEMEION(Dataset):
@@ -46,14 +47,13 @@ def _collate_and_decode_sample(
4647
label_data = [int(label) for label in data[256:] if label]
4748

4849
if decoder is raw:
49-
image = image_data.unsqueeze(0)
50+
image = Image(image_data.unsqueeze(0))
5051
else:
5152
image_buffer = image_buffer_from_array(image_data.numpy())
5253
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
5354

54-
label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
55-
category = self.info.categories[label]
56-
return dict(image=image, label=label, category=category)
55+
label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
56+
return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx]))
5757

5858
def _make_datapipe(
5959
self,

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,50 @@
3030
hint_sharding,
3131
hint_shuffling,
3232
)
33+
from torchvision.prototype.features import BoundingBox
3334

34-
HERE = pathlib.Path(__file__).parent
35+
36+
class VOCDatasetInfo(DatasetInfo):
37+
def __init__(self, *args: Any, **kwargs: Any):
38+
super().__init__(*args, **kwargs)
39+
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
40+
41+
def make_config(self, **options: Any) -> DatasetConfig:
42+
config = super().make_config(**options)
43+
if config.split == "test" and config.year != "2007":
44+
raise ValueError("`split='test'` is only available for `year='2007'`")
45+
46+
return config
3547

3648

3749
class VOC(Dataset):
3850
def _make_info(self) -> DatasetInfo:
39-
return DatasetInfo(
51+
return VOCDatasetInfo(
4052
"voc",
4153
type=DatasetType.IMAGE,
4254
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
4355
valid_options=dict(
44-
split=("train", "val", "test"),
45-
year=("2012",),
56+
split=("train", "val", "trainval", "test"),
57+
year=("2012", "2007", "2008", "2009", "2010", "2011"),
4658
task=("detection", "segmentation"),
4759
),
4860
)
4961

62+
_TRAIN_VAL_ARCHIVES = {
63+
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
64+
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
65+
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
66+
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
67+
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
68+
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
69+
}
70+
_TEST_ARCHIVES = {
71+
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
72+
}
73+
5074
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
51-
if config.year == "2012":
52-
if config.split == "train":
53-
archive = HttpResource(
54-
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
55-
sha256="e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb",
56-
)
57-
else:
58-
raise RuntimeError("FIXME")
59-
else:
60-
raise RuntimeError("FIXME")
75+
file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year]
76+
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256)
6177
return [archive]
6278

6379
_ANNS_FOLDER = dict(
@@ -88,7 +104,7 @@ def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor:
88104
objects = result["annotation"]["object"]
89105
bboxes = [obj["bndbox"] for obj in objects]
90106
bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes]
91-
return torch.tensor(bboxes)
107+
return BoundingBox(bboxes)
92108

93109
def _collate_and_decode_sample(
94110
self,

0 commit comments

Comments
 (0)