From 57ab9b503c4ec38e4c6a30dea214ef96f3acfced Mon Sep 17 00:00:00 2001 From: erjia Date: Mon, 12 Sep 2022 18:38:15 +0000 Subject: [PATCH 1/2] Enable data lazy loading for CUB200 --- torchvision/prototype/datasets/_builtin/cub200.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index bb3f712c59d..6bd201609c6 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -39,6 +39,10 @@ def _info() -> Dict[str, Any]: return dict(categories=read_categories_file(NAME)) +def image_filename_fn(rel_posix_path: str) -> str: + return rel_posix_path.rsplit("/", maxsplit=1)[1] + + @register_dataset(NAME) class CUB200(Dataset): """ @@ -185,17 +189,15 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, ) image_files_dp = CSVParser(image_files_dp, dialect="cub200") - image_files_map = dict( - (image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp - ) + image_files_map = image_files_dp.map(image_filename_fn, input_col=1).to_map_datapipe() split_dp = CSVParser(split_dp, dialect="cub200") split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Mapper(split_dp, getitem(0)) - split_dp = Mapper(split_dp, image_files_map.get) + split_dp = Mapper(split_dp, image_files_map.__getitem__) bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200") - bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0) + bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0) anns_dp = IterKeyZipper( bounding_boxes_dp, From 88d4fca45d3614699e736dce6c7d21283a085eae Mon Sep 17 00:00:00 2001 From: erjia Date: Mon, 12 Sep 2022 20:02:52 +0000 Subject: [PATCH 2/2] Enable data lazy loading for imagenet --- torchvision/prototype/datasets/_builtin/imagenet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 062e240a8b8..07eeb869c73 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,5 +1,4 @@ import enum -import functools import pathlib import re from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union @@ -126,8 +125,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl if num_children == 0 ] - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: - return wnids[int(imagenet_label) - 1] + def _imagenet_label_id_fn(self, imagenet_label: str) -> int: + return int(imagenet_label) - 1 _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") @@ -173,11 +172,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, ) meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) - _, wnids = zip(*next(iter(meta_dp))) + wnid_dp = meta_dp.flatmap().map(getitem(1)).enumerate().to_map_datapipe() label_dp = LineReader(label_dp, decode=True, return_path=False) # We cannot use self._wnids here, since we use a different order than the dataset - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp = label_dp.zip_with_map(wnid_dp, key_fn=self._imagenet_label_id_fn).map(getitem(1)) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_shuffling(label_dp) label_dp = hint_sharding(label_dp)