@@ -58,6 +58,16 @@ def new_like(
58
58
** kwargs ,
59
59
)
60
60
61
+ _NO_WRAPPING_EXCEPTIONS = {
62
+ torch .Tensor .clone : lambda cls , input , output : cls .new_like (input , output ),
63
+ torch .Tensor .to : lambda cls , input , output : cls .new_like (
64
+ input , output , dtype = output .dtype , device = output .device
65
+ ),
66
+ # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
67
+ # retains the type automatically
68
+ torch .Tensor .requires_grad_ : lambda cls , input , output : output ,
69
+ }
70
+
61
71
@classmethod
62
72
def __torch_function__ (
63
73
cls ,
@@ -73,19 +83,15 @@ def __torch_function__(
73
83
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
74
84
``args`` and ``kwargs`` of the original call.
75
85
76
- The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature `
86
+ The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature `
77
87
use case, this has two downsides:
78
88
79
89
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
80
90
``return cls(func(*args, **kwargs))``, will fail for them.
81
91
2. For most operations, there is no way of knowing if the input type is still valid for the output.
82
92
83
- For these reasons, the automatic output wrapping is turned off for most operators.
84
-
85
- Exceptions to this are:
86
-
87
- - :meth:`torch.Tensor.clone`
88
- - :meth:`torch.Tensor.to`
93
+ For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
94
+ listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS`
89
95
"""
90
96
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
91
97
# need to reimplement the functionality.
@@ -96,18 +102,21 @@ def __torch_function__(
96
102
with DisableTorchFunction ():
97
103
output = func (* args , ** kwargs or dict ())
98
104
99
- # The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
100
- # the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
101
- # `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
102
- # case.
103
- if not isinstance (args [0 ], cls ):
104
- return output
105
+ wrapper = cls ._NO_WRAPPING_EXCEPTIONS .get (func )
106
+ # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
107
+ # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
108
+ # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
109
+ # `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with
110
+ # `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would
111
+ # be wrapped into a `features.Image`.
112
+ if wrapper and isinstance (args [0 ], cls ):
113
+ return wrapper (cls , args [0 ], output ) # type: ignore[no-any-return]
114
+
115
+ # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
116
+ # will retain the input type. Thus, we need to unwrap here.
117
+ if isinstance (output , cls ):
118
+ return output .as_subclass (torch .Tensor ) # type: ignore[arg-type]
105
119
106
- if func is torch .Tensor .clone :
107
- return cls .new_like (args [0 ], output )
108
- elif func is torch .Tensor .to :
109
- return cls .new_like (args [0 ], output , dtype = output .dtype , device = output .device )
110
- else :
111
120
return output
112
121
113
122
def _make_repr (self , ** kwargs : Any ) -> str :
0 commit comments