From b351430718be6a90d53d150d2d3af6619868cf7d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 17 Dec 2021 19:46:54 +0100 Subject: [PATCH 1/8] add DTD as prototype dataset --- .../prototype/datasets/_builtin/__init__.py | 1 + .../datasets/_builtin/dtd.categories | 47 +++++++ .../prototype/datasets/_builtin/dtd.py | 130 ++++++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 torchvision/prototype/datasets/_builtin/dtd.categories create mode 100644 torchvision/prototype/datasets/_builtin/dtd.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 62abc3119f6..7e5fd788466 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -2,6 +2,7 @@ from .celeba import CelebA from .cifar import Cifar10, Cifar100 from .coco import Coco +from .dtd import DTD from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .sbd import SBD diff --git a/torchvision/prototype/datasets/_builtin/dtd.categories b/torchvision/prototype/datasets/_builtin/dtd.categories new file mode 100644 index 00000000000..7f3df8a2b00 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/dtd.categories @@ -0,0 +1,47 @@ +banded +blotchy +braided +bubbly +bumpy +chequered +cobwebbed +cracked +crosshatched +crystalline +dotted +fibrous +flecked +freckled +frilly +gauzy +grid +grooved +honeycombed +interlaced +knitted +lacelike +lined +marbled +matted +meshed +paisley +perforated +pitted +pleated +polka-dotted +porous +potholed +scaly +smeared +spiralled +sprinkled +stained +stratified +striped +studded +swirly +veined +waffled +woven +wrinkled +zigzagged diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py new file mode 100644 index 00000000000..e78ab88da27 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -0,0 +1,130 @@ +import io +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torchdata.datapipes.iter import ( + IterDataPipe, + Mapper, + Shuffler, + Filter, + IterKeyZipper, + Demultiplexer, + LineReader, + CSVParser, +) +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + hint_sharding, + path_comparator, + getitem, +) +from torchvision.prototype.features import Label + + +class DTD(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "dtd", + type=DatasetType.IMAGE, + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + valid_options=dict( + split=("train", "test", "val"), + fold=tuple(str(fold) for fold in range(1, 11)), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + archive = HttpResource( + "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", + sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", + decompress=True, + ) + return [archive] + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: + path = pathlib.Path(data[0]) + if path.parent.name == "labels": + if path.name == "labels_joint_anno.txt": + return 1 + + return 0 + elif path.parents[1].name == "images": + return 2 + else: + return None + + def _image_key_fn(self, data: Tuple[str, Any]) -> str: + path = pathlib.Path(data[0]) + return str(path.relative_to(path.parents[1])) + + def _collate_and_decode_sample( + self, + data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + (_, joint_categories_data), image_data = data + _, *joint_categories = joint_categories_data + path, buffer = image_data + + category = pathlib.Path(path).parent.name + + return dict( + joint_categories={category for category in joint_categories if category}, + label=Label(self.info.categories.index(category), category=category), + path=path, + image=decoder(buffer) if decoder else buffer, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + + splits_dp, joint_categories_dp, images_dp = Demultiplexer( + archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + + splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) + splits_dp = LineReader(splits_dp, decode=True, return_path=False) + splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) + splits_dp = hint_sharding(splits_dp) + + joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") + + dp = IterKeyZipper( + splits_dp, + joint_categories_dp, + key_fn=getitem(), + ref_key_fn=getitem(0), + buffer_size=INFINITE_BUFFER_SIZE, + ) + dp = IterKeyZipper( + dp, + images_dp, + key_fn=getitem(0), + ref_key_fn=self._image_key_fn, + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + + def _filter_images(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_images) + return sorted({pathlib.Path(path).parent.name for path, _ in dp}) From e9b4c1cba7d1e40cee5a2a269f54e0c31ccf1ddd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 19 Dec 2021 20:53:54 +0100 Subject: [PATCH 2/8] add old style dataset --- torchvision/datasets/__init__.py | 2 + torchvision/datasets/dtd.py | 94 ++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 torchvision/datasets/dtd.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 80859791004..6bd9be9b338 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -4,6 +4,7 @@ from .cifar import CIFAR10, CIFAR100 from .cityscapes import Cityscapes from .coco import CocoCaptions, CocoDetection +from .dtd import DTD from .fakedata import FakeData from .flickr import Flickr8k, Flickr30k from .folder import ImageFolder, DatasetFolder @@ -77,4 +78,5 @@ "FlyingChairs", "FlyingThings3D", "HD1K", + "DTD", ) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py new file mode 100644 index 00000000000..cf4da3ea3ef --- /dev/null +++ b/torchvision/datasets/dtd.py @@ -0,0 +1,94 @@ +import os +from typing import Optional, Callable, Union + +import PIL.Image + +from .utils import verify_str_arg, download_and_extract_archive +from .vision import VisionDataset + + +class DTD(VisionDataset): + """`Describable Textures Dataset (DTD) `_. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + fold (string or int, optional): The dataset fold. Should be ``1 <= fold <= 10``. Defaults to ``1``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" + _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" + + def __init__( + self, + root: str, + split: str = "train", + fold: Union[str, int] = 1, + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform, target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + self._fold = verify_str_arg(str(fold), "fold", [str(i) for i in range(1, 11)]) + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + image_files = [] + categories = [] + with open(os.path.join(self._meta_folder, f"{self._split}{self._fold}.txt")) as file: + for line in file: + category, name = line.strip().split("/") + image_files.append(os.path.join(self._images_folder, category, name)) + categories.append(category) + self._image_files, self._categories = image_files, categories + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx): + image_file, label = self._image_files[idx], self._categories[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}, fold={self._fold}" + + @property + def _base_folder(self): + return os.path.join(self.root, type(self).__name__.lower()) + + @property + def _data_folder(self) -> str: + return os.path.join(self._base_folder, "dtd") + + @property + def _meta_folder(self) -> str: + return os.path.join(self._data_folder, "labels") + + @property + def _images_folder(self) -> str: + return os.path.join(self._data_folder, "images") + + def _check_exists(self) -> bool: + return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) + + def _download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive(self._URL, download_root=self._base_folder, md5=self._MD5) From 52ab24974482a6258f029f96d62a532a699aa09b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 19 Dec 2021 21:17:39 +0100 Subject: [PATCH 3/8] add test for old dataset --- test/test_datasets.py | 36 ++++++++++++++++++++++++++++++++++++ torchvision/datasets/dtd.py | 19 +++++++++++-------- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 761f11d77dc..d5f2f246304 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2168,5 +2168,41 @@ def inject_fake_data(self, tmpdir, config): return num_sequences * (num_examples_per_sequence - 1) +class DTDTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.DTD + FEATURE_TYPES = (PIL.Image.Image, int) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test", "val"), + # There is no need to test the whole matrix here, since each fold is treated exactly the same + fold=(5,), + ) + + def inject_fake_data(self, tmpdir: str, config): + data_folder = os.path.join(tmpdir, "dtd", "dtd") + os.makedirs(data_folder) + + num_images_per_class = 3 + image_folder = os.path.join(data_folder, "images") + image_files = [] + for cls in ("banded", "marbled", "zigzagged"): + image_files.extend( + datasets_utils.create_image_folder( + image_folder, + cls, + file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg", + num_examples=num_images_per_class, + ) + ) + + meta_folder = os.path.join(data_folder, "labels") + os.makedirs(meta_folder) + image_files_in_config = random.choices(image_files, k=len(image_files) // 2) + with open(os.path.join(meta_folder, f"{config['split']}{config['fold']}.txt"), "w") as file: + file.write("\n".join(str(path.relative_to(path.parents[1])) for path in image_files_in_config) + "\n") + + return len(image_files_in_config) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index cf4da3ea3ef..a2fe9ae9003 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -31,7 +31,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - super().__init__(root, transform, target_transform) + super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._fold = verify_str_arg(str(fold), "fold", [str(i) for i in range(1, 11)]) @@ -42,19 +42,23 @@ def __init__( raise RuntimeError("Dataset not found. You can use download=True to download it") image_files = [] - categories = [] + classes = [] with open(os.path.join(self._meta_folder, f"{self._split}{self._fold}.txt")) as file: for line in file: - category, name = line.strip().split("/") - image_files.append(os.path.join(self._images_folder, category, name)) - categories.append(category) - self._image_files, self._categories = image_files, categories + cls, name = line.strip().split("/") + image_files.append(os.path.join(self._images_folder, cls, name)) + classes.append(cls) + self._image_files = image_files + + self.classes = sorted(set(classes)) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + self._labels = [self.class_to_idx[cls] for cls in classes] def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx): - image_file, label = self._image_files[idx], self._categories[idx] + image_file, label = self._image_files[idx], self._labels[idx] image = PIL.Image.open(image_file).convert("RGB") if self.transform: @@ -90,5 +94,4 @@ def _check_exists(self) -> bool: def _download(self) -> None: if self._check_exists(): return - download_and_extract_archive(self._URL, download_root=self._base_folder, md5=self._MD5) From b7be23f0d98b6659772aa256459a7bd0f0a2d5b2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 20 Dec 2021 08:50:27 +0100 Subject: [PATCH 4/8] fix tests for windows --- test/test_datasets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index d5f2f246304..0aa875c9c2b 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2197,11 +2197,12 @@ def inject_fake_data(self, tmpdir: str, config): meta_folder = os.path.join(data_folder, "labels") os.makedirs(meta_folder) - image_files_in_config = random.choices(image_files, k=len(image_files) // 2) + image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files] + image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2) with open(os.path.join(meta_folder, f"{config['split']}{config['fold']}.txt"), "w") as file: - file.write("\n".join(str(path.relative_to(path.parents[1])) for path in image_files_in_config) + "\n") + file.write("\n".join(image_ids_in_config) + "\n") - return len(image_files_in_config) + return len(image_ids_in_config) if __name__ == "__main__": From e044a8e7de15e0eb3edfe4e126aaaa858a2f75de Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 17:48:30 +0100 Subject: [PATCH 5/8] add dataset to docs --- docs/source/datasets.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7f09ff245ca..5ceb191e86b 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes Date: Wed, 22 Dec 2021 09:29:13 +0100 Subject: [PATCH 6/8] remove properties and use pathlib --- torchvision/datasets/dtd.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index a2fe9ae9003..37652478601 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -1,4 +1,5 @@ import os +import pathlib from typing import Optional, Callable, Union import PIL.Image @@ -31,24 +32,28 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._fold = verify_str_arg(str(fold), "fold", [str(i) for i in range(1, 11)]) + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() + self._data_folder = self._base_folder / "dtd" + self._meta_folder = self._data_folder / "labels" + self._images_folder = self._data_folder / "images" + if download: self._download() if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - image_files = [] + self._image_files = [] classes = [] - with open(os.path.join(self._meta_folder, f"{self._split}{self._fold}.txt")) as file: + with open(self._meta_folder / f"{self._split}{self._fold}.txt") as file: for line in file: cls, name = line.strip().split("/") - image_files.append(os.path.join(self._images_folder, cls, name)) + self._image_files.append(self._images_folder.joinpath(cls, name)) classes.append(cls) - self._image_files = image_files self.classes = sorted(set(classes)) self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) @@ -72,26 +77,10 @@ def __getitem__(self, idx): def extra_repr(self) -> str: return f"split={self._split}, fold={self._fold}" - @property - def _base_folder(self): - return os.path.join(self.root, type(self).__name__.lower()) - - @property - def _data_folder(self) -> str: - return os.path.join(self._base_folder, "dtd") - - @property - def _meta_folder(self) -> str: - return os.path.join(self._data_folder, "labels") - - @property - def _images_folder(self) -> str: - return os.path.join(self._data_folder, "images") - def _check_exists(self) -> bool: return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) def _download(self) -> None: if self._check_exists(): return - download_and_extract_archive(self._URL, download_root=self._base_folder, md5=self._MD5) + download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) From 6df6403688b17d980b87ac9faf87408dfb621b29 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 5 Jan 2022 13:51:49 +0100 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Nicolas Hug --- torchvision/datasets/dtd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index 37652478601..ded41cbaa2e 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -15,7 +15,10 @@ class DTD(VisionDataset): root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. fold (string or int, optional): The dataset fold. Should be ``1 <= fold <= 10``. Defaults to ``1``. - transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ From 7d21152728177363b4b5676dfaea18541271fdb4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 5 Jan 2022 14:48:37 +0100 Subject: [PATCH 8/8] fold -> partition --- test/test_datasets.py | 4 ++-- torchvision/datasets/dtd.py | 23 +++++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 952858c4c19..914a899afaa 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2212,7 +2212,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( split=("train", "test", "val"), # There is no need to test the whole matrix here, since each fold is treated exactly the same - fold=(5,), + partition=(1, 5, 10), ) def inject_fake_data(self, tmpdir: str, config): @@ -2235,7 +2235,7 @@ def inject_fake_data(self, tmpdir: str, config): meta_folder.mkdir() image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files] image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2) - with open(meta_folder / f"{config['split']}{config['fold']}.txt", "w") as file: + with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file: file.write("\n".join(image_ids_in_config) + "\n") return len(image_ids_in_config) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index ded41cbaa2e..ceacc64eedb 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -1,6 +1,6 @@ import os import pathlib -from typing import Optional, Callable, Union +from typing import Optional, Callable import PIL.Image @@ -14,7 +14,13 @@ class DTD(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. - fold (string or int, optional): The dataset fold. Should be ``1 <= fold <= 10``. Defaults to ``1``. + partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. + + .. note:: + + The partition only changes which split each image belongs to. Thus, regardless of the selected + partition, combining all splits will result in all images. + download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. @@ -30,13 +36,18 @@ def __init__( self, root: str, split: str = "train", - fold: Union[str, int] = 1, + partition: int = 1, download: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) - self._fold = verify_str_arg(str(fold), "fold", [str(i) for i in range(1, 11)]) + if not isinstance(partition, int) and not (1 <= partition <= 10): + raise ValueError( + f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " + f"but got {partition} instead" + ) + self._partition = partition super().__init__(root, transform=transform, target_transform=target_transform) self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() @@ -52,7 +63,7 @@ def __init__( self._image_files = [] classes = [] - with open(self._meta_folder / f"{self._split}{self._fold}.txt") as file: + with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: for line in file: cls, name = line.strip().split("/") self._image_files.append(self._images_folder.joinpath(cls, name)) @@ -78,7 +89,7 @@ def __getitem__(self, idx): return image, label def extra_repr(self) -> str: - return f"split={self._split}, fold={self._fold}" + return f"split={self._split}, partition={self._partition}" def _check_exists(self) -> bool: return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)