@@ -126,8 +126,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl
126
126
if num_children == 0
127
127
]
128
128
129
- def _imagenet_label_to_wnid (self , imagenet_label : str , * , wnids : Tuple [ str , ...] ) -> str :
130
- return wnids [ int (imagenet_label ) - 1 ]
129
+ def _imagenet_label_id_fn (self , imagenet_label : str ) -> int :
130
+ return int (imagenet_label ) - 1
131
131
132
132
_VAL_TEST_IMAGE_NAME_PATTERN = re .compile (r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG" )
133
133
@@ -173,11 +173,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
173
173
)
174
174
175
175
meta_dp = Mapper (meta_dp , self ._extract_categories_and_wnids )
176
- _ , wnids = zip ( * next ( iter ( meta_dp )) )
176
+ wnid_dp = meta_dp . flatmap (). map ( getitem ( 1 )). enumerate (). to_map_datapipe ( )
177
177
178
178
label_dp = LineReader (label_dp , decode = True , return_path = False )
179
179
# We cannot use self._wnids here, since we use a different order than the dataset
180
- label_dp = Mapper ( label_dp , functools . partial ( self . _imagenet_label_to_wnid , wnids = wnids ))
180
+ label_dp = label_dp . zip_with_map ( wnid_dp , key_fn = self . _imagenet_label_id_fn ). map ( getitem ( 1 ))
181
181
label_dp : IterDataPipe [Tuple [int , str ]] = Enumerator (label_dp , 1 )
182
182
label_dp = hint_shuffling (label_dp )
183
183
label_dp = hint_sharding (label_dp )
0 commit comments