@@ -21,7 +21,7 @@ class CelebA(VisionDataset):
21
21
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
22
22
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
23
23
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
24
- Defaults to ``attr``.
24
+ Defaults to ``attr``. If empty, ``None`` will be returned as target.
25
25
transform (callable, optional): A function/transform that takes in an PIL image
26
26
and returns a transformed version. E.g, ``transforms.ToTensor``
27
27
target_transform (callable, optional): A function/transform that takes in the
@@ -59,6 +59,9 @@ def __init__(self, root, split="train", target_type="attr", transform=None,
59
59
else :
60
60
self .target_type = [target_type ]
61
61
62
+ if not self .target_type and self .target_transform is not None :
63
+ raise RuntimeError ('target_transform is specified but target_type is empty' )
64
+
62
65
if download :
63
66
self .download ()
64
67
@@ -133,13 +136,17 @@ def __getitem__(self, index):
133
136
else :
134
137
# TODO: refactor with utils.verify_str_arg
135
138
raise ValueError ("Target type \" {}\" is not recognized." .format (t ))
136
- target = tuple (target ) if len (target ) > 1 else target [0 ]
137
139
138
140
if self .transform is not None :
139
141
X = self .transform (X )
140
142
141
- if self .target_transform is not None :
142
- target = self .target_transform (target )
143
+ if target :
144
+ target = tuple (target ) if len (target ) > 1 else target [0 ]
145
+
146
+ if self .target_transform is not None :
147
+ target = self .target_transform (target )
148
+ else :
149
+ target = None
143
150
144
151
return X , target
145
152
0 commit comments