Skip to content

Commit 176f919

Browse files
committed
Enable data lazy loading for imagenet
1 parent ae375a6 commit 176f919

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import enum
2-
import functools
32
import pathlib
43
import re
54
from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union
@@ -126,8 +125,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl
126125
if num_children == 0
127126
]
128127

129-
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
130-
return wnids[int(imagenet_label) - 1]
128+
def _imagenet_label_id_fn(self, imagenet_label: str) -> int:
129+
return int(imagenet_label) - 1
131130

132131
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
133132

@@ -173,11 +172,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
173172
)
174173

175174
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
176-
_, wnids = zip(*next(iter(meta_dp)))
175+
wnid_dp = meta_dp.flatmap().map(getitem(1)).enumerate().to_map_datapipe()
177176

178177
label_dp = LineReader(label_dp, decode=True, return_path=False)
179178
# We cannot use self._wnids here, since we use a different order than the dataset
180-
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
179+
label_dp = label_dp.zip_with_map(wnid_dp, key_fn=self._imagenet_label_id_fn).map(getitem(1))
181180
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
182181
label_dp = hint_shuffling(label_dp)
183182
label_dp = hint_sharding(label_dp)

0 commit comments

Comments
 (0)