Skip to content

Commit 4d00ae0

Browse files
authored
add download functionality to prototype datasets (#5035)
* add download functionality to prototype datasets * fix annotation * fix test * remove iopath * add comments
1 parent 4282c9f commit 4d00ae0

File tree

16 files changed

+225
-105
lines changed

16 files changed

+225
-105
lines changed

test/builtin_dataset_mocks.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929
DEFAULT_TEST_DECODER = object()
3030

3131

32+
class TestResource(datasets.utils.OnlineResource):
33+
def __init__(self, *, dataset_name, dataset_config, **kwargs):
34+
super().__init__(**kwargs)
35+
self.dataset_name = dataset_name
36+
self.dataset_config = dataset_config
37+
38+
def _download(self, _):
39+
raise pytest.UsageError(
40+
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
41+
f"but this file does not exist."
42+
)
43+
44+
3245
class DatasetMocks:
3346
def __init__(self):
3447
self._mock_data_fns = {}
@@ -72,7 +85,7 @@ def _parse_mock_info(self, mock_info, *, name):
7285
)
7386
return mock_info
7487

75-
def _get(self, dataset, config):
88+
def _get(self, dataset, config, root):
7689
name = dataset.info.name
7790
resources_and_mock_info = self._cache.get((name, config))
7891
if resources_and_mock_info:
@@ -87,20 +100,12 @@ def _get(self, dataset, config):
87100
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
88101
)
89102

90-
root = self._tmp_home / name
91-
root.mkdir(exist_ok=True)
103+
mock_resources = [
104+
TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name)
105+
for resource in dataset.resources(config)
106+
]
92107
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
93108

94-
mock_resources = []
95-
for resource in dataset.resources(config):
96-
path = root / resource.file_name
97-
if not path.exists() and path.is_file():
98-
raise pytest.UsageError(
99-
f"Dataset '{name}' requires the file {path.name} for {config}, but this file does not exist."
100-
)
101-
102-
mock_resources.append(datasets.utils.LocalResource(path))
103-
104109
self._cache[(name, config)] = mock_resources, mock_info
105110
return mock_resources, mock_info
106111

@@ -109,9 +114,13 @@ def load(
109114
) -> Tuple[IterDataPipe, Dict[str, Any]]:
110115
dataset = find(name)
111116
config = dataset.info.make_config(split=split, **options)
112-
resources, mock_info = self._get(dataset, config)
117+
118+
root = self._tmp_home / name
119+
root.mkdir(exist_ok=True)
120+
resources, mock_info = self._get(dataset, config, root)
121+
113122
datapipe = dataset._make_datapipe(
114-
[resource.to_datapipe() for resource in resources],
123+
[resource.load(root) for resource in resources],
115124
config=config,
116125
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
117126
)

test/test_prototype_datasets_api.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ def test_default_config(self):
211211
pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
212212
],
213213
)
214-
def test_to_datapipe_config(self, config, kwarg):
214+
def test_load_config(self, config, kwarg):
215215
dataset = self.DatasetMock()
216216

217-
dataset.to_datapipe("", config=kwarg)
217+
dataset.load("", config=kwarg)
218218

219219
dataset.resources.assert_called_with(config)
220220

@@ -225,18 +225,19 @@ def test_missing_dependencies(self):
225225
dependency = "fake_dependency"
226226
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
227227
with pytest.raises(ModuleNotFoundError, match=dependency):
228-
dataset.to_datapipe("root")
228+
dataset.load("root")
229229

230230
def test_resources(self, mocker):
231-
resource_mock = mocker.Mock(spec=["to_datapipe"])
231+
resource_mock = mocker.Mock(spec=["load"])
232232
sentinel = object()
233-
resource_mock.to_datapipe.return_value = sentinel
233+
resource_mock.load.return_value = sentinel
234234
dataset = self.DatasetMock(resources=[resource_mock])
235235

236236
root = "root"
237-
dataset.to_datapipe(root)
237+
dataset.load(root)
238238

239-
resource_mock.to_datapipe.assert_called_with(root)
239+
(call_args, _) = resource_mock.load.call_args
240+
assert call_args[0] == root
240241

241242
(call_args, _) = dataset._make_datapipe.call_args
242243
assert call_args[0][0] is sentinel
@@ -245,7 +246,7 @@ def test_decoder(self):
245246
dataset = self.DatasetMock()
246247

