Skip to content

Commit dffbf2d

Browse files
Raise validation error when no transforms passed to RandomApply, RandomChoice and RandomOrder. (#9130)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 98f8b37 commit dffbf2d

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

test/test_transforms_v2.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,14 +2520,29 @@ def test_errors(self):
25202520
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
25212521
cls(lambda x: x)
25222522

2523-
with pytest.raises(ValueError, match="at least one transform"):
2524-
transforms.Compose([])
2523+
for cls in (
2524+
transforms.Compose,
2525+
transforms.RandomApply,
2526+
transforms.RandomChoice,
2527+
transforms.RandomOrder,
2528+
):
2529+
2530+
with pytest.raises(ValueError, match="at least one transform"):
2531+
cls([])
25252532

25262533
for p in [-1, 2]:
25272534
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
25282535
transforms.RandomApply([lambda x: x], p=p)
25292536

2530-
for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
2537+
for transforms_, p in [
2538+
([lambda x: x], []),
2539+
(
2540+
[lambda x: x, lambda x: x],
2541+
[
2542+
1.0,
2543+
],
2544+
),
2545+
]:
25312546
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
25322547
transforms.RandomChoice(transforms_, p=p)
25332548

torchvision/transforms/v2/_container.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa
8787

8888
if not isinstance(transforms, (Sequence, nn.ModuleList)):
8989
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
90+
elif not transforms:
91+
raise ValueError("Pass at least one transform")
9092
self.transforms = transforms
9193

9294
if not (0.0 <= p <= 1.0):
@@ -133,7 +135,8 @@ def __init__(
133135
) -> None:
134136
if not isinstance(transforms, Sequence):
135137
raise TypeError("Argument transforms should be a sequence of callables")
136-
138+
elif not transforms:
139+
raise ValueError("Pass at least one transform")
137140
if p is None:
138141
p = [1] * len(transforms)
139142
elif len(p) != len(transforms):
@@ -163,6 +166,8 @@ class RandomOrder(Transform):
163166
def __init__(self, transforms: Sequence[Callable]) -> None:
164167
if not isinstance(transforms, Sequence):
165168
raise TypeError("Argument transforms should be a sequence of callables")
169+
elif not transforms:
170+
raise ValueError("Pass at least one transform")
166171
super().__init__()
167172
self.transforms = transforms
168173

0 commit comments

Comments
 (0)