Allow users to choose whether to return Datapoint subclasses or pure Tensor#7825
Allow users to choose whether to return Datapoint subclasses or pure Tensor#7825NicolasHug merged 21 commits intopytorch:mainfrom
Conversation
| tensor = tensor.unsqueeze(0) | ||
| elif tensor.ndim != 2: | ||
| raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") | ||
| def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] |
There was a problem hiding this comment.
Had to add the check_dims flag because in some cases like for bbox.sum() the dims won't be correct
There was a problem hiding this comment.
Should we move this checks to _wrap for images as well?
vision/torchvision/datapoints/_image.py
Lines 39 to 42 in 3065ad5
|
|
||
| For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
| listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
| listed in _FORCE_TORCHFUNCTION_SUBCLASS |
There was a problem hiding this comment.
Note: the docstring above is mostly wrong / obsolete now. If this is merged I would rewrite everything. Same for a lot of the comments.
There was a problem hiding this comment.
Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.
There was a problem hiding this comment.
I think I will be able to write better quality comments / docstring once I start writing the user-facing docs.
| _TORCHFUNCTION_SUBCLASS = False | ||
|
|
||
|
|
||
| def set_return_type(type="Tensor"): |
There was a problem hiding this comment.
This is the only function that is publicly exposed. We'll probably want to make this a context manager on top of a global flag switch. We can bikeshed on its name and refine the actual UX, but let's first focus on whether we want to expose this functionality.
|
|
||
|
|
||
| # For those ops we always want to preserve the original subclass instead of returning a pure Tensor | ||
| _FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} |
There was a problem hiding this comment.
Note: I'm not super happy about our names ("unwrapping" vs "subclass"), a lot of it actually coming from the base implementations. But we can clean that up later.
| tensor = tensor.unsqueeze(0) | ||
| elif tensor.ndim != 2: | ||
| raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") | ||
| def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] |
There was a problem hiding this comment.
Should we move this checks to _wrap for images as well?
vision/torchvision/datapoints/_image.py
Lines 39 to 42 in 3065ad5
|
|
||
| For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
| listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
| listed in _FORCE_TORCHFUNCTION_SUBCLASS |
There was a problem hiding this comment.
Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.
pmeier
left a comment
There was a problem hiding this comment.
Stamping. Thanks Nicolas!
|
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
… or pure Tensor (#7825) Reviewed By: matteobettini Differential Revision: D48642251 fbshipit-source-id: 9a59123410585c4b0523069089803784168ca707
This is basically the same as #7807, but preserve the current default behaviour i.e. we still return tensors by default.
This adds a
datapoints.set_return_type("datapoints")public switch that allows users to decide whether they want datapoints or tensors as output.This does NOT change anything to the unwrap/wrapping logic of our functional kernels.
cc @vfdev-5