Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8f8f936
move stuff out of CM
NicolasHug Aug 7, 2023
b1018a9
Call wrap_like for all exceptions
NicolasHug Aug 7, 2023
ddd88cd
Get rid of __torchfunction__ and the whole wrapping/unwrapping logic
NicolasHug Aug 7, 2023
4e8b53d
bbox tests
NicolasHug Aug 7, 2023
7471271
Put back wrapping / unwrapping in kernels
NicolasHug Aug 8, 2023
c5b44a9
Merge branch 'main' of github.com:pytorch/vision into lajenfljanfeljnfe
NicolasHug Aug 8, 2023
23b9704
Fix tests
NicolasHug Aug 8, 2023
f12fee1
preserve metadata on bboxes
NicolasHug Aug 9, 2023
1962124
Merge branch 'main' of github.com:pytorch/vision into lajenfljanfeljnfe
NicolasHug Aug 9, 2023
e9c1173
Merge branch 'main' of github.com:pytorch/vision into lajenfljanfeljnfe
NicolasHug Aug 9, 2023
854b01c
mypy
NicolasHug Aug 10, 2023
ec17580
Merge branch 'main' of github.com:pytorch/vision into lajenfljanfeljnfe
NicolasHug Aug 10, 2023
57e1b87
Merge branch 'main' of github.com:pytorch/vision into lajenfljanfeljnfe
NicolasHug Aug 12, 2023
2d39f6c
Allow switch
NicolasHug Aug 12, 2023
12b237b
Revert unnecessary changes
NicolasHug Aug 12, 2023
3bcfa91
fix
NicolasHug Aug 12, 2023
c7b10cc
Add required files
NicolasHug Aug 12, 2023
494fcaf
Merge branch 'main' of github.com:pytorch/vision into allow_return_su…
NicolasHug Aug 14, 2023
8a9645d
Address comments
NicolasHug Aug 14, 2023
6d41b69
Add support for context manager
NicolasHug Aug 14, 2023
8dc6add
mypy
NicolasHug Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 131 additions & 18 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from torchvision import datapoints


@pytest.fixture(autouse=True)
def preserve_default_wrapping_behaviour():
yield
datapoints.set_return_type("Tensor")


@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = datapoints.Image(data)
Expand Down Expand Up @@ -80,72 +86,89 @@ def test_to_wrapping():
assert image_to.dtype is torch.float64


def test_to_datapoint_reference():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_to_datapoint_reference(return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
image = datapoints.Image(tensor)

datapoints.set_return_type(return_type)
tensor_to = tensor.to(image)

assert type(tensor_to) is torch.Tensor
assert type(tensor_to) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is torch.float64


def test_clone_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_clone_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))

datapoints.set_return_type(return_type)
image_clone = image.clone()

assert type(image_clone) is datapoints.Image
assert image_clone.data_ptr() != image.data_ptr()


def test_requires_grad__wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_requires_grad__wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))

assert not image.requires_grad

datapoints.set_return_type(return_type)
image_requires_grad = image.requires_grad_(True)

assert type(image_requires_grad) is datapoints.Image
assert image.requires_grad
assert image_requires_grad.requires_grad


def test_detach_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_detach_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True)

datapoints.set_return_type(return_type)
image_detached = image.detach()

assert type(image_detached) is datapoints.Image


def test_no_wrapping_exceptions_with_metadata():
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)

datapoints.set_return_type(return_type)

bbox = bbox.clone()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)

bbox = bbox.to(torch.float64)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)

bbox = bbox.detach()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)

assert not bbox.requires_grad
bbox.requires_grad_(True)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad


def test_other_op_no_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_other_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))

# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
datapoints.set_return_type(return_type)
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2

assert type(output) is torch.Tensor
assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)


@pytest.mark.parametrize(
Expand All @@ -164,19 +187,21 @@ def test_no_tensor_output_op_no_wrapping(op):
assert type(output) is not datapoints.Image


def test_inplace_op_no_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_inplace_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))

datapoints.set_return_type(return_type)
output = image.add_(0)

assert type(output) is torch.Tensor
assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert type(image) is datapoints.Image


def test_wrap_like():
image = datapoints.Image(torch.rand(3, 16, 16))

# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2

image_new = datapoints.Image.wrap_like(image, output)
Expand Down Expand Up @@ -209,3 +234,91 @@ def test_deepcopy(datapoint, requires_grad):

assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad


@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_operations(return_type):
datapoints.set_return_type(return_type)

img = datapoints.Image(torch.rand(3, 10, 10))
t = torch.rand(3, 10, 10)
mask = datapoints.Mask(torch.rand(1, 10, 10))

for out in (
[
img + t,
t + img,
img * t,
t * img,
img + 3,
3 + img,
img * 3,
3 * img,
img + img,
img.sum(),
img.reshape(-1),
img.float(),
torch.stack([img, img]),
]
+ list(torch.chunk(img, 2))
+ list(torch.unbind(img))
):
assert type(out) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)