247248
sentinel = object()
248-
dataset.to_datapipe("", decoder=sentinel)
249+
dataset.load("", decoder=sentinel)
249250

250251
(_, call_kwargs) = dataset._make_datapipe.call_args
251252
assert call_kwargs["decoder"] is sentinel

torchvision/prototype/datasets/_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ def load(
6161
name: str,
6262
*,
6363
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
64+
skip_integrity_check: bool = False,
6465
split: str = "train",
6566
**options: Any,
6667
) -> IterDataPipe[Dict[str, Any]]:
67-
name = name.lower()
6868
dataset = find(name)
6969

7070
if decoder is DEFAULT_DECODER:
7171
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
7272

7373
config = dataset.info.make_config(split=split, **options)
74-
root = os.path.join(home(), name)
74+
root = os.path.join(home(), dataset.name)
7575

76-
return dataset.to_datapipe(root, config=config, decoder=decoder)
76+
return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
TarArchiveReader,
1211
Shuffler,
1312
Filter,
1413
IterKeyZipper,
@@ -38,6 +37,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
3837
images = HttpResource(
3938
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
4039
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
40+
decompress=True,
4141
)
4242
anns = HttpResource(
4343
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
@@ -119,11 +119,9 @@ def _make_datapipe(
119119
) -> IterDataPipe[Dict[str, Any]]:
120120
images_dp, anns_dp = resource_dps
121121

122-
images_dp = TarArchiveReader(images_dp)
123122
images_dp = Filter(images_dp, self._is_not_background_image)
124123
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
125124

126-
anns_dp = TarArchiveReader(anns_dp)
127125
anns_dp = Filter(anns_dp, self._is_ann)
128126

129127
dp = IterKeyZipper(
@@ -137,8 +135,7 @@ def _make_datapipe(
137135
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
138136

139137
def _generate_categories(self, root: pathlib.Path) -> List[str]:
140-
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
141-
dp = TarArchiveReader(dp)
138+
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
142139
dp = Filter(dp, self._is_not_background_image)
143140
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
144141

@@ -185,13 +182,11 @@ def _make_datapipe(
185182
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
186183
) -> IterDataPipe[Dict[str, Any]]:
187184
dp = resource_dps[0]
188-
dp = TarArchiveReader(dp)
189185
dp = Filter(dp, self._is_not_rogue_file)
190186
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
191187
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
192188

193189
def _generate_categories(self, root: pathlib.Path) -> List[str]:
194-
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
195-
dp = TarArchiveReader(dp)
190+
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
196191
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
197192
return [name.split(".")[1] for name in sorted(dir_names)]

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Mapper,
99
Shuffler,
1010
Filter,
11-
ZipArchiveReader,
1211
Zipper,
1312
IterKeyZipper,
1413
)
@@ -154,8 +153,6 @@ def _make_datapipe(
154153
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
155154
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
156155

157-
images_dp = ZipArchiveReader(images_dp)
158-
159156
anns_dp = Zipper(
160157
*[
161158
CelebACSVParser(dp, fieldnames=fieldnames)

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
IterDataPipe,
1212
Filter,
1313
Mapper,
14-
TarArchiveReader,
1514
Shuffler,
1615
)
1716
from torchvision.prototype.datasets.decoder import raw
@@ -85,16 +84,14 @@ def _make_datapipe(
8584
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
8685
) -> IterDataPipe[Dict[str, Any]]:
8786
dp = resource_dps[0]
88-
dp = TarArchiveReader(dp)
8987
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
9088
dp = Mapper(dp, self._unpickle)
9189
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
9290
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
9391
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
9492

9593
def _generate_categories(self, root: pathlib.Path) -> List[str]:
96-
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
97-
dp = TarArchiveReader(dp)
94+
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
9895
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
9996
dp = Mapper(dp, self._unpickle)
10097
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Shuffler,
1212
Filter,
1313
Demultiplexer,
14-
ZipArchiveReader,
1514
Grouper,
1615
IterKeyZipper,
1716
JsonParser,
@@ -180,13 +179,10 @@ def _make_datapipe(
180179
) -> IterDataPipe[Dict[str, Any]]:
181180
images_dp, meta_dp = resource_dps
182181

183-
images_dp = ZipArchiveReader(images_dp)
184-
185182
if config.annotations is None:
186183
dp = Shuffler(images_dp)
187184
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
188185

189-
meta_dp = ZipArchiveReader(meta_dp)
190186
meta_dp = Filter(
191187
meta_dp,
192188
self._filter_meta_files,
@@ -234,8 +230,7 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
234230
config = self.default_config
235231
resources = self.resources(config)
236232

237-
dp = resources[1].to_datapipe(pathlib.Path(root) / self.name)
238-
dp = ZipArchiveReader(dp)
233+
dp = resources[1].load(pathlib.Path(root) / self.name)
239234
dp = Filter(
240235
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
241236
)

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
Dataset,
1010
DatasetConfig,
1111
DatasetInfo,
12-
HttpResource,
1312
OnlineResource,
13+
ManualDownloadResource,
1414
DatasetType,
1515
)
1616
from torchvision.prototype.datasets.utils._internal import (
@@ -25,6 +25,11 @@
2525
from torchvision.prototype.utils._internal import FrozenMapping
2626

2727

28+
class ImageNetResource(ManualDownloadResource):
29+
def __init__(self, **kwargs: Any) -> None:
30+
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
31+
32+
2833
class ImageNetLabel(Label):
2934
wnid: Optional[str]
3035

@@ -81,10 +86,10 @@ def wnid_to_category(self) -> Dict[str, str]:
8186

8287
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
8388
name = "test_v10102019" if config.split == "test" else config.split
84-
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
89+
images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
8590

86-
devkit = HttpResource(
87-
"ILSVRC2012_devkit_t12.tar.gz",
91+
devkit = ImageNetResource(
92+
file_name="ILSVRC2012_devkit_t12.tar.gz",
8893
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
8994
)
9095

@@ -139,15 +144,12 @@ def _make_datapipe(
139144
) -> IterDataPipe[Dict[str, Any]]:
140145
images_dp, devkit_dp = resource_dps
141146

142-
images_dp = TarArchiveReader(images_dp)
143-
144147
if config.split == "train":
145148
# the train archive is a tar of tars
146149
dp = TarArchiveReader(images_dp)
147150
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
148151
dp = Mapper(dp, self._collate_train_data)
149152
elif config.split == "val":
150-
devkit_dp = TarArchiveReader(devkit_dp)
151153
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
152154
devkit_dp = LineReader(devkit_dp, return_path=False)
153155
devkit_dp = Mapper(devkit_dp, int)
@@ -177,8 +179,7 @@ def _make_datapipe(
177179

178180
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
179181
resources = self.resources(self.default_config)
180-
devkit_dp = resources[1].to_datapipe(root / self.name)
181-
devkit_dp = TarArchiveReader(devkit_dp)
182+
devkit_dp = resources[1].load(root / self.name)
182183
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
183184

184185
meta = next(iter(devkit_dp))[1]

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
IterDataPipe,
1212
Demultiplexer,
1313
Mapper,
14-
ZipArchiveReader,
1514
Zipper,
1615
Shuffler,
1716
)
@@ -310,7 +309,6 @@ def _make_datapipe(
310309
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
311310
) -> IterDataPipe[Dict[str, Any]]:
312311
archive_dp = resource_dps[0]
313-
archive_dp = ZipArchiveReader(archive_dp)
314312
images_dp, labels_dp = Demultiplexer(
315313
archive_dp,
316314
2,

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
TarArchiveReader,
1211
Shuffler,
1312
Demultiplexer,
1413
Filter,
@@ -129,7 +128,6 @@ def _make_datapipe(
129128
archive_dp, extra_split_dp = resource_dps
130129

131130
archive_dp = resource_dps[0]
132-
archive_dp = TarArchiveReader(archive_dp)
133131
split_dp, images_dp, anns_dp = Demultiplexer(
134132
archive_dp,
135133
3,
@@ -155,8 +153,7 @@ def _make_datapipe(
155153
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
156154

157155
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
158-
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
159-
dp = TarArchiveReader(dp)
156+
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
160157
dp = Filter(dp, path_comparator("name", "category_names.m"))
161158
dp = LineReader(dp)
162159
dp = Mapper(dp, bytes.decode, input_col=1)

0 commit comments

Comments
 (0)