Skip to content

[proto] Enable lazy loading for the data pipeline of CUB200 and Imagenet #6569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))


def image_filename_fn(rel_posix_path: str) -> str:
return rel_posix_path.rsplit("/", maxsplit=1)[1]


@register_dataset(NAME)
class CUB200(Dataset):
"""
Expand Down Expand Up @@ -185,17 +189,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
)

image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_map = dict(
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp
)
image_files_map = image_files_dp.map(image_filename_fn, input_col=1).to_map_datapipe()

split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.get)
split_dp = Mapper(split_dp, image_files_map.__getitem__)

bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0)
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)

anns_dp = IterKeyZipper(
bounding_boxes_dp,
Expand Down
9 changes: 4 additions & 5 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import enum
import functools
import pathlib
import re
from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union
Expand Down Expand Up @@ -126,8 +125,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl
if num_children == 0
]

def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
return wnids[int(imagenet_label) - 1]
def _imagenet_label_id_fn(self, imagenet_label: str) -> int:
return int(imagenet_label) - 1

_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")

Expand Down Expand Up @@ -173,11 +172,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
)

meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
_, wnids = zip(*next(iter(meta_dp)))
wnid_dp = meta_dp.flatmap().map(getitem(1)).enumerate().to_map_datapipe()

label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp = label_dp.zip_with_map(wnid_dp, key_fn=self._imagenet_label_id_fn).map(getitem(1))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp)
Expand Down