From 684e73a9564d029e30f3668ae77038b30bef4c85 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 6 Apr 2021 19:56:45 +0100 Subject: [PATCH 1/8] Added KITTI dataset --- test/fakedata_generation.py | 33 ++++++ test/test_datasets.py | 20 +++- torchvision/datasets/__init__.py | 4 +- torchvision/datasets/kitti.py | 184 +++++++++++++++++++++++++++++++ 4 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 torchvision/datasets/kitti.py diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index 314222dc43f..0fe2ad23dfd 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -404,3 +404,36 @@ def make_archive(stack, root, name): data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive) yield root, data + + +@contextlib.contextmanager +def kitti_root(): + def _make_image(file): + PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file) + + def _make_train_archive(root): + extracted_dir = os.path.join(root, 'training', 'image_2') + os.makedirs(extracted_dir) + _make_image(os.path.join(extracted_dir, '00000.png')) + + def _make_target_archive(root): + extracted_dir = os.path.join(root, 'training', 'label_2') + os.makedirs(extracted_dir) + target_contents = 'Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n' + target_file = os.path.join(extracted_dir, '00000.txt') + with open(target_file, "w") as txt_file: + txt_file.write(target_contents) + + def _make_test_archive(root): + extracted_dir = os.path.join(root, 'testing', 'image_2') + os.makedirs(extracted_dir) + _make_image(os.path.join(extracted_dir, '00001.png')) + + with get_tmp_dir() as root: + raw_dir = os.path.join(root, "Kitti", "raw") + os.makedirs(raw_dir) + _make_train_archive(raw_dir) + _make_target_archive(raw_dir) + _make_test_archive(raw_dir) + + yield root diff --git a/test/test_datasets.py b/test/test_datasets.py index db80b55a90f..67a68d326f8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -10,7 +10,12 @@ import torchvision from torchvision.datasets import utils from common_utils import get_tmp_dir -from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root +from fakedata_generation import ( + kitti_root, + places365_root, + stl10_root, + svhn_root, +) import xml.etree.ElementTree as ET from urllib.request import Request, urlopen import itertools @@ -155,6 +160,19 @@ def test_places365_repr_smoke(self): dataset = torchvision.datasets.Places365(root, download=True) self.assertIsInstance(repr(dataset), str) + def test_kitti(self): + with kitti_root() as root: + dataset = torchvision.datasets.Kitti(root) + self.assertEqual(len(dataset), 1) + img, target = dataset[0][0], dataset[0][1] + self.assertTrue(isinstance(img, PIL.Image.Image)) + + dataset = torchvision.datasets.Kitti(root, split='test') + self.assertEqual(len(dataset), 1) + img, target = dataset[0][0], dataset[0][1] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertEqual(target, None) + class STL10Tester(DatasetTestcase): @contextlib.contextmanager diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 0ce0fd6bd60..b60fc7c7964 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -24,6 +24,7 @@ from .hmdb51 import HMDB51 from .ucf101 import UCF101 from .places365 import Places365 +from .kitti import Kitti __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', @@ -34,4 +35,5 @@ 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', 'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101', - 'Places365') + 'Places365', 'Kitti', + ) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py new file mode 100644 index 00000000000..240329c61bc --- /dev/null +++ b/torchvision/datasets/kitti.py @@ -0,0 +1,184 @@ +import os +from collections import namedtuple +from typing import Any, Callable, NamedTuple, Optional, Tuple +from urllib.error import URLError + +import pandas as pd +from PIL import Image + +from .utils import download_and_extract_archive +from .vision import VisionDataset + + +class Kitti(VisionDataset): + """`KITTI `_ Dataset. + Args: + root (string): Root directory where images are downloaded to. + Expects the following folder structure if download=False: + + .. code:: + + + └─ Kitti + └─ raw + ├── training + | ├── image_2 + | └── label_2 + └── testing + └── image_2 + split (string): The dataset split to use. One of {``train``, ``test``}. + Defaults to ``train``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample + and its target as entry and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + mirrors = [ + "https://s3.eu-central-1.amazonaws.com/avg-kitti/", + ] + resources = [ + "data_object_image_2.zip", + "data_object_label_2.zip", + ] + image_dir_name = "image_2" + labels_dir_name = "label_2" + + def __init__( + self, + root: str, + split: str = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + download: bool = False, + ): + super().__init__( + root, + transform=transform, + target_transform=target_transform, + transforms=transforms, + ) + self.TargetTuple = namedtuple( + "TargetTuple", + [ + "type", + "truncated", + "occluded", + "alpha", + "bbox", + "dimensions", + "location", + "rotation_y", + ], + ) + self.images = [] + self.targets = [] + self.root = root + self.split = split + self.transform = transform + self.target_transform = target_transform + self.transforms = transforms + + if download: + self.download() + if not self._check_exists(): + raise RuntimeError( + "Dataset not found. You may use download=True to download it." + ) + + location = "testing" if self.split == "test" else "training" + image_dir = os.path.join(self.raw_folder, location, self.image_dir_name) + if location == "training": + labels_dir = os.path.join(self.raw_folder, location, self.labels_dir_name) + for img_file in os.listdir(image_dir): + self.images.append(os.path.join(image_dir, img_file)) + if location == "training": + self.targets.append( + os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") + ) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get item at a given index. + Args: + index (int): Index + Returns: + tuple: (image, target), where + target is a namedtuple with the following fields: + type: Int64Tensor[N] + truncated: FloatTensor[N] + occluded: Int64Tensor[N] + alpha: FloatTensor[N] + bbox: FloatTensor[N, 4] + dimensions: FloatTensor[N, 3] + locations: FloatTensor[N, 3] + rotation_y: FloatTensor[N] + score: FloatTensor[N] + """ + image = Image.open(self.images[index]) + target = None if self.split == "test" else self._parse_target(index) + if self.transforms: + image, target = self.transforms(image, target) + if self.transform: + image = self.transform(image) + if self.target_transform: + target = self.target_transform(target) + return image, target + + def _parse_target(self, index: int) -> NamedTuple: + target_df = pd.read_csv(self.targets[index], delimiter=" ", header=None) + return self.TargetTuple( + type=target_df.iloc[:, 0].values, + truncated=target_df.iloc[:, 1].values, + occluded=target_df.iloc[:, 2].values, + alpha=target_df.iloc[:, 3].values, + bbox=target_df.iloc[:, 4:8].values, + dimensions=target_df.iloc[:, 8:11].values, + location=target_df.iloc[:, 11:14].values, + rotation_y=target_df.iloc[:, 14].values, + ) + + def __len__(self) -> int: + return len(self.images) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + def _check_exists(self) -> bool: + """Check if the data directory exists.""" + location = "testing" if self.split == "test" else "training" + folders = [self.image_dir_name] + if self.split != "test": + folders.append(self.labels_dir_name) + return all( + os.path.isdir(os.path.join(self.raw_folder, location, fname)) + for fname in folders + ) + + def download(self) -> None: + """Download the KITTI data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for fname in self.resources: + for mirror in self.mirrors: + url = f"{mirror}{fname}" + try: + print(f"Downloading {url}") + download_and_extract_archive( + url=url, + download_root=self.raw_folder, + filename=fname, + ) + except URLError as error: + print(f"Error downloading {fname}: {error}") From 004357d84ca2b3f99055d96393a7f8ebbd327ede Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 7 Apr 2021 13:27:58 +0100 Subject: [PATCH 2/8] Addressed review comments --- torchvision/datasets/kitti.py | 54 ++++++++++++----------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 240329c61bc..e40e5432da5 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -1,9 +1,9 @@ +import csv import os -from collections import namedtuple -from typing import Any, Callable, NamedTuple, Optional, Tuple +from collections import defaultdict +from typing import Any, Callable, Dict, Optional, Tuple from urllib.error import URLError -import pandas as pd from PIL import Image from .utils import download_and_extract_archive @@ -64,26 +64,10 @@ def __init__( target_transform=target_transform, transforms=transforms, ) - self.TargetTuple = namedtuple( - "TargetTuple", - [ - "type", - "truncated", - "occluded", - "alpha", - "bbox", - "dimensions", - "location", - "rotation_y", - ], - ) self.images = [] self.targets = [] self.root = root self.split = split - self.transform = transform - self.target_transform = target_transform - self.transforms = transforms if download: self.download() @@ -109,7 +93,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: index (int): Index Returns: tuple: (image, target), where - target is a namedtuple with the following fields: + target is a dictionary with the following keys: type: Int64Tensor[N] truncated: FloatTensor[N] occluded: Int64Tensor[N] @@ -124,24 +108,22 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: target = None if self.split == "test" else self._parse_target(index) if self.transforms: image, target = self.transforms(image, target) - if self.transform: - image = self.transform(image) - if self.target_transform: - target = self.target_transform(target) return image, target - def _parse_target(self, index: int) -> NamedTuple: - target_df = pd.read_csv(self.targets[index], delimiter=" ", header=None) - return self.TargetTuple( - type=target_df.iloc[:, 0].values, - truncated=target_df.iloc[:, 1].values, - occluded=target_df.iloc[:, 2].values, - alpha=target_df.iloc[:, 3].values, - bbox=target_df.iloc[:, 4:8].values, - dimensions=target_df.iloc[:, 8:11].values, - location=target_df.iloc[:, 11:14].values, - rotation_y=target_df.iloc[:, 14].values, - ) + def _parse_target(self, index: int) -> Dict[str, Any]: + target: Dict[str, Any] = defaultdict(list) + with open(self.targets[index]) as inp: + content = csv.reader(inp, delimiter=" ") + for line in content: + target["type"].append(line[0]) + target["truncated"].append(line[1]) + target["occluded"].append(line[2]) + target["alpha"].append(line[3]) + target["bbox"].append(line[4:8]) + target["dimensions"].append(line[8:11]) + target["location"].append(line[11:14]) + target["rotation_y"].append(line[14]) + return target def __len__(self) -> int: return len(self.images) From 4786e7b82bbc9eec9f403acaa0dc87b9aad45bae Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 7 Apr 2021 15:46:39 +0100 Subject: [PATCH 3/8] Changed type of target to List[Dict] and corrected the data types of the returned values. --- torchvision/datasets/kitti.py | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index e40e5432da5..776ccc4f15a 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -1,7 +1,6 @@ import csv import os -from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple from urllib.error import URLError from PIL import Image @@ -93,16 +92,15 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: index (int): Index Returns: tuple: (image, target), where - target is a dictionary with the following keys: - type: Int64Tensor[N] - truncated: FloatTensor[N] - occluded: Int64Tensor[N] - alpha: FloatTensor[N] - bbox: FloatTensor[N, 4] - dimensions: FloatTensor[N, 3] - locations: FloatTensor[N, 3] - rotation_y: FloatTensor[N] - score: FloatTensor[N] + target is a list of dictionaries with the following keys: + type: str + truncated: float + occluded: int + alpha: float + bbox: float[4] + dimensions: float[3] + locations: float[3] + rotation_y: float """ image = Image.open(self.images[index]) target = None if self.split == "test" else self._parse_target(index) @@ -110,19 +108,21 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: image, target = self.transforms(image, target) return image, target - def _parse_target(self, index: int) -> Dict[str, Any]: - target: Dict[str, Any] = defaultdict(list) + def _parse_target(self, index: int) -> List: + target = [] with open(self.targets[index]) as inp: content = csv.reader(inp, delimiter=" ") for line in content: - target["type"].append(line[0]) - target["truncated"].append(line[1]) - target["occluded"].append(line[2]) - target["alpha"].append(line[3]) - target["bbox"].append(line[4:8]) - target["dimensions"].append(line[8:11]) - target["location"].append(line[11:14]) - target["rotation_y"].append(line[14]) + target.append({ + "type": line[0], + "truncated": float(line[1]), + "occluded": int(line[2]), + "alpha": float(line[3]), + "bbox": [float(x) for x in line[4:8]], + "dimensions": [float(x) for x in line[8:11]], + "location": [float(x) for x in line[11:14]], + "rotation_y": float(line[14]), + }) return target def __len__(self) -> int: From eecbf87470968bd857cfd5ff4da4db584e537fd0 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 8 Apr 2021 12:21:12 +0100 Subject: [PATCH 4/8] Updated unit test to rely on ImageDatasetTestCase --- test/datasets_utils.py | 2 +- test/fakedata_generation.py | 33 ---------------------- test/test_datasets.py | 56 +++++++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 52 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 658ef6640fe..07085b5d0ed 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -44,7 +44,7 @@ class UsageError(Exception): class LazyImporter: - r"""Lazy importer for additional dependicies. + r"""Lazy importer for additional dependencies. Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class provide modules listed in MODULES as attributes. They are only imported when accessed. diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index 0fe2ad23dfd..314222dc43f 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -404,36 +404,3 @@ def make_archive(stack, root, name): data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive) yield root, data - - -@contextlib.contextmanager -def kitti_root(): - def _make_image(file): - PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file) - - def _make_train_archive(root): - extracted_dir = os.path.join(root, 'training', 'image_2') - os.makedirs(extracted_dir) - _make_image(os.path.join(extracted_dir, '00000.png')) - - def _make_target_archive(root): - extracted_dir = os.path.join(root, 'training', 'label_2') - os.makedirs(extracted_dir) - target_contents = 'Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n' - target_file = os.path.join(extracted_dir, '00000.txt') - with open(target_file, "w") as txt_file: - txt_file.write(target_contents) - - def _make_test_archive(root): - extracted_dir = os.path.join(root, 'testing', 'image_2') - os.makedirs(extracted_dir) - _make_image(os.path.join(extracted_dir, '00001.png')) - - with get_tmp_dir() as root: - raw_dir = os.path.join(root, "Kitti", "raw") - os.makedirs(raw_dir) - _make_train_archive(raw_dir) - _make_target_archive(raw_dir) - _make_test_archive(raw_dir) - - yield root diff --git a/test/test_datasets.py b/test/test_datasets.py index 67a68d326f8..581ee0034c4 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,23 +1,18 @@ import contextlib -import sys import os import unittest from unittest import mock import numpy as np import PIL from PIL import Image -from torch._utils_internal import get_file_path_2 import torchvision from torchvision.datasets import utils -from common_utils import get_tmp_dir from fakedata_generation import ( - kitti_root, places365_root, stl10_root, svhn_root, ) import xml.etree.ElementTree as ET -from urllib.request import Request, urlopen import itertools import datasets_utils import pathlib @@ -160,19 +155,6 @@ def test_places365_repr_smoke(self): dataset = torchvision.datasets.Places365(root, download=True) self.assertIsInstance(repr(dataset), str) - def test_kitti(self): - with kitti_root() as root: - dataset = torchvision.datasets.Kitti(root) - self.assertEqual(len(dataset), 1) - img, target = dataset[0][0], dataset[0][1] - self.assertTrue(isinstance(img, PIL.Image.Image)) - - dataset = torchvision.datasets.Kitti(root, split='test') - self.assertEqual(len(dataset), 1) - img, target = dataset[0][0], dataset[0][1] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertEqual(target, None) - class STL10Tester(DatasetTestcase): @contextlib.contextmanager @@ -1720,5 +1702,43 @@ def test_classes(self, config): self.assertSequenceEqual(dataset.classes, info["classes"]) +class KittiTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti + FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = os.path.join(tmpdir, "Kitti", "raw") + os.makedirs(kitti_dir) + + split_to_idx = split_to_num_examples = { + None: 1, + "train": 1, + "test": 2, + } + + # We need to create all folders regardless of the split in config + for split in ("train", "test"): + split_idx = split_to_idx[split] + num_examples = split_to_num_examples[split] + + datasets_utils.create_image_folder( + root=kitti_dir, + name=os.path.join(f"{split}ing", "image_2"), + file_name_fn=lambda image_idx: f"000{split_idx + image_idx}.png", + num_examples=num_examples, + ) + if split == "train": + for image_idx in range(num_examples): + target_file_dir = os.path.join(kitti_dir, f"{split}ing", "label_2") + os.makedirs(target_file_dir) + target_file_name = os.path.join(target_file_dir, f"000{split_idx + image_idx}.txt") + target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa + with open(target_file_name, "w") as target_file: + target_file.write(target_contents) + + return split_to_num_examples[config["split"]] + + if __name__ == "__main__": unittest.main() From 8d0219bd12e9b6cbaa4bb89e9bc74bf2ecdd1629 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 8 Apr 2021 17:28:07 +0100 Subject: [PATCH 5/8] Added kitti to dataset documentation --- docs/source/datasets.rst | 7 ++++++ torchvision/datasets/kitti.py | 45 +++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index ceb517ced8f..cb02f2bcaa3 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -149,6 +149,13 @@ Kinetics-400 :members: __getitem__ :special-members: +KITTI +~~~~~~~~~ + +.. autoclass:: Kitti + :members: __getitem__ + :special-members: + KMNIST ~~~~~~~~~~~~~ diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 776ccc4f15a..de98c7c6d88 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -11,6 +11,7 @@ class Kitti(VisionDataset): """`KITTI `_ Dataset. + Args: root (string): Root directory where images are downloaded to. Expects the following folder structure if download=False: @@ -18,12 +19,12 @@ class Kitti(VisionDataset): .. code:: - └─ Kitti - └─ raw - ├── training - | ├── image_2 - | └── label_2 - └── testing + └── Kitti + └─ raw + ├── training + | ├── image_2 + | └── label_2 + └── testing └── image_2 split (string): The dataset split to use. One of {``train``, ``test``}. Defaults to ``train``. @@ -36,6 +37,7 @@ class Kitti(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + """ mirrors = [ @@ -76,9 +78,9 @@ def __init__( ) location = "testing" if self.split == "test" else "training" - image_dir = os.path.join(self.raw_folder, location, self.image_dir_name) + image_dir = os.path.join(self._raw_folder, location, self.image_dir_name) if location == "training": - labels_dir = os.path.join(self.raw_folder, location, self.labels_dir_name) + labels_dir = os.path.join(self._raw_folder, location, self.labels_dir_name) for img_file in os.listdir(image_dir): self.images.append(os.path.join(image_dir, img_file)) if location == "training": @@ -88,19 +90,22 @@ def __init__( def __getitem__(self, index: int) -> Tuple[Any, Any]: """Get item at a given index. + Args: index (int): Index Returns: tuple: (image, target), where target is a list of dictionaries with the following keys: - type: str - truncated: float - occluded: int - alpha: float - bbox: float[4] - dimensions: float[3] - locations: float[3] - rotation_y: float + + - type: str + - truncated: float + - occluded: int + - alpha: float + - bbox: float[4] + - dimensions: float[3] + - locations: float[3] + - rotation_y: float + """ image = Image.open(self.images[index]) target = None if self.split == "test" else self._parse_target(index) @@ -129,7 +134,7 @@ def __len__(self) -> int: return len(self.images) @property - def raw_folder(self) -> str: + def _raw_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, "raw") def _check_exists(self) -> bool: @@ -139,7 +144,7 @@ def _check_exists(self) -> bool: if self.split != "test": folders.append(self.labels_dir_name) return all( - os.path.isdir(os.path.join(self.raw_folder, location, fname)) + os.path.isdir(os.path.join(self._raw_folder, location, fname)) for fname in folders ) @@ -149,7 +154,7 @@ def download(self) -> None: if self._check_exists(): return - os.makedirs(self.raw_folder, exist_ok=True) + os.makedirs(self._raw_folder, exist_ok=True) # download files for fname in self.resources: @@ -159,7 +164,7 @@ def download(self) -> None: print(f"Downloading {url}") download_and_extract_archive( url=url, - download_root=self.raw_folder, + download_root=self._raw_folder, filename=fname, ) except URLError as error: From 9c083f6c1027d0c7a1e8062c7b8827095f4461ff Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 9 Apr 2021 11:17:29 +0100 Subject: [PATCH 6/8] Cleaned up test and some minor changes --- test/datasets_utils.py | 2 +- test/test_datasets.py | 38 +++++++++++++++++------------------ torchvision/datasets/kitti.py | 25 +++++++++++------------ 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 07085b5d0ed..658ef6640fe 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -44,7 +44,7 @@ class UsageError(Exception): class LazyImporter: - r"""Lazy importer for additional dependencies. + r"""Lazy importer for additional dependicies. Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class provide modules listed in MODULES as attributes. They are only imported when accessed. diff --git a/test/test_datasets.py b/test/test_datasets.py index 581ee0034c4..1c45af1163c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,18 +1,18 @@ import contextlib +import sys import os import unittest from unittest import mock import numpy as np import PIL from PIL import Image +from torch._utils_internal import get_file_path_2 import torchvision from torchvision.datasets import utils -from fakedata_generation import ( - places365_root, - stl10_root, - svhn_root, -) +from common_utils import get_tmp_dir +from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root import xml.etree.ElementTree as ET +from urllib.request import Request, urlopen import itertools import datasets_utils import pathlib @@ -1705,39 +1705,37 @@ def test_classes(self, config): class KittiTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Kitti FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) def inject_fake_data(self, tmpdir, config): kitti_dir = os.path.join(tmpdir, "Kitti", "raw") os.makedirs(kitti_dir) - split_to_idx = split_to_num_examples = { - None: 1, - "train": 1, - "test": 2, + split_to_num_examples = { + True: 1, + False: 2, } - # We need to create all folders regardless of the split in config - for split in ("train", "test"): - split_idx = split_to_idx[split] - num_examples = split_to_num_examples[split] + # We need to create all folders(training and testing). + for is_training in (True, False): + num_examples = split_to_num_examples[is_training] datasets_utils.create_image_folder( root=kitti_dir, - name=os.path.join(f"{split}ing", "image_2"), - file_name_fn=lambda image_idx: f"000{split_idx + image_idx}.png", + name=os.path.join("training" if is_training else "testing", "image_2"), + file_name_fn=lambda image_idx: f"{image_idx:06d}.png", num_examples=num_examples, ) - if split == "train": + if is_training: for image_idx in range(num_examples): - target_file_dir = os.path.join(kitti_dir, f"{split}ing", "label_2") + target_file_dir = os.path.join(kitti_dir, "training", "label_2") os.makedirs(target_file_dir) - target_file_name = os.path.join(target_file_dir, f"000{split_idx + image_idx}.txt") + target_file_name = os.path.join(target_file_dir, f"{image_idx:06d}.txt") target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa with open(target_file_name, "w") as target_file: target_file.write(target_contents) - return split_to_num_examples[config["split"]] + return split_to_num_examples[config["train"]] if __name__ == "__main__": diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index de98c7c6d88..5d51dd8f2ac 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -26,9 +26,9 @@ class Kitti(VisionDataset): | └── label_2 └── testing └── image_2 - split (string): The dataset split to use. One of {``train``, ``test``}. + train (bool, optional): Use ``train`` split if true, else ``test`` split. Defaults to ``train``. - transform (callable, optional): A function/transform that takes in an PIL image + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -53,7 +53,7 @@ class Kitti(VisionDataset): def __init__( self, root: str, - split: str = None, + train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, @@ -68,7 +68,8 @@ def __init__( self.images = [] self.targets = [] self.root = root - self.split = split + self.train = train + self._location = "training" if self.train else "testing" if download: self.download() @@ -77,13 +78,12 @@ def __init__( "Dataset not found. You may use download=True to download it." ) - location = "testing" if self.split == "test" else "training" - image_dir = os.path.join(self._raw_folder, location, self.image_dir_name) - if location == "training": - labels_dir = os.path.join(self._raw_folder, location, self.labels_dir_name) + image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) + if self.train: + labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name) for img_file in os.listdir(image_dir): self.images.append(os.path.join(image_dir, img_file)) - if location == "training": + if self.train: self.targets.append( os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") ) @@ -108,7 +108,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ image = Image.open(self.images[index]) - target = None if self.split == "test" else self._parse_target(index) + target = self._parse_target(index) if self.train else None if self.transforms: image, target = self.transforms(image, target) return image, target @@ -139,12 +139,11 @@ def _raw_folder(self) -> str: def _check_exists(self) -> bool: """Check if the data directory exists.""" - location = "testing" if self.split == "test" else "training" folders = [self.image_dir_name] - if self.split != "test": + if self.train: folders.append(self.labels_dir_name) return all( - os.path.isdir(os.path.join(self._raw_folder, location, fname)) + os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders ) From 544c73d1bf0b477074fc747f26e6898ecf09f402 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 9 Apr 2021 13:29:15 +0100 Subject: [PATCH 7/8] Made data_url a string instead of a list --- torchvision/datasets/kitti.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 5d51dd8f2ac..74ab644d2da 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -40,9 +40,7 @@ class Kitti(VisionDataset): """ - mirrors = [ - "https://s3.eu-central-1.amazonaws.com/avg-kitti/", - ] + data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/" resources = [ "data_object_image_2.zip", "data_object_label_2.zip", @@ -157,14 +155,13 @@ def download(self) -> None: # download files for fname in self.resources: - for mirror in self.mirrors: - url = f"{mirror}{fname}" - try: - print(f"Downloading {url}") - download_and_extract_archive( - url=url, - download_root=self._raw_folder, - filename=fname, - ) - except URLError as error: - print(f"Error downloading {fname}: {error}") + url = f"{self.data_url}{fname}" + try: + print(f"Downloading {url}") + download_and_extract_archive( + url=url, + download_root=self._raw_folder, + filename=fname, + ) + except URLError as error: + print(f"Error downloading {fname}: {error}") From 93e523749d5cf03196a746ec0ec049ecd1e77b8b Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 9 Apr 2021 15:12:54 +0100 Subject: [PATCH 8/8] Removed unnecessary try and print --- torchvision/datasets/kitti.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 74ab644d2da..8db2e45b715 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -1,7 +1,6 @@ import csv import os from typing import Any, Callable, List, Optional, Tuple -from urllib.error import URLError from PIL import Image @@ -155,13 +154,8 @@ def download(self) -> None: # download files for fname in self.resources: - url = f"{self.data_url}{fname}" - try: - print(f"Downloading {url}") - download_and_extract_archive( - url=url, - download_root=self._raw_folder, - filename=fname, - ) - except URLError as error: - print(f"Error downloading {fname}: {error}") + download_and_extract_archive( + url=f"{self.data_url}{fname}", + download_root=self._raw_folder, + filename=fname, + )