@@ -23,7 +23,7 @@ def find_classes(dir):
23
23
24
24
def make_dataset (dir , class_to_idx ):
25
25
images = []
26
- for target in os . listdir ( dir ):
26
+ for target in class_to_idx . keys ( ):
27
27
d = os .path .join (dir , target )
28
28
if not os .path .isdir (d ):
29
29
continue
@@ -43,10 +43,33 @@ def default_loader(path):
43
43
44
44
45
45
class ImageFolder (data .Dataset ):
46
-
47
- def __init__ (self , root , transform = None , target_transform = None ,
48
- loader = default_loader ):
49
- classes , class_to_idx = find_classes (root )
46
+ """
47
+ A class representing a directory of images as a `Dataset`.
48
+
49
+ Args:
50
+ root (String): The path to the directory.
51
+ transform (Object): A callable object that transforms images. See: torchvision/transforms.py. By default no
52
+ transformations will be applied to images.
53
+ target_transform (Object): A callable object that transforms targets (labels). By default no transformations
54
+ will be applied to targets.
55
+ loader (Function): Loads an image and returns it in a usable form. By default loads images by
56
+ their path and returns a `PIL.Image` instance.
57
+ classes (List/Tuple): The sub-directories of `root` that correspond to the classes of this data set. By default
58
+ all sub-directories of `root` are used.
59
+
60
+ Example:
61
+ >>> dataset = folder.ImageFolder('./dataset', transform=transforms.Compose([
62
+ >>> transforms.Scale(size=224),
63
+ >>> transforms.RandomCrop(size=224),
64
+ >>> transforms.ToTensor()
65
+ >>> ]), classes=['cat', 'dog'])
66
+ """
67
+
68
+ def __init__ (self , root , transform = None , target_transform = None , loader = default_loader , classes = None ):
69
+ if not classes :
70
+ classes , class_to_idx = find_classes (root )
71
+ else :
72
+ class_to_idx = {classes [i ]: i for i in range (len (classes ))}
50
73
imgs = make_dataset (root , class_to_idx )
51
74
if len (imgs ) == 0 :
52
75
raise (RuntimeError ("Found 0 images in subfolders of: " + root + "\n "
0 commit comments