Skip to content

Commit ae375a6

Browse files
committed
Enable data lazy loading for CUB200
1 parent 1d6a259 commit ae375a6

File tree

1 file changed

+7
-5
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+7
-5
lines changed

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def _info() -> Dict[str, Any]:
3939
return dict(categories=read_categories_file(NAME))
4040

4141

42+
def image_filename_fn(rel_posix_path: int) -> int:
43+
return rel_posix_path.rsplit("/", maxsplit=1)[1]
44+
45+
4246
@register_dataset(NAME)
4347
class CUB200(Dataset):
4448
"""
@@ -185,17 +189,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
185189
)
186190

187191
image_files_dp = CSVParser(image_files_dp, dialect="cub200")
188-
image_files_map = dict(
189-
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp
190-
)
192+
image_files_map = image_files_dp.map(image_filename_fn, input_col=1).to_map_datapipe()
191193

192194
split_dp = CSVParser(split_dp, dialect="cub200")
193195
split_dp = Filter(split_dp, self._2011_filter_split)
194196
split_dp = Mapper(split_dp, getitem(0))
195-
split_dp = Mapper(split_dp, image_files_map.get)
197+
split_dp = Mapper(split_dp, image_files_map.__getitem__)
196198

197199
bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
198-
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0)
200+
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
199201

200202
anns_dp = IterKeyZipper(
201203
bounding_boxes_dp,

0 commit comments

Comments
 (0)