Skip to content

Commit 654ee06

Browse files
committed
Enable data lazy loading for imagenet
1 parent 3d2a2f9 commit 654ee06

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl
126126
if num_children == 0
127127
]
128128

129-
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
130-
return wnids[int(imagenet_label) - 1]
129+
def _imagenet_label_id_fn(self, imagenet_label: str) -> int:
130+
return int(imagenet_label) - 1
131131

132132
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
133133

@@ -173,11 +173,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
173173
)
174174

175175
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
176-
_, wnids = zip(*next(iter(meta_dp)))
176+
wnid_dp = meta_dp.flatmap().map(getitem(1)).enumerate().to_map_datapipe()
177177

178178
label_dp = LineReader(label_dp, decode=True, return_path=False)
179179
# We cannot use self._wnids here, since we use a different order than the dataset
180-
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
180+
label_dp = label_dp.zip_with_map(wnid_dp, key_fn=self._imagenet_label_id_fn).map(getitem(1))
181181
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
182182
label_dp = hint_shuffling(label_dp)
183183
label_dp = hint_sharding(label_dp)

0 commit comments

Comments
 (0)