Skip to content

Commit 76c04d6

Browse files
dnikufmassa
authored andcommitted
Support empty target_type for CelebA dataset (#1351)
* Support empty target_type for CelebA dataset * Return (X, None) for interface consistency * Document behavior for target_type=[] * Simplify diff * Raise exception on meaningless parameters
1 parent ef67fd9 commit 76c04d6

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

torchvision/datasets/celeba.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CelebA(VisionDataset):
2121
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
2222
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
2323
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
24-
Defaults to ``attr``.
24+
Defaults to ``attr``. If empty, ``None`` will be returned as target.
2525
transform (callable, optional): A function/transform that takes in an PIL image
2626
and returns a transformed version. E.g, ``transforms.ToTensor``
2727
target_transform (callable, optional): A function/transform that takes in the
@@ -59,6 +59,9 @@ def __init__(self, root, split="train", target_type="attr", transform=None,
5959
else:
6060
self.target_type = [target_type]
6161

62+
if not self.target_type and self.target_transform is not None:
63+
raise RuntimeError('target_transform is specified but target_type is empty')
64+
6265
if download:
6366
self.download()
6467

@@ -133,13 +136,17 @@ def __getitem__(self, index):
133136
else:
134137
# TODO: refactor with utils.verify_str_arg
135138
raise ValueError("Target type \"{}\" is not recognized.".format(t))
136-
target = tuple(target) if len(target) > 1 else target[0]
137139

138140
if self.transform is not None:
139141
X = self.transform(X)
140142

141-
if self.target_transform is not None:
142-
target = self.target_transform(target)
143+
if target:
144+
target = tuple(target) if len(target) > 1 else target[0]
145+
146+
if self.target_transform is not None:
147+
target = self.target_transform(target)
148+
else:
149+
target = None
143150

144151
return X, target
145152

0 commit comments

Comments
 (0)