Skip to content

Commit 3f6c23c

Browse files
Choco31415fmassa
authored andcommitted
Addresses issue #145 as per @fmessa's suggestion. (#527)
* Addresses issue #145 as per @fmessa's suggestion. * Removed blank line for styling.
1 parent 5a0d079 commit 3f6c23c

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

torchvision/datasets/folder.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,10 @@ def is_image_file(filename):
3232
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
3333

3434

35-
def find_classes(dir):
36-
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
37-
classes.sort()
38-
class_to_idx = {classes[i]: i for i in range(len(classes))}
39-
return classes, class_to_idx
40-
41-
4235
def make_dataset(dir, class_to_idx, extensions):
4336
images = []
4437
dir = os.path.expanduser(dir)
45-
for target in sorted(os.listdir(dir)):
38+
for target in sorted(class_to_idx.keys()):
4639
d = os.path.join(dir, target)
4740
if not os.path.isdir(d):
4841
continue
@@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset):
8679
"""
8780

8881
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
89-
classes, class_to_idx = find_classes(root)
82+
classes, class_to_idx = self._find_classes(root)
9083
samples = make_dataset(root, class_to_idx, extensions)
9184
if len(samples) == 0:
9285
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
@@ -104,6 +97,24 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
10497
self.transform = transform
10598
self.target_transform = target_transform
10699

100+
def _find_classes(self, dir):
101+
"""
102+
Finds the class folders in a dataset.
103+
104+
Args:
105+
dir (string): Root directory path.
106+
107+
Returns:
108+
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
109+
110+
Ensures:
111+
No class is a subdirectory of another.
112+
"""
113+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
114+
classes.sort()
115+
class_to_idx = {classes[i]: i for i in range(len(classes))}
116+
return classes, class_to_idx
117+
107118
def __getitem__(self, index):
108119
"""
109120
Args:

0 commit comments

Comments
 (0)