Skip to content

Commit d7d90f5

Browse files
authored
handle inplace operations in _Feature.__torch_function__ (#6671)
* prevent feature wrapping for inplace ops * cleanup * mypy * refactor __torch_function__ to be more concise * avoid double lookup * fix normalize * refactor normalize * mypy
1 parent f7f38f1 commit d7d90f5

File tree

5 files changed

+82
-28
lines changed

5 files changed

+82
-28
lines changed

test/test_prototype_features.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23
from torchvision.prototype import features
34

@@ -48,6 +49,19 @@ def test_clone_wrapping():
4849
assert label_clone.categories is label.categories
4950

5051

52+
def test_requires_grad__wrapping():
53+
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
54+
label = features.Label(tensor, categories=["foo", "bar"])
55+
56+
assert not label.requires_grad
57+
58+
label_requires_grad = label.requires_grad_(True)
59+
60+
assert type(label_requires_grad) is features.Label
61+
assert label.requires_grad
62+
assert label_requires_grad.requires_grad
63+
64+
5165
def test_other_op_no_wrapping():
5266
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
5367
label = features.Label(tensor, categories=["foo", "bar"])
@@ -58,6 +72,33 @@ def test_other_op_no_wrapping():
5872
assert type(output) is torch.Tensor
5973

6074

75+
@pytest.mark.parametrize(
76+
"op",
77+
[
78+
lambda t: t.numpy(),
79+
lambda t: t.tolist(),
80+
lambda t: t.max(dim=-1),
81+
],
82+
)
83+
def test_no_tensor_output_op_no_wrapping(op):
84+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
85+
label = features.Label(tensor, categories=["foo", "bar"])
86+
87+
output = op(label)
88+
89+
assert type(output) is not features.Label
90+
91+
92+
def test_inplace_op_no_wrapping():
93+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
94+
label = features.Label(tensor, categories=["foo", "bar"])
95+
96+
output = label.add_(0)
97+
98+
assert type(output) is torch.Tensor
99+
assert type(label) is features.Label
100+
101+
61102
def test_new_like():
62103
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
63104
label = features.Label(tensor, categories=["foo", "bar"])

test/test_prototype_transforms_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -907,15 +907,15 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
907907
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
908908

909909

910-
def test_midlevel_normalize_output_type():
910+
def test_normalize_output_type():
911911
inpt = torch.rand(1, 3, 32, 32)
912912
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
913-
assert isinstance(output, torch.Tensor)
913+
assert type(output) is torch.Tensor
914914
torch.testing.assert_close(inpt - 0.5, output)
915915

916916
inpt = make_image(color_space=features.ColorSpace.RGB)
917917
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
918-
assert isinstance(output, torch.Tensor)
918+
assert type(output) is torch.Tensor
919919
torch.testing.assert_close(inpt - 0.5, output)
920920

921921

torchvision/prototype/features/_feature.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def new_like(
5858
**kwargs,
5959
)
6060

61+
_NO_WRAPPING_EXCEPTIONS = {
62+
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
63+
torch.Tensor.to: lambda cls, input, output: cls.new_like(
64+
input, output, dtype=output.dtype, device=output.device
65+
),
66+
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
67+
# retains the type automatically
68+
torch.Tensor.requires_grad_: lambda cls, input, output: output,
69+
}
70+
6171
@classmethod
6272
def __torch_function__(
6373
cls,
@@ -73,19 +83,15 @@ def __torch_function__(
7383
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
7484
``args`` and ``kwargs`` of the original call.
7585
76-
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
86+
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature`
7787
use case, this has two downsides:
7888
7989
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
8090
``return cls(func(*args, **kwargs))``, will fail for them.
8191
2. For most operations, there is no way of knowing if the input type is still valid for the output.
8292
83-
For these reasons, the automatic output wrapping is turned off for most operators.
84-
85-
Exceptions to this are:
86-
87-
- :meth:`torch.Tensor.clone`
88-
- :meth:`torch.Tensor.to`
93+
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
94+
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS`
8995
"""
9096
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
9197
# need to reimplement the functionality.
@@ -96,18 +102,21 @@ def __torch_function__(
96102
with DisableTorchFunction():
97103
output = func(*args, **kwargs or dict())
98104

99-
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
100-
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
101-
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
102-
# case.
103-
if not isinstance(args[0], cls):
104-
return output
105+
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
106+
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
107+
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
108+
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
109+
# `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with
110+
# `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would
111+
# be wrapped into a `features.Image`.
112+
if wrapper and isinstance(args[0], cls):
113+
return wrapper(cls, args[0], output) # type: ignore[no-any-return]
114+
115+
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
116+
# will retain the input type. Thus, we need to unwrap here.
117+
if isinstance(output, cls):
118+
return output.as_subclass(torch.Tensor) # type: ignore[arg-type]
105119

106-
if func is torch.Tensor.clone:
107-
return cls.new_like(args[0], output)
108-
elif func is torch.Tensor.to:
109-
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
110-
else:
111120
return output
112121

113122
def _make_repr(self, **kwargs: Any) -> str:

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
def normalize(
1313
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
1414
) -> torch.Tensor:
15-
if not isinstance(inpt, torch.Tensor):
16-
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
15+
if torch.jit.is_scripting():
16+
correct_type = isinstance(inpt, torch.Tensor)
1717
else:
18-
# Image instance after normalization is not Image anymore due to unknown data range
19-
# Thus we return Tensor for input Image
20-
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
18+
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image)
19+
inpt = inpt.as_subclass(torch.Tensor) # type: ignore[arg-type]
20+
if not correct_type:
21+
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
22+
23+
# Image instance after normalization is not Image anymore due to unknown data range
24+
# Thus we return Tensor for input Image
25+
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
2126

2227

2328
def gaussian_blur_image_tensor(

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
937937
mean = mean.view(-1, 1, 1)
938938
if std.ndim == 1:
939939
std = std.view(-1, 1, 1)
940-
tensor.sub_(mean).div_(std)
941-
return tensor
940+
return tensor.sub_(mean).div_(std)
942941

943942

944943
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:

0 commit comments

Comments
 (0)