|
1 | 1 | import enum
|
2 |
| -import functools |
3 | 2 | import pathlib
|
4 | 3 | import re
|
5 | 4 | from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union
|
@@ -126,8 +125,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl
|
126 | 125 | if num_children == 0
|
127 | 126 | ]
|
128 | 127 |
|
129 |
| - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: |
130 |
| - return wnids[int(imagenet_label) - 1] |
| 128 | + def _imagenet_label_id_fn(self, imagenet_label: str) -> int: |
| 129 | + return int(imagenet_label) - 1 |
131 | 130 |
|
132 | 131 | _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
|
133 | 132 |
|
@@ -173,11 +172,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
|
173 | 172 | )
|
174 | 173 |
|
175 | 174 | meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
|
176 |
| - _, wnids = zip(*next(iter(meta_dp))) |
| 175 | + wnid_dp = meta_dp.flatmap().map(getitem(1)).enumerate().to_map_datapipe() |
177 | 176 |
|
178 | 177 | label_dp = LineReader(label_dp, decode=True, return_path=False)
|
179 | 178 | # We cannot use self._wnids here, since we use a different order than the dataset
|
180 |
| - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) |
| 179 | + label_dp = label_dp.zip_with_map(wnid_dp, key_fn=self._imagenet_label_id_fn).map(getitem(1)) |
181 | 180 | label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
|
182 | 181 | label_dp = hint_shuffling(label_dp)
|
183 | 182 | label_dp = hint_sharding(label_dp)
|
|
0 commit comments