Skip to content

Commit 6bae6c1

Browse files
authored
Merge branch 'main' into main
2 parents 4922635 + c27bed4 commit 6bae6c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+233
-121
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
7070
PCAM
7171
PhotoTour
7272
Places365
73+
RenderedSST2
7374
QMNIST
7475
SBDataset
7576
SBU

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,7 @@ ignore_missing_imports = True
121121
[mypy-torchdata.*]
122122

123123
ignore_missing_imports = True
124+
125+
[mypy-h5py.*]
126+
127+
ignore_missing_imports = True

test/test_datasets.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,11 +2281,6 @@ def inject_fake_data(self, tmpdir: str, config):
22812281
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
22822282
DATASET_CLASS = datasets.SUN397
22832283

2284-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2285-
split=("train", "test"),
2286-
partition=(1, 10, None),
2287-
)
2288-
22892284
def inject_fake_data(self, tmpdir: str, config):
22902285
data_dir = pathlib.Path(tmpdir) / "SUN397"
22912286
data_dir.mkdir()
@@ -2308,18 +2303,7 @@ def inject_fake_data(self, tmpdir: str, config):
23082303
with open(data_dir / "ClassName.txt", "w") as file:
23092304
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
23102305

2311-
if config["partition"] is not None:
2312-
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
2313-
2314-
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
2315-
file.writelines(
2316-
"\n".join(
2317-
f"/{f_path.relative_to(data_dir).as_posix()}"
2318-
for f_path in random.choices(im_paths, k=num_samples)
2319-
)
2320-
)
2321-
else:
2322-
num_samples = len(im_paths)
2306+
num_samples = len(im_paths)
23232307

23242308
return num_samples
23252309

@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
23972381
DATASET_CLASS = datasets.GTSRB
23982382
FEATURE_TYPES = (PIL.Image.Image, int)
23992383

2400-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2384+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
24012385

24022386
def inject_fake_data(self, tmpdir: str, config):
2403-
root_folder = os.path.join(tmpdir, "GTSRB")
2387+
root_folder = os.path.join(tmpdir, "gtsrb")
24042388
os.makedirs(root_folder, exist_ok=True)
24052389

24062390
# Train data
2407-
train_folder = os.path.join(root_folder, "Training")
2391+
train_folder = os.path.join(root_folder, "GTSRB", "Training")
24082392
os.makedirs(train_folder, exist_ok=True)
24092393

2410-
num_examples = 3
2394+
num_examples = 3 if config["split"] == "train" else 4
24112395
classes = ("00000", "00042", "00012")
24122396
for class_idx in classes:
24132397
datasets_utils.create_image_folder(
@@ -2419,7 +2403,7 @@ def inject_fake_data(self, tmpdir: str, config):
24192403

24202404
total_number_of_examples = num_examples * len(classes)
24212405
# Test data
2422-
test_folder = os.path.join(root_folder, "Final_Test", "Images")
2406+
test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
24232407
os.makedirs(test_folder, exist_ok=True)
24242408

24252409
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
@@ -2665,5 +2649,27 @@ def inject_fake_data(self, tmpdir: str, config):
26652649
return num_images
26662650

26672651

2652+
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
2653+
DATASET_CLASS = datasets.RenderedSST2
2654+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2655+
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
2656+
2657+
def inject_fake_data(self, tmpdir: str, config):
2658+
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
2659+
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
2660+
2661+
num_images_per_class = {"train": 5, "test": 6, "val": 7}
2662+
sampled_classes = ["positive", "negative"]
2663+
for cls in sampled_classes:
2664+
datasets_utils.create_image_folder(
2665+
image_folder,
2666+
cls,
2667+
file_name_fn=lambda idx: f"{idx}.png",
2668+
num_examples=num_images_per_class[config["split"]],
2669+
)
2670+
2671+
return len(sampled_classes) * num_images_per_class[config["split"]]
2672+
2673+
26682674
if __name__ == "__main__":
26692675
unittest.main()

test/test_prototype_datasets_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,16 @@ def test_default_config(self, info):
126126
assert info.default_config == default_config
127127

128128
@pytest.mark.parametrize(
129-
("options", "expected_error_msg"),
129+
("valid_options", "options", "expected_error_msg"),
130130
[
131-
pytest.param(dict(unknown_option=None), "Unknown option 'unknown_option'", id="unknown_option"),
132-
pytest.param(dict(split="unknown_split"), "Invalid argument 'unknown_split'", id="invalid_argument"),
131+
(dict(), dict(any_option=None), "does not take any options"),
132+
(dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"),
133+
(dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"),
133134
],
134135
)
135-
def test_make_config_invalid_inputs(self, info, options, expected_error_msg):
136+
def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg):
137+
info = make_minimal_dataset_info(valid_options=valid_options)
138+
136139
with pytest.raises(ValueError, match=expected_error_msg):
137140
info.make_config(**options)
138141

test/test_prototype_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_naming_conventions(model_fn):
9797
)
9898
@run_if_test_with_prototype
9999
def test_schema_meta_validation(model_fn):
100-
classification_fields = ["size", "categories", "acc@1", "acc@5"]
100+
classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"]
101101
defaults = {
102102
"all": ["task", "architecture", "publication_year", "interpolation", "recipe", "num_params"],
103103
"models": classification_fields,

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .pcam import PCAM
3030
from .phototour import PhotoTour
3131
from .places365 import Places365
32+
from .rendered_sst2 import RenderedSST2
3233
from .sbd import SBDataset
3334
from .sbu import SBU
3435
from .semeion import SEMEION
@@ -102,4 +103,5 @@
102103
"Country211",
103104
"FGVCAircraft",
104105
"EuroSAT",
106+
"RenderedSST2",
105107
)

torchvision/datasets/clevr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
split: str = "train",
3535
transform: Optional[Callable] = None,
3636
target_transform: Optional[Callable] = None,
37-
download: bool = True,
37+
download: bool = False,
3838
) -> None:
3939
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4040
super().__init__(root, transform=transform, target_transform=target_transform)

torchvision/datasets/country211.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
split: str = "train",
3333
transform: Optional[Callable] = None,
3434
target_transform: Optional[Callable] = None,
35-
download: bool = True,
35+
download: bool = False,
3636
) -> None:
3737
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3838

torchvision/datasets/dtd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ class DTD(VisionDataset):
2121
The partition only changes which split each image belongs to. Thus, regardless of the selected
2222
partition, combining all splits will result in all images.
2323
24-
download (bool, optional): If True, downloads the dataset from the internet and
25-
puts it in root directory. If dataset is already downloaded, it is not
26-
downloaded again.
2724
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2825
version. E.g, ``transforms.RandomCrop``.
2926
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
download (bool, optional): If True, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again. Default is False.
3030
"""
3131

3232
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
@@ -37,9 +37,9 @@ def __init__(
3737
root: str,
3838
split: str = "train",
3939
partition: int = 1,
40-
download: bool = True,
4140
transform: Optional[Callable] = None,
4241
target_transform: Optional[Callable] = None,
42+
download: bool = False,
4343
) -> None:
4444
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4545
if not isinstance(partition, int) and not (1 <= partition <= 10):

torchvision/datasets/eurosat.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any
2+
from typing import Callable, Optional
33

44
from .folder import ImageFolder
55
from .utils import download_and_extract_archive
@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):
1010
1111
Args:
1212
root (string): Root directory of dataset where ``root/eurosat`` exists.
13-
download (bool, optional): If True, downloads the dataset from the internet and
14-
puts it in root directory. If dataset is already downloaded, it is not
15-
downloaded again. Default is False.
1613
transform (callable, optional): A function/transform that takes in an PIL image
1714
and returns a transformed version. E.g, ``transforms.RandomCrop``
1815
target_transform (callable, optional): A function/transform that takes in the
1916
target and transforms it.
17+
download (bool, optional): If True, downloads the dataset from the internet and
18+
puts it in root directory. If dataset is already downloaded, it is not
19+
downloaded again. Default is False.
2020
"""
2121

22-
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
23-
md5 = "c8fa014336c82ac7804f0398fcb19387"
24-
2522
def __init__(
2623
self,
2724
root: str,
25+
transform: Optional[Callable] = None,
26+
target_transform: Optional[Callable] = None,
2827
download: bool = False,
29-
**kwargs: Any,
3028
) -> None:
3129
self.root = os.path.expanduser(root)
3230
self._base_folder = os.path.join(self.root, "eurosat")
@@ -38,7 +36,7 @@ def __init__(
3836
if not self._check_exists():
3937
raise RuntimeError("Dataset not found. You can use download=True to download it")
4038

41-
super().__init__(self._data_folder, **kwargs)
39+
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
4240
self.root = os.path.expanduser(root)
4341

4442
def __len__(self) -> int:
@@ -53,4 +51,8 @@ def download(self) -> None:
5351
return
5452

5553
os.makedirs(self._base_folder, exist_ok=True)
56-
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
54+
download_and_extract_archive(
55+
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
56+
download_root=self._base_folder,
57+
md5="c8fa014336c82ac7804f0398fcb19387",
58+
)

0 commit comments

Comments
 (0)