for out in (
[
mask + t,
t + mask,
mask * t,
t * mask,
mask + 3,
3 + mask,
mask * 3,
3 * mask,
mask + mask,
mask.sum(),
mask.reshape(-1),
mask.float(),
torch.stack([mask, mask]),
]
+ list(torch.chunk(mask, 2))
+ list(torch.unbind(mask))
):
assert type(out) is (datapoints.Mask if return_type == "datapoint" else torch.Tensor)

with pytest.raises(TypeError, match="unsupported operand type"):
img + mask

with pytest.raises(TypeError, match="unsupported operand type"):
img * mask

bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000)
)
t = torch.rand(2, 4)

for out in (
[
bboxes + t,
t + bboxes,
bboxes * t,
t * bboxes,
bboxes + 3,
3 + bboxes,
bboxes * 3,
3 * bboxes,
bboxes + bboxes,
bboxes.sum(),
bboxes.reshape(-1),
bboxes.float(),
torch.stack([bboxes, bboxes]),
]
+ list(torch.chunk(bboxes, 2))
+ list(torch.unbind(bboxes))
):
if return_type == "Tensor":
assert type(out) is torch.Tensor
else:
assert isinstance(out, datapoints.BoundingBoxes)
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
2 changes: 2 additions & 0 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import Datapoint
from ._image import Image
from ._mask import Mask
from ._torch_function_helpers import set_return_type
from ._video import Video

if _WARN_ABOUT_BETA_TRANSFORMS:
Expand Down
38 changes: 32 additions & 6 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Optional, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch.utils._pytree import tree_flatten

from ._datapoint import Datapoint

Expand Down Expand Up @@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int]

@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add the check_dims flag because in some cases like for bbox.sum() the dims won't be correct

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this checks to _wrap for images as well?

if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed

if check_dims:
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls)
Expand Down Expand Up @@ -99,5 +101,29 @@ def wrap_like(
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)

@classmethod
def _wrap_output(
cls,
output: torch.tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> BoundingBoxes:
# If there are BoundingBoxes instances in the output, their metadata got lost when we called
# super().__torch_function__. We need to restore the metadata somehow, so we choose to take
# the metadata from the first bbox in the parameters.
# This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
# something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size

if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
elif isinstance(output, (tuple, list)):
output = type(output)(
BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
)
return output

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, canvas_size=self.canvas_size)
24 changes: 16 additions & 8 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size

from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass


D = TypeVar("D", bound="Datapoint")
Expand Down Expand Up @@ -33,9 +34,14 @@ def _to_tensor(
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

# The ops in this set are those that should *preserve* the Datapoint type,
# i.e. they are exceptions to the "no wrapping" rule.
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
@classmethod
def _wrap_output(
cls,
output: torch.tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> D:
return torch._tensor._convert(output, cls)

@classmethod
def __torch_function__(
Expand All @@ -60,27 +66,29 @@ def __torch_function__(
2. For most operations, there is no way of knowing if the input type is still valid for the output.

For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
listed in _FORCE_TORCHFUNCTION_SUBCLASS
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the docstring above is mostly wrong / obsolete now. If this is merged I would rewrite everything. Same for a lot of the comments.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not going to block, but I would prefer doing this here. I'm ok with the docs being updated later, since the default behavior doesn't change. But I feel the comments here should be updated right away, since they are wrong / obsolete now.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will be able to write better quality comments / docstring once I start writing the user-facing docs.

"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.

if not all(issubclass(cls, t) for t in types):
return NotImplemented

# Like in the base Tensor.__torch_function__ implementation, it's easier to always use
# DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())

if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
if _must_return_subclass() or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
return cls.wrap_like(args[0], output)
return cls._wrap_output(output, args, kwargs)

if isinstance(output, cls):
if not _must_return_subclass() and isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor)
Expand Down
16 changes: 16 additions & 0 deletions torchvision/datapoints/_torch_function_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

_TORCHFUNCTION_SUBCLASS = False


def set_return_type(type="Tensor"):
Copy link
Copy Markdown
Member Author

@NicolasHug NicolasHug Aug 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only function that is publicly exposed. We'll probably want to make this a context manager on top of a global flag switch. We can bikeshed on its name and refine the actual UX, but let's first focus on whether we want to expose this functionality.

global _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[type.lower()]


def _must_return_subclass():
return _TORCHFUNCTION_SUBCLASS


# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I'm not super happy about our names ("unwrapping" vs "subclass"), a lot of it actually coming from the base implementations. But we can clean that up later.

2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def forward(self, *inputs: Any) -> Any:
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)

params = dict(valid=valid, labels=labels)
params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels
Expand Down
Loading