Skip to content

Commit 2f38774

Browse files
committed
Add optional classes param to ImageFolder and add docs.
1 parent 429dbeb commit 2f38774

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

torchvision/datasets/folder.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def find_classes(dir):
2323

2424
def make_dataset(dir, class_to_idx):
2525
images = []
26-
for target in os.listdir(dir):
26+
for target in class_to_idx.keys():
2727
d = os.path.join(dir, target)
2828
if not os.path.isdir(d):
2929
continue
@@ -43,10 +43,33 @@ def default_loader(path):
4343

4444

4545
class ImageFolder(data.Dataset):
46-
47-
def __init__(self, root, transform=None, target_transform=None,
48-
loader=default_loader):
49-
classes, class_to_idx = find_classes(root)
46+
"""
47+
A class representing a directory of images as a `Dataset`.
48+
49+
Args:
50+
root (String): The path to the directory.
51+
transform (Object): A callable object that transforms images. See: torchvision/transforms.py. By default no
52+
transformations will be applied to images.
53+
target_transform (Object): A callable object that transforms targets (labels). By default no transformations
54+
will be applied to targets.
55+
loader (Function): Loads an image and returns it in a usable form. By default loads images by
56+
their path and returns a `PIL.Image` instance.
57+
classes (List/Tuple): The sub-directories of `root` that correspond to the classes of this data set. By default
58+
all sub-directories of `root` are used.
59+
60+
Example:
61+
>>> dataset = folder.ImageFolder('./dataset', transform=transforms.Compose([
62+
>>> transforms.Scale(size=224),
63+
>>> transforms.RandomCrop(size=224),
64+
>>> transforms.ToTensor()
65+
>>> ]), classes=['cat', 'dog'])
66+
"""
67+
68+
def __init__(self, root, transform=None, target_transform=None, loader=default_loader, classes=None):
69+
if not classes:
70+
classes, class_to_idx = find_classes(root)
71+
else:
72+
class_to_idx = {classes[i]: i for i in range(len(classes))}
5073
imgs = make_dataset(root, class_to_idx)
5174
if len(imgs) == 0:
5275
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"

0 commit comments

Comments
 (0)