@@ -39,6 +39,10 @@ def _info() -> Dict[str, Any]:
39
39
return dict (categories = read_categories_file (NAME ))
40
40
41
41
42
+ def image_filename_fn (rel_posix_path : int ) -> int :
43
+ return rel_posix_path .rsplit ("/" , maxsplit = 1 )[1 ]
44
+
45
+
42
46
@register_dataset (NAME )
43
47
class CUB200 (Dataset ):
44
48
"""
@@ -185,17 +189,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
185
189
)
186
190
187
191
image_files_dp = CSVParser (image_files_dp , dialect = "cub200" )
188
- image_files_map = dict (
189
- (image_id , rel_posix_path .rsplit ("/" , maxsplit = 1 )[1 ]) for image_id , rel_posix_path in image_files_dp
190
- )
192
+ image_files_map = image_files_dp .map (image_filename_fn , input_col = 1 ).to_map_datapipe ()
191
193
192
194
split_dp = CSVParser (split_dp , dialect = "cub200" )
193
195
split_dp = Filter (split_dp , self ._2011_filter_split )
194
196
split_dp = Mapper (split_dp , getitem (0 ))
195
- split_dp = Mapper (split_dp , image_files_map .get )
197
+ split_dp = Mapper (split_dp , image_files_map .__getitem__ )
196
198
197
199
bounding_boxes_dp = CSVParser (bounding_boxes_dp , dialect = "cub200" )
198
- bounding_boxes_dp = Mapper (bounding_boxes_dp , image_files_map .get , input_col = 0 )
200
+ bounding_boxes_dp = Mapper (bounding_boxes_dp , image_files_map .__getitem__ , input_col = 0 )
199
201
200
202
anns_dp = IterKeyZipper (
201
203
bounding_boxes_dp ,
0 commit comments