@@ -32,17 +32,10 @@ def is_image_file(filename):
32
32
return has_file_allowed_extension (filename , IMG_EXTENSIONS )
33
33
34
34
35
- def find_classes (dir ):
36
- classes = [d for d in os .listdir (dir ) if os .path .isdir (os .path .join (dir , d ))]
37
- classes .sort ()
38
- class_to_idx = {classes [i ]: i for i in range (len (classes ))}
39
- return classes , class_to_idx
40
-
41
-
42
35
def make_dataset (dir , class_to_idx , extensions ):
43
36
images = []
44
37
dir = os .path .expanduser (dir )
45
- for target in sorted (os . listdir ( dir )):
38
+ for target in sorted (class_to_idx . keys ( )):
46
39
d = os .path .join (dir , target )
47
40
if not os .path .isdir (d ):
48
41
continue
@@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset):
86
79
"""
87
80
88
81
def __init__ (self , root , loader , extensions , transform = None , target_transform = None ):
89
- classes , class_to_idx = find_classes (root )
82
+ classes , class_to_idx = self . _find_classes (root )
90
83
samples = make_dataset (root , class_to_idx , extensions )
91
84
if len (samples ) == 0 :
92
85
raise (RuntimeError ("Found 0 files in subfolders of: " + root + "\n "
@@ -104,6 +97,24 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
104
97
self .transform = transform
105
98
self .target_transform = target_transform
106
99
100
+ def _find_classes (self , dir ):
101
+ """
102
+ Finds the class folders in a dataset.
103
+
104
+ Args:
105
+ dir (string): Root directory path.
106
+
107
+ Returns:
108
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
109
+
110
+ Ensures:
111
+ No class is a subdirectory of another.
112
+ """
113
+ classes = [d for d in os .listdir (dir ) if os .path .isdir (os .path .join (dir , d ))]
114
+ classes .sort ()
115
+ class_to_idx = {classes [i ]: i for i in range (len (classes ))}
116
+ return classes , class_to_idx
117
+
107
118
def __getitem__ (self , index ):
108
119
"""
109
120
Args:
0 commit comments