Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Conversion
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
v2.ToPureTensor

Auto-Augmentation
-----------------
Expand Down
6 changes: 6 additions & 0 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
if random_erase_prob > 0:
transforms.append(T.RandomErasing(p=random_erase_prob))

if use_v2:
transforms.append(T.ToPureTensor())

self.transforms = T.Compose(transforms)

def __call__(self, img):
Expand Down Expand Up @@ -107,6 +110,9 @@ def __init__(
T.Normalize(mean=mean, std=std),
]

if use_v2:
transforms.append(T.ToPureTensor())

self.transforms = T.Compose(transforms)

def __call__(self, img):
Expand Down
5 changes: 5 additions & 0 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(),
T.ToPureTensor(),
]

self.transforms = T.Compose(transforms)
Expand All @@ -103,6 +104,10 @@ def __init__(self, backend="pil", use_v2=False):
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]

if use_v2:
transforms += [T.ToPureTensor()]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
Expand Down
5 changes: 5 additions & 0 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
transforms += [T.ConvertImageDtype(torch.float)]

transforms += [T.Normalize(mean=mean, std=std)]
if use_v2:
transforms += [T.ToPureTensor()]

self.transforms = T.Compose(transforms)

Expand Down Expand Up @@ -98,6 +100,9 @@ def __init__(
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
if use_v2:
transforms += [T.ToPureTensor()]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
Expand Down
21 changes: 21 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,3 +2270,24 @@ def test_image_correctness(self, permutation, batch_dims):
expected = self.reference_image_correctness(image, permutation=permutation)

torch.testing.assert_close(actual, expected)


class TestToPureTensor:
def test_correctness(self):
input = {
"img": make_image(),
"img_tensor": make_image_tensor(),
"img_pil": make_image_pil(),
"mask": make_detection_mask(),
"video": make_video(),
"bbox": make_bounding_box(),
"str": "str",
}

out = transforms.ToPureTensor()(input)

for input_value, out_value in zip(input.values(), out.values()):
if isinstance(input_value, datapoints.Datapoint):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
else:
assert isinstance(out_value, type(input_value))
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage, ToPureTensor

from ._deprecated import ToTensor # usort: skip

Expand Down
14 changes: 14 additions & 0 deletions torchvision/transforms/v2/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,17 @@ def _transform(
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL


class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any).

.. v2betastatus:: ToPureTensor transform

This doesn't scale or change the values, only the type.
"""

_transformed_types = (datapoints.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBoxes)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need / want a functional for this. The functional would just wrap .as_subclass(torch.Tensor) which is already public and users should just use that