From 049c0466dfb9edae046b447f86ecb4957468c9d6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 31 Dec 2021 08:40:37 +0100 Subject: [PATCH 1/2] add SVHN prototype dataset --- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/svhn.py | 96 +++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 torchvision/prototype/datasets/_builtin/svhn.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 62abc3119f6..a694d03051f 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -6,4 +6,5 @@ from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .sbd import SBD from .semeion import SEMEION +from .svhn import SVHN from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py new file mode 100644 index 00000000000..7f9c019e92e --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -0,0 +1,96 @@ +import functools +import io +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchdata.datapipes.iter import ( + IterDataPipe, + Mapper, + UnBatcher, +) +from torchvision.prototype.datasets.decoder import raw +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + read_mat, + hint_sharding, + hint_shuffling, + image_buffer_from_array, +) +from torchvision.prototype.features import Label, Image + + +class SVHN(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "svhn", + type=DatasetType.RAW, + dependencies=("scipy",), + categories=10, + homepage="http://ufldl.stanford.edu/housenumbers/", + valid_options=dict(split=("train", "test", "extra")), + ) + + _CHECKSUMS = { + "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", + "test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b", + "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", + } + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + data = HttpResource( + f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat", + sha256=self._CHECKSUMS[config.split], + ) + + return [data] + + def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]: + _, buffer = data + content = read_mat(buffer) + return list( + zip( + content["X"].transpose((3, 0, 1, 2)), + content["y"].squeeze(), + ) + ) + + def _collate_and_decode_sample( + self, + data: Tuple[np.ndarray, np.ndarray], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + image_array, label_array = data + + if decoder is raw: + image = Image(image_array.transpose((2, 0, 1))) + else: + image_buffer = image_buffer_from_array(image_array) + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] + + return dict( + image=image, + label=Label(int(label_array) % 10), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + dp = Mapper(dp, self._read_images_and_labels) + dp = UnBatcher(dp) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) From 64a28b023bd7b722639a880350ce44ba7bec7591 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 14:26:12 +0100 Subject: [PATCH 2/2] add test --- test/builtin_dataset_mocks.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index fc980326307..2de4c63498d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1298,3 +1298,23 @@ def generate(cls, root): def cub200(info, root, config): num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year} + + +@DATASET_MOCKS.set_from_named_callable +def svhn(info, root, config): + import scipy.io as sio + + num_samples = { + "train": 2, + "test": 3, + "extra": 4, + }[config.split] + + sio.savemat( + root / f"{config.split}_32x32.mat", + { + "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), + "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), + }, + ) + return num_samples