Skip to content

Commit 71073cb

Browse files
authored
add sequence fill support for ElasticTransform (#7141)
1 parent 2bc8a14 commit 71073cb

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

test/test_transforms_tensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,35 @@ def test_gaussian_blur(device, channels, meth_kwargs):
858858
agg_method="max",
859859
tol=tol,
860860
)
861+
862+
863+
@pytest.mark.parametrize("device", cpu_and_gpu())
864+
@pytest.mark.parametrize(
865+
"fill",
866+
[
867+
1,
868+
1.0,
869+
[1],
870+
[1.0],
871+
(1,),
872+
(1.0,),
873+
[1, 2, 3],
874+
[1.0, 2.0, 3.0],
875+
(1, 2, 3),
876+
(1.0, 2.0, 3.0),
877+
],
878+
)
879+
@pytest.mark.parametrize("channels", [1, 3])
880+
def test_elastic_transform(device, channels, fill):
881+
if isinstance(fill, (list, tuple)) and len(fill) > 1 and channels == 1:
882+
# For this the test would correctly fail, since the number of channels in the image does not match `fill`.
883+
# Thus, this is not an issue in the transform, but rather a problem of parametrization that just gives the
884+
# product of `fill` and `channels`.
885+
return
886+
887+
_test_class_op(
888+
T.ElasticTransform,
889+
meth_kwargs=dict(fill=fill),
890+
channels=channels,
891+
device=device,
892+
)

torchvision/transforms/functional.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,8 +1539,6 @@ def elastic_transform(
15391539
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
15401540
If a tuple of length 3, it is used to fill R, G, B channels respectively.
15411541
This value is only used when the padding_mode is constant.
1542-
Only number is supported for torch Tensor.
1543-
Only int or str or tuple value is supported for PIL Image.
15441542
"""
15451543
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
15461544
_log_api_usage_once(elastic_transform)

torchvision/transforms/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,8 +2104,12 @@ def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINE
21042104
interpolation = _interpolation_modes_from_int(interpolation)
21052105
self.interpolation = interpolation
21062106

2107-
if not isinstance(fill, (int, float)):
2108-
raise TypeError(f"fill should be int or float. Got {type(fill)}")
2107+
if isinstance(fill, (int, float)):
2108+
fill = [float(fill)]
2109+
elif isinstance(fill, (list, tuple)):
2110+
fill = [float(f) for f in fill]
2111+
else:
2112+
raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
21092113
self.fill = fill
21102114

21112115
@staticmethod

0 commit comments

Comments
 (0)