-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Allow users to choose whether to return Datapoint subclasses or pure Tensor #7825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
8f8f936
b1018a9
ddd88cd
4e8b53d
7471271
c5b44a9
23b9704
f12fee1
1962124
e9c1173
854b01c
ec17580
57e1b87
2d39f6c
12b237b
3bcfa91
c7b10cc
494fcaf
8a9645d
6d41b69
8dc6add
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,8 @@ | |
|
|
||
| import torch | ||
| from torch._C import DisableTorchFunctionSubclass | ||
| from torch.types import _device, _dtype, _size | ||
|
|
||
| from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass | ||
|
|
||
|
|
||
| D = TypeVar("D", bound="Datapoint") | ||
|
|
@@ -33,9 +34,14 @@ def _to_tensor( | |
| def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: | ||
| return tensor.as_subclass(cls) | ||
|
|
||
| # The ops in this set are those that should *preserve* the Datapoint type, | ||
| # i.e. they are exceptions to the "no wrapping" rule. | ||
| _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} | ||
| @classmethod | ||
| def _wrap_output( | ||
| cls, | ||
| output: torch.tensor, | ||
| args: Sequence[Any] = (), | ||
| kwargs: Optional[Mapping[str, Any]] = None, | ||
| ) -> D: | ||
| return torch._tensor._convert(output, cls) | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @classmethod | ||
| def __torch_function__( | ||
|
|
@@ -60,27 +66,29 @@ def __torch_function__( | |
| 2. For most operations, there is no way of knowing if the input type is still valid for the output. | ||
|
|
||
| For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
| listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
| listed in _FORCE_TORCHFUNCTION_SUBCLASS | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: the docstring above is mostly wrong / obsolete now. If this is merged I would rewrite everything. Same for a lot of the comments.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I will be able to write better quality comments / docstring once I start writing the user-facing docs. |
||
| """ | ||
| # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we | ||
| # need to reimplement the functionality. | ||
|
|
||
| if not all(issubclass(cls, t) for t in types): | ||
| return NotImplemented | ||
|
|
||
| # Like in the base Tensor.__torch_function__ implementation, it's easier to always use | ||
| # DisableTorchFunctionSubclass and then manually re-wrap the output if necessary | ||
| with DisableTorchFunctionSubclass(): | ||
| output = func(*args, **kwargs or dict()) | ||
|
|
||
| if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): | ||
| if _must_return_subclass() or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): | ||
| # We also require the primary operand, i.e. `args[0]`, to be | ||
| # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will | ||
| # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, | ||
| # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with | ||
| # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would | ||
| # be wrapped into a `datapoints.Image`. | ||
| return cls.wrap_like(args[0], output) | ||
| return cls._wrap_output(output, args, kwargs) | ||
|
|
||
| if isinstance(output, cls): | ||
| if not _must_return_subclass() and isinstance(output, cls): | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, | ||
| # so for those, the output is still a Datapoint. Thus, we need to manually unwrap. | ||
| return output.as_subclass(torch.Tensor) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| import torch | ||
|
|
||
| _TORCHFUNCTION_SUBCLASS = False | ||
|
|
||
|
|
||
| def set_return_type(type="Tensor"): | ||
|
||
| global _TORCHFUNCTION_SUBCLASS | ||
| _TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[type.lower()] | ||
|
|
||
|
|
||
| def _must_return_subclass(): | ||
| return _TORCHFUNCTION_SUBCLASS | ||
|
|
||
|
|
||
| # For those ops we always want to preserve the original subclass instead of returning a pure Tensor | ||
| _FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I'm not super happy about our names ("unwrapping" vs "subclass"), a lot of it actually coming from the base implementations. But we can clean that up later. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to add the
check_dimsflag because in some cases like forbbox.sum()the dims won't be correctThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this checks to
_wrapfor images as well?vision/torchvision/datapoints/_image.py
Lines 39 to 42 in 3065ad5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed