Skip to content

Commit 6285b31

Browse files
committed
Fix bug in tests and in food101 dataset
1 parent 23f685a commit 6285b31

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

test/test_datasets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,12 +2195,13 @@ def inject_fake_data(self, tmpdir: str, config):
21952195
num_examples=num_images_per_class,
21962196
)
21972197
metadata[cls] = [
2198-
fname.parent.name + "/" + fname.name for fname in random.choices(im_fnames, k=n_samples_per_class)
2198+
"/".join(fname.relative_to(image_folder).with_suffix("").parts)
2199+
for fname in random.choices(im_fnames, k=n_samples_per_class)
21992200
]
22002201

22012202
with open(meta_folder / f"{config['split']}.json", "w") as file:
22022203
file.write(json.dumps(metadata))
2203-
print(metadata)
2204+
22042205
return len(sampled_classes * n_samples_per_class)
22052206

22062207

torchvision/datasets/food101.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def __init__(
3939
) -> None:
4040
super().__init__(root, transform=transform, target_transform=target_transform)
4141
self._split = verify_str_arg(split, "split", ("train", "test"))
42-
self._root_path = Path(self.root)
43-
self._base_folder = self._root_path / "food-101"
42+
self._base_folder = Path(self._root_path) / "food-101"
4443
self._meta_folder = self._base_folder / "meta"
4544
self._images_folder = self._base_folder / "images"
4645

@@ -61,7 +60,7 @@ def __init__(
6160
for class_label, im_rel_paths in metadata.items():
6261
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
6362
self._image_files += [
64-
self._images_folder.joinpath(*f"{im_rel_path}".split("/")) for im_rel_path in im_rel_paths
63+
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
6564
]
6665

6766
def __len__(self) -> int:
@@ -88,4 +87,4 @@ def _check_exists(self) -> bool:
8887
def _download(self) -> None:
8988
if self._check_exists():
9089
return
91-
download_and_extract_archive(self._URL, download_root=str(self.root), md5=self._MD5)
90+
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

0 commit comments

Comments
 (0)