diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index bebeaccaadd..09d6238565f 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -3,6 +3,7 @@ import pytest import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS +from torch.utils.data.dataloader_experimental import DataLoader2 from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler @@ -80,27 +81,13 @@ def test_transformable(self, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @parametrize_dataset_mocks( - DATASET_MOCKS, - marks={ - "cub200": pytest.mark.xfail( - reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165" - ) - }, - ) + @parametrize_dataset_mocks(DATASET_MOCKS) def test_traversable(self, dataset_mock, config): dataset, _ = dataset_mock.load(config) traverse(dataset) - @parametrize_dataset_mocks( - DATASET_MOCKS, - marks={ - "cub200": pytest.mark.xfail( - reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165" - ) - }, - ) + @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__) def test_has_annotations(self, dataset_mock, config, annotation_dp_type): def scan(graph): @@ -116,6 +103,15 @@ def scan(graph): else: raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_multi_epoch(self, dataset_mock, config): + dataset, _ = dataset_mock.load(config) + data_loader = DataLoader2(dataset) + + for epoch in range(2): + for _ in data_loader: + pass + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index facd909f468..d0a1c2b5ad3 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -31,6 +31,7 @@ getitem, path_comparator, path_accessor, + LazyDict, ) from torchvision.prototype.features import Label, BoundingBox, Feature @@ -94,6 +95,9 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None + def _2011_image_key(self, rel_posix_path: str) -> str: + return rel_posix_path.rsplit("/", 1)[1] + def _2011_filter_split(self, row: List[str], *, split: str) -> bool: _, split_id = row return { @@ -173,9 +177,8 @@ def _make_datapipe( ) image_files_dp = CSVParser(image_files_dp, dialect="cub200") - image_files_map = dict( - (image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp - ) + image_files_dp = Mapper(image_files_dp, self._2011_image_key, input_col=1) + image_files_map = LazyDict(image_files_dp) split_dp = CSVParser(split_dp, dialect="cub200") split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split)) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index e21e8ffd25f..6641b798239 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -9,6 +9,7 @@ import pathlib import pickle import platform +from collections import UserDict from typing import BinaryIO from typing import ( Sequence, @@ -49,6 +50,7 @@ "fromfile", "read_flo", "hint_sharding", + "LazyDict", ] K = TypeVar("K") @@ -345,3 +347,21 @@ def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) + + +class LazyDict(UserDict): + def __init__(self, datapipe: IterDataPipe[Tuple[K, D]], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.datapipe = datapipe + self.loaded = False + + def load(self) -> None: + for key, value in self.datapipe: + self.data[key] = value + self.loaded = True + + def __getitem__(self, item: K) -> D: + if not self.loaded: + self.load() + + return self.data[item]