Skip to content

Expand tests for prototype datasets #5187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
735 changes: 702 additions & 33 deletions test/builtin_dataset_mocks.py

Large diffs are not rendered by default.

37 changes: 34 additions & 3 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,28 @@
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import transforms
from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str


@parametrize_dataset_mocks(DATASET_MOCKS)
def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
if untested_datasets:
raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
f"Please add mock data to `test/builtin_dataset_mocks.py`."
)


class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

Expand All @@ -31,6 +42,7 @@ def test_sample(self, dataset_mock, config):
if not sample:
raise AssertionError("Sample dictionary is empty.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)

Expand All @@ -40,6 +52,7 @@ def test_num_samples(self, dataset_mock, config):

assert num_samples == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

Expand All @@ -50,6 +63,7 @@ def test_decoding(self, dataset_mock, config):
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

Expand All @@ -60,16 +74,33 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

next(iter(dataset.map(transforms.Identity())))

@parametrize_dataset_mocks(
DATASET_MOCKS,
marks={
"cub200": pytest.mark.xfail(
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
)
},
)
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

traverse(dataset)

@parametrize_dataset_mocks(
DATASET_MOCKS,
marks={
"cub200": pytest.mark.xfail(
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
)
},
)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph):
Expand All @@ -86,8 +117,8 @@ def scan(graph):
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

Expand Down
14 changes: 8 additions & 6 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Feature, Label, BoundingBox

csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)

Expand Down Expand Up @@ -67,6 +68,7 @@ def _make_info(self) -> DatasetInfo:
"celeba",
type=DatasetType.IMAGE,
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
Expand Down Expand Up @@ -104,7 +106,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:

_SPLIT_ID_TO_NAME = {
"0": "train",
"1": "valid",
"1": "val",
"2": "test",
}

Expand All @@ -117,22 +119,22 @@ def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[s

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]],
data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, _, image_data = split_and_image_data
_, (_, image_data) = split_and_image_data
path, buffer = image_data
_, ann = ann_data

image = decoder(buffer) if decoder else buffer

identity = int(ann["identity"]["identity"])
identity = Label(int(ann["identity"]["identity"]))
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
}

Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name

def _2011_decode_ann(
def _2011_load_ann(
self,
data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]],
*,
Expand All @@ -126,7 +126,7 @@ def _2010_anns_key(self, data: Tuple[str, io.IOBase]) -> Tuple[str, Tuple[str, i
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data

def _2010_decode_ann(
def _2010_load_ann(
self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
_, (path, buffer) = data
Expand Down Expand Up @@ -154,7 +154,7 @@ def _collate_and_decode_sample(
label_str, category = dir_name.split(".")

return dict(
(self._2011_decode_ann if year == "2011" else self._2010_decode_ann)(anns_data, decoder=decoder),
(self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder),
image=decoder(buffer) if decoder else buffer,
label=Label(int(label_str), category=category),
)
Expand Down Expand Up @@ -196,7 +196,7 @@ def _make_datapipe(
else: # config.year == "2010"
split_dp, images_dp, anns_dp = resource_dps

split_dp = Filter(split_dp, path_comparator("stem", config.split))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -126,7 +127,7 @@ def _make_datapipe(
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES
Expand Down
6 changes: 4 additions & 2 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Feature


class SBD(Dataset):
Expand Down Expand Up @@ -83,11 +84,11 @@ def _decode_ann(

# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries = (
torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
if decode_boundaries
else None
)
segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None
segmentation = Feature(raw_segmentation) if decode_segmentation else None

return boundaries, segmentation

Expand Down Expand Up @@ -140,6 +141,7 @@ def _make_datapipe(

if config.split == "train_noval":
split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("stem", config.split))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label


class SEMEION(Dataset):
Expand Down Expand Up @@ -46,14 +47,13 @@ def _collate_and_decode_sample(
label_data = [int(label) for label in data[256:] if label]

if decoder is raw:
image = image_data.unsqueeze(0)
image = Image(image_data.unsqueeze(0))
else:
image_buffer = image_buffer_from_array(image_data.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]

label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
category = self.info.categories[label]
return dict(image=image, label=label, category=category)
label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx]))

def _make_datapipe(
self,
Expand Down
46 changes: 31 additions & 15 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,50 @@
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox

HERE = pathlib.Path(__file__).parent

class VOCDatasetInfo(DatasetInfo):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")

def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
if config.split == "test" and config.year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")

return config


class VOC(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
return VOCDatasetInfo(
"voc",
type=DatasetType.IMAGE,
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict(
split=("train", "val", "test"),
year=("2012",),
split=("train", "val", "trainval", "test"),
year=("2012", "2007", "2008", "2009", "2010", "2011"),
task=("detection", "segmentation"),
),
)

_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
}
_TEST_ARCHIVES = {
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.year == "2012":
if config.split == "train":
archive = HttpResource(
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
sha256="e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb",
)
else:
raise RuntimeError("FIXME")
else:
raise RuntimeError("FIXME")
file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256)
return [archive]

_ANNS_FOLDER = dict(
Expand Down Expand Up @@ -88,7 +104,7 @@ def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor:
objects = result["annotation"]["object"]
bboxes = [obj["bndbox"] for obj in objects]
bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes]
return torch.tensor(bboxes)
return BoundingBox(bboxes)

def _collate_and_decode_sample(
self,
Expand Down