Skip to content

[PROTOTYPE] add support for categories in DatasetInfo #4432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def run(self):
# Package info
packages=find_packages(exclude=('test',)),
package_data={
package_name: ['*.dll', '*.dylib', '*.so']
package_name: ['*.dll', '*.dylib', '*.so', 'prototype/datasets/_builtin/*.categories']
},
zip_safe=False,
install_requires=requirements,
Expand Down
10 changes: 10 additions & 0 deletions test/prototype/datasets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest
from torchvision.prototype.datasets.utils import DatasetInfo


@pytest.fixture
def make_minimal_dataset_info():
def make(name="name", categories=None, **kwargs):
return DatasetInfo(name, categories=categories or [], **kwargs)

return make
8 changes: 4 additions & 4 deletions test/prototype/datasets/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchvision.prototype import datasets
from torchvision.prototype.datasets import _api
from torchvision.prototype.datasets import _builtin
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig


@pytest.fixture
Expand All @@ -14,9 +14,9 @@ def patch_datasets(monkeypatch):


@pytest.fixture
def dataset(mocker):
info = DatasetInfo(
"name", valid_options=dict(split=("train", "test"), foo=("bar", "baz"))
def dataset(mocker, make_minimal_dataset_info):
info = make_minimal_dataset_info(
valid_options=dict(split=("train", "test"), foo=("bar", "baz"))
)

class DatasetMock(Dataset):
Expand Down
38 changes: 20 additions & 18 deletions test/prototype/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,52 +122,53 @@ class TestDatasetInfo:
def valid_options():
return dict(split=("train", "test"), foo=("bar", "baz"))

def test_no_valid_options(self):
info = utils.DatasetInfo("name")
@staticmethod
@pytest.fixture
def info(make_minimal_dataset_info, valid_options):
return make_minimal_dataset_info(valid_options=valid_options)

def test_no_valid_options(self, make_minimal_dataset_info):
info = make_minimal_dataset_info()
assert info.default_config.split == "train"

def test_valid_options_no_split(self):
info = utils.DatasetInfo("name", valid_options=dict(option=("argument",)))
def test_valid_options_no_split(self, make_minimal_dataset_info):
info = make_minimal_dataset_info(valid_options=dict(option=("argument",)))
assert info.default_config.split == "train"

def test_valid_options_no_train(self):
def test_valid_options_no_train(self, make_minimal_dataset_info):
with pytest.raises(ValueError):
utils.DatasetInfo("name", valid_options=dict(split=("test",)))
make_minimal_dataset_info(valid_options=dict(split=("test",)))

def test_default_config(self, valid_options):
def test_default_config(self, make_minimal_dataset_info, valid_options):
default_config = utils.DatasetConfig(
{key: values[0] for key, values in valid_options.items()}
)

assert (
utils.DatasetInfo("name", valid_options=valid_options).default_config
make_minimal_dataset_info(valid_options=valid_options).default_config
== default_config
)

def test_make_config_unknown_option(self, valid_options):
info = utils.DatasetInfo("name", valid_options=valid_options)

def test_make_config_unknown_option(self, info):
with pytest.raises(ValueError):
info.make_config(unknown_option=None)

def test_make_config_invalid_argument(self, valid_options):
info = utils.DatasetInfo("name", valid_options=valid_options)

def test_make_config_invalid_argument(self, info):
with pytest.raises(ValueError):
info.make_config(split="unknown_split")

def test_repr(self, valid_options):
output = repr(utils.DatasetInfo("name", valid_options=valid_options))
def test_repr(self, make_minimal_dataset_info, valid_options):
output = repr(make_minimal_dataset_info(valid_options=valid_options))

assert isinstance(output, str)
assert "DatasetInfo" in output
for key, value in valid_options.items():
assert f"{key}={value}" in output

@pytest.mark.parametrize("optional_info", ("citation", "homepage", "license"))
def test_repr_optional_info(self, optional_info):
def test_repr_optional_info(self, make_minimal_dataset_info, optional_info):
sentinel = "sentinel"
info = utils.DatasetInfo("name", **{optional_info: sentinel})
info = make_minimal_dataset_info(**{optional_info: sentinel})

assert f"{optional_info}={sentinel}" in repr(info)

Expand All @@ -183,6 +184,7 @@ def make(name="name", valid_options=None, resources=None):
dict(
info=utils.DatasetInfo(
name,
categories=[],
valid_options=valid_options or dict(split=("train", "test")),
),
resources=mocker.Mock(return_value=[])
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._home import *
from . import decoder, utils, datapipes

# Load this last, since itself but especially _builtin/* depends on the above being available
Expand Down
25 changes: 3 additions & 22 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,17 @@
import difflib
import io
import os
import pathlib
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.hub import _get_torch_home
from torch.utils.data import IterDataPipe

from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils._internal import add_suggestion
from . import _builtin

__all__ = ["home", "register", "list", "info", "load"]


# TODO: This needs a better default
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"


def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME
if home is not None:
HOME = pathlib.Path(home).expanduser().resolve()
return HOME

home = os.getenv("TORCHVISION_DATASETS_HOME")
if home is not None:
return pathlib.Path(home)

return HOME
__all__ = ["register", "list", "info", "load"]


DATASETS: Dict[str, Dataset] = {}
Expand Down
Loading