Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,36 @@ def make_archive(stack, root, name):
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)

yield root, data


@contextlib.contextmanager
def kitti_root():
def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)

def _make_train_archive(root):
extracted_dir = os.path.join(root, 'training', 'image_2')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '00000.png'))

def _make_target_archive(root):
extracted_dir = os.path.join(root, 'training', 'label_2')
os.makedirs(extracted_dir)
target_contents = 'Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n'
target_file = os.path.join(extracted_dir, '00000.txt')
with open(target_file, "w") as txt_file:
txt_file.write(target_contents)

def _make_test_archive(root):
extracted_dir = os.path.join(root, 'testing', 'image_2')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '00001.png'))

with get_tmp_dir() as root:
raw_dir = os.path.join(root, "Kitti", "raw")
os.makedirs(raw_dir)
_make_train_archive(raw_dir)
_make_target_archive(raw_dir)
_make_test_archive(raw_dir)

yield root
20 changes: 19 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import svhn_root, places365_root, widerface_root, stl10_root
from fakedata_generation import (
kitti_root,
places365_root,
stl10_root,
svhn_root,
)
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -155,6 +160,19 @@ def test_places365_repr_smoke(self):
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)

def test_kitti(self):
with kitti_root() as root:
dataset = torchvision.datasets.Kitti(root)
self.assertEqual(len(dataset), 1)
img, target = dataset[0][0], dataset[0][1]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.Kitti(root, split='test')
self.assertEqual(len(dataset), 1)
img, target = dataset[0][0], dataset[0][1]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertEqual(target, None)


class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
Expand Down
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .hmdb51 import HMDB51
from .ucf101 import UCF101
from .places365 import Places365
from .kitti import Kitti

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
Expand All @@ -34,4 +35,5 @@
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
'Places365')
'Places365', 'Kitti',
)
166 changes: 166 additions & 0 deletions torchvision/datasets/kitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import csv
import os
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Tuple
from urllib.error import URLError

from PIL import Image

from .utils import download_and_extract_archive
from .vision import VisionDataset


class Kitti(VisionDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
Expects the following folder structure if download=False:

.. code::

<root>
└─ Kitti
└─ raw
├── training
| ├── image_2
| └── label_2
└── testing
└── image_2
split (string): The dataset split to use. One of {``train``, ``test``}.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample
and its target as entry and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

mirrors = [
"https://s3.eu-central-1.amazonaws.com/avg-kitti/",
]
resources = [
"data_object_image_2.zip",
"data_object_label_2.zip",
]
image_dir_name = "image_2"
labels_dir_name = "label_2"

def __init__(
self,
root: str,
split: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
download: bool = False,
):
super().__init__(
root,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)
self.images = []
self.targets = []
self.root = root
self.split = split

if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found. You may use download=True to download it."
)

location = "testing" if self.split == "test" else "training"
image_dir = os.path.join(self.raw_folder, location, self.image_dir_name)
if location == "training":
labels_dir = os.path.join(self.raw_folder, location, self.labels_dir_name)
for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file))
if location == "training":
self.targets.append(
os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")
)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index.
Args:
index (int): Index
Returns:
tuple: (image, target), where
target is a dictionary with the following keys:
type: Int64Tensor[N]
truncated: FloatTensor[N]
occluded: Int64Tensor[N]
alpha: FloatTensor[N]
bbox: FloatTensor[N, 4]
dimensions: FloatTensor[N, 3]
locations: FloatTensor[N, 3]
rotation_y: FloatTensor[N]
score: FloatTensor[N]
"""
image = Image.open(self.images[index])
target = None if self.split == "test" else self._parse_target(index)
if self.transforms:
image, target = self.transforms(image, target)
return image, target

def _parse_target(self, index: int) -> Dict[str, Any]:
target: Dict[str, Any] = defaultdict(list)
with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ")
for line in content:
target["type"].append(line[0])
target["truncated"].append(line[1])
target["occluded"].append(line[2])
target["alpha"].append(line[3])
target["bbox"].append(line[4:8])
target["dimensions"].append(line[8:11])
target["location"].append(line[11:14])
target["rotation_y"].append(line[14])
return target

def __len__(self) -> int:
return len(self.images)

@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")

def _check_exists(self) -> bool:
"""Check if the data directory exists."""
location = "testing" if self.split == "test" else "training"
folders = [self.image_dir_name]
if self.split != "test":
folders.append(self.labels_dir_name)
return all(
os.path.isdir(os.path.join(self.raw_folder, location, fname))
for fname in folders
)

def download(self) -> None:
"""Download the KITTI data if it doesn't exist already."""

if self._check_exists():
return

os.makedirs(self.raw_folder, exist_ok=True)

# download files
for fname in self.resources:
for mirror in self.mirrors:
url = f"{mirror}{fname}"
try:
print(f"Downloading {url}")
download_and_extract_archive(
url=url,
download_root=self.raw_folder,
filename=fname,
)
except URLError as error:
print(f"Error downloading {fname}: {error}")