diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py new file mode 100644 index 00000000000..554088b912a --- /dev/null +++ b/torchvision/datapoints/__init__.py @@ -0,0 +1,8 @@ +from ._bounding_box import BoundingBox, BoundingBoxFormat +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT +from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT +from ._label import Label, OneHotLabel +from ._mask import Mask +from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT + +from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py new file mode 100644 index 00000000000..b904dd5e5aa --- /dev/null +++ b/torchvision/datapoints/_bounding_box.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Sequence, Tuple, Union + +import torch +from torchvision._utils import StrEnum +from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms + +from ._datapoint import Datapoint, FillTypeJIT + + +class BoundingBoxFormat(StrEnum): + XYXY = StrEnum.auto() + XYWH = StrEnum.auto() + CXCYWH = StrEnum.auto() + + +class BoundingBox(Datapoint): + format: BoundingBoxFormat + spatial_size: Tuple[int, int] + + @classmethod + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox: + bounding_box = tensor.as_subclass(cls) + bounding_box.format = format + bounding_box.spatial_size = spatial_size + return bounding_box + + def __new__( + cls, + data: Any, + *, + format: Union[BoundingBoxFormat, str], + spatial_size: Tuple[int, int], + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> BoundingBox: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + + if isinstance(format, str): + format = BoundingBoxFormat.from_str(format.upper()) + + return cls._wrap(tensor, format=format, spatial_size=spatial_size) + + @classmethod + def wrap_like( + cls, + other: BoundingBox, + tensor: torch.Tensor, + *, + format: Optional[BoundingBoxFormat] = None, + spatial_size: Optional[Tuple[int, int]] = None, + ) -> BoundingBox: + return cls._wrap( + tensor, + format=format if format is not None else other.format, + spatial_size=spatial_size if spatial_size is not None else other.spatial_size, + ) + + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(format=self.format, spatial_size=self.spatial_size) + + def horizontal_flip(self) -> BoundingBox: + output = self._F.horizontal_flip_bounding_box( + self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size + ) + return BoundingBox.wrap_like(self, output) + + def vertical_flip(self) -> BoundingBox: + output = self._F.vertical_flip_bounding_box( + self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size + ) + return BoundingBox.wrap_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> BoundingBox: + output, spatial_size = self._F.resize_bounding_box( + self.as_subclass(torch.Tensor), + spatial_size=self.spatial_size, + size=size, + max_size=max_size, + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: + output, spatial_size = self._F.crop_bounding_box( + self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def center_crop(self, output_size: List[int]) -> BoundingBox: + output, spatial_size = self._F.center_crop_bounding_box( + self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> BoundingBox: + output, spatial_size = self._F.resized_crop_bounding_box( + self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def pad( + self, + padding: Union[int, Sequence[int]], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", + ) -> BoundingBox: + output, spatial_size = self._F.pad_bounding_box( + self.as_subclass(torch.Tensor), + format=self.format, + spatial_size=self.spatial_size, + padding=padding, + padding_mode=padding_mode, + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def rotate( + self, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: FillTypeJIT = None, + ) -> BoundingBox: + output, spatial_size = self._F.rotate_bounding_box( + self.as_subclass(torch.Tensor), + format=self.format, + spatial_size=self.spatial_size, + angle=angle, + expand=expand, + center=center, + ) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> BoundingBox: + output = self._F.affine_bounding_box( + self.as_subclass(torch.Tensor), + self.format, + self.spatial_size, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return BoundingBox.wrap_like(self, output) + + def perspective( + self, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + ) -> BoundingBox: + output = self._F.perspective_bounding_box( + self.as_subclass(torch.Tensor), + format=self.format, + spatial_size=self.spatial_size, + startpoints=startpoints, + endpoints=endpoints, + coefficients=coefficients, + ) + return BoundingBox.wrap_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> BoundingBox: + output = self._F.elastic_bounding_box( + self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement + ) + return BoundingBox.wrap_like(self, output) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py new file mode 100644 index 00000000000..5f4a0d96ea2 --- /dev/null +++ b/torchvision/datapoints/_datapoint.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union + +import PIL.Image +import torch +from torch._C import DisableTorchFunctionSubclass +from torch.types import _device, _dtype, _size +from torchvision.transforms import InterpolationMode + + +D = TypeVar("D", bound="Datapoint") +FillType = Union[int, float, Sequence[int], Sequence[float], None] +FillTypeJIT = Optional[List[float]] + + +class Datapoint(torch.Tensor): + __F: Optional[ModuleType] = None + + @staticmethod + def _to_tensor( + data: Any, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> torch.Tensor: + if requires_grad is None: + requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False + return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + + @classmethod + def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: + raise NotImplementedError + + _NO_WRAPPING_EXCEPTIONS = { + torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), + torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), + # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus + # retains the type automatically + torch.Tensor.requires_grad_: lambda cls, input, output: output, + } + + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + """For general information about how the __torch_function__ protocol works, + see https://pytorch.org/docs/stable/notes/extending.html#extending-torch + + TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the + ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the + ``args`` and ``kwargs`` of the original call. + + The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` + use case, this has two downsides: + + 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. + ``return cls(func(*args, **kwargs))``, will fail for them. + 2. For most operations, there is no way of knowing if the input type is still valid for the output. + + For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are + listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` + """ + # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we + # need to reimplement the functionality. + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + with DisableTorchFunctionSubclass(): + output = func(*args, **kwargs or dict()) + + wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) + # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be + # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will + # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, + # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with + # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would + # be wrapped into a `datapoints.Image`. + if wrapper and isinstance(args[0], cls): + return wrapper(cls, args[0], output) + + # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, + # will retain the input type. Thus, we need to unwrap here. + if isinstance(output, cls): + return output.as_subclass(torch.Tensor) + + return output + + def _make_repr(self, **kwargs: Any) -> str: + # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. + # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class. + extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) + return f"{super().__repr__()[:-1]}, {extra_repr})" + + @property + def _F(self) -> ModuleType: + # This implements a lazy import of the functional to get around the cyclic import. This import is deferred + # until the first time we need reference to the functional module and it's shared across all instances of + # the class. This approach avoids the DataLoader issue described at + # https://github.com/pytorch/vision/pull/6476#discussion_r953588621 + if Datapoint.__F is None: + from ..transforms import functional + + Datapoint.__F = functional + return Datapoint.__F + + # Add properties for common attributes like shape, dtype, device, ndim etc + # this way we return the result without passing into __torch_function__ + @property + def shape(self) -> _size: # type: ignore[override] + with DisableTorchFunctionSubclass(): + return super().shape + + @property + def ndim(self) -> int: # type: ignore[override] + with DisableTorchFunctionSubclass(): + return super().ndim + + @property + def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] + with DisableTorchFunctionSubclass(): + return super().device + + @property + def dtype(self) -> _dtype: # type: ignore[override] + with DisableTorchFunctionSubclass(): + return super().dtype + + def horizontal_flip(self) -> Datapoint: + return self + + def vertical_flip(self) -> Datapoint: + return self + + # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize + # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593 + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Datapoint: + return self + + def crop(self, top: int, left: int, height: int, width: int) -> Datapoint: + return self + + def center_crop(self, output_size: List[int]) -> Datapoint: + return self + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Datapoint: + return self + + def pad( + self, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", + ) -> Datapoint: + return self + + def rotate( + self, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: FillTypeJIT = None, + ) -> Datapoint: + return self + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Datapoint: + return self + + def perspective( + self, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + ) -> Datapoint: + return self + + def elastic( + self, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> Datapoint: + return self + + def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint: + return self + + def adjust_brightness(self, brightness_factor: float) -> Datapoint: + return self + + def adjust_saturation(self, saturation_factor: float) -> Datapoint: + return self + + def adjust_contrast(self, contrast_factor: float) -> Datapoint: + return self + + def adjust_sharpness(self, sharpness_factor: float) -> Datapoint: + return self + + def adjust_hue(self, hue_factor: float) -> Datapoint: + return self + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint: + return self + + def posterize(self, bits: int) -> Datapoint: + return self + + def solarize(self, threshold: float) -> Datapoint: + return self + + def autocontrast(self) -> Datapoint: + return self + + def equalize(self) -> Datapoint: + return self + + def invert(self) -> Datapoint: + return self + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint: + return self + + +InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] +InputTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py new file mode 100644 index 00000000000..74f83095177 --- /dev/null +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -0,0 +1,426 @@ +# type: ignore + +from __future__ import annotations + +import contextlib +from collections import defaultdict + +import torch +from torch.utils.data import Dataset + +from torchvision import datasets +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F + +__all__ = ["wrap_dataset_for_transforms_v2"] + + +# TODO: naming! +def wrap_dataset_for_transforms_v2(dataset): + return VisionDatasetDatapointWrapper(dataset) + + +class WrapperFactories(dict): + def register(self, dataset_cls): + def decorator(wrapper_factory): + self[dataset_cls] = wrapper_factory + return wrapper_factory + + return decorator + + +# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the +# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can +# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when +# we have access to the dataset instance. +WRAPPER_FACTORIES = WrapperFactories() + + +class VisionDatasetDatapointWrapper(Dataset): + def __init__(self, dataset): + dataset_cls = type(dataset) + + if not isinstance(dataset, datasets.VisionDataset): + raise TypeError( + f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " + f"but got a '{dataset_cls.__name__}' instead." + ) + + for cls in dataset_cls.mro(): + if cls in WRAPPER_FACTORIES: + wrapper_factory = WRAPPER_FACTORIES[cls] + break + elif cls is datasets.VisionDataset: + # TODO: If we have documentation on how to do that, put a link in the error message. + msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + if dataset_cls in datasets.__dict__.values(): + msg = ( + f"{msg} If an automated wrapper for this dataset would be useful for you, " + f"please open an issue at https://github.com/pytorch/vision/issues." + ) + raise TypeError(msg) + + self._dataset = dataset + self._wrapper = wrapper_factory(dataset) + + # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. + # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint + # `transforms` + # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54 + # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to + # disable all three here to be able to extract the untransformed sample to wrap. + self.transform, dataset.transform = dataset.transform, None + self.target_transform, dataset.target_transform = dataset.target_transform, None + self.transforms, dataset.transforms = dataset.transforms, None + + def __getattr__(self, item): + with contextlib.suppress(AttributeError): + return object.__getattribute__(self, item) + + return getattr(self._dataset, item) + + def __getitem__(self, idx): + # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor + # of this class + sample = self._dataset[idx] + + sample = self._wrapper(idx, sample) + + # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) + # or joint (`transforms`), we can access the full functionality through `transforms` + if self.transforms is not None: + sample = self.transforms(*sample) + + return sample + + def __len__(self): + return len(self._dataset) + + +def raise_not_supported(description): + raise RuntimeError( + f"{description} is currently not supported by this wrapper. " + f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." + ) + + +def identity(item): + return item + + +def identity_wrapper_factory(dataset): + def wrapper(idx, sample): + return sample + + return wrapper + + +def pil_image_to_mask(pil_image): + return datapoints.Mask(pil_image) + + +def list_of_dicts_to_dict_of_lists(list_of_dicts): + dict_of_lists = defaultdict(list) + for dct in list_of_dicts: + for key, value in dct.items(): + dict_of_lists[key].append(value) + return dict(dict_of_lists) + + +def wrap_target_by_type(target, *, target_types, type_wrappers): + if not isinstance(target, (tuple, list)): + target = [target] + + wrapped_target = tuple( + type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target) + ) + + if len(wrapped_target) == 1: + wrapped_target = wrapped_target[0] + + return wrapped_target + + +def classification_wrapper_factory(dataset): + return identity_wrapper_factory(dataset) + + +for dataset_cls in [ + datasets.Caltech256, + datasets.CIFAR10, + datasets.CIFAR100, + datasets.ImageNet, + datasets.MNIST, + datasets.FashionMNIST, + datasets.GTSRB, + datasets.DatasetFolder, + datasets.ImageFolder, +]: + WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) + + +def segmentation_wrapper_factory(dataset): + def wrapper(idx, sample): + image, mask = sample + return image, pil_image_to_mask(mask) + + return wrapper + + +for dataset_cls in [ + datasets.VOCSegmentation, +]: + WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory) + + +def video_classification_wrapper_factory(dataset): + if dataset.video_clips.output_format == "THWC": + raise RuntimeError( + f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, " + f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead." + ) + + def wrapper(idx, sample): + video, audio, label = sample + + video = datapoints.Video(video) + + return video, audio, label + + return wrapper + + +for dataset_cls in [ + datasets.HMDB51, + datasets.Kinetics, + datasets.UCF101, +]: + WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory) + + +@WRAPPER_FACTORIES.register(datasets.Caltech101) +def caltech101_wrapper_factory(dataset): + if "annotation" in dataset.target_type: + raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") + + return classification_wrapper_factory(dataset) + + +@WRAPPER_FACTORIES.register(datasets.CocoDetection) +def coco_dectection_wrapper_factory(dataset): + def segmentation_to_mask(segmentation, *, spatial_size): + from pycocotools import mask + + segmentation = ( + mask.frPyObjects(segmentation, *spatial_size) + if isinstance(segmentation, dict) + else mask.merge(mask.frPyObjects(segmentation, *spatial_size)) + ) + return torch.from_numpy(mask.decode(segmentation)) + + def wrapper(idx, sample): + image_id = dataset.ids[idx] + + image, target = sample + + if not target: + return image, dict(image_id=image_id) + + batched_target = list_of_dicts_to_dict_of_lists(target) + + batched_target["image_id"] = image_id + + spatial_size = tuple(F.get_spatial_size(image)) + batched_target["boxes"] = datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=spatial_size, + ) + batched_target["masks"] = datapoints.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, spatial_size=spatial_size) + for segmentation in batched_target["segmentation"] + ] + ), + ) + batched_target["labels"] = torch.tensor(batched_target["category_id"]) + + return image, batched_target + + return wrapper + + +WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory) + + +VOC_DETECTION_CATEGORIES = [ + "__background__", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] +VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES)))) + + +@WRAPPER_FACTORIES.register(datasets.VOCDetection) +def voc_detection_wrapper_factory(dataset): + def wrapper(idx, sample): + image, target = sample + + batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) + + target["boxes"] = datapoints.BoundingBox( + [ + [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for bndbox in batched_instances["bndbox"] + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), + ) + target["labels"] = torch.tensor( + [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]] + ) + + return image, target + + return wrapper + + +@WRAPPER_FACTORIES.register(datasets.SBDataset) +def sbd_wrapper(dataset): + if dataset.mode == "boundaries": + raise_not_supported("SBDataset with mode='boundaries'") + + return segmentation_wrapper_factory(dataset) + + +@WRAPPER_FACTORIES.register(datasets.CelebA) +def celeba_wrapper_factory(dataset): + if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): + raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") + + def wrapper(idx, sample): + image, target = sample + + target = wrap_target_by_type( + target, + target_types=dataset.target_type, + type_wrappers={ + "bbox": lambda item: datapoints.BoundingBox( + item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ), + }, + ) + + return image, target + + return wrapper + + +KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"] +KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))) + + +@WRAPPER_FACTORIES.register(datasets.Kitti) +def kitti_wrapper_factory(dataset): + def wrapper(idx, sample): + image, target = sample + + if target is not None: + target = list_of_dicts_to_dict_of_lists(target) + + target["boxes"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) + ) + target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]]) + + return image, target + + return wrapper + + +@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) +def oxford_iiit_pet_wrapper_factor(dataset): + def wrapper(idx, sample): + image, target = sample + + if target is not None: + target = wrap_target_by_type( + target, + target_types=dataset._target_types, + type_wrappers={ + "segmentation": pil_image_to_mask, + }, + ) + + return image, target + + return wrapper + + +@WRAPPER_FACTORIES.register(datasets.Cityscapes) +def cityscapes_wrapper_factory(dataset): + if any(target_type in dataset.target_type for target_type in ["polygon", "color"]): + raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`") + + def instance_segmentation_wrapper(mask): + # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 + data = pil_image_to_mask(mask) + masks = [] + labels = [] + for id in data.unique(): + masks.append(data == id) + label = id + if label >= 1_000: + label //= 1_000 + labels.append(label) + return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) + + def wrapper(idx, sample): + image, target = sample + + target = wrap_target_by_type( + target, + target_types=dataset.target_type, + type_wrappers={ + "instance": instance_segmentation_wrapper, + "semantic": pil_image_to_mask, + }, + ) + + return image, target + + return wrapper + + +@WRAPPER_FACTORIES.register(datasets.WIDERFace) +def widerface_wrapper(dataset): + def wrapper(idx, sample): + image, target = sample + + if target is not None: + target["bbox"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ) + + return image, target + + return wrapper diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py new file mode 100644 index 00000000000..4fc14323abe --- /dev/null +++ b/torchvision/datapoints/_image.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Tuple, Union + +import PIL.Image +import torch +from torchvision.transforms.functional import InterpolationMode + +from ._datapoint import Datapoint, FillTypeJIT + + +class Image(Datapoint): + @classmethod + def _wrap(cls, tensor: torch.Tensor) -> Image: + image = tensor.as_subclass(cls) + return image + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> Image: + if isinstance(data, PIL.Image.Image): + from torchvision.prototype.transforms import functional as F + + data = F.pil_to_tensor(data) + + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + if tensor.ndim < 2: + raise ValueError + elif tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + + return cls._wrap(tensor) + + @classmethod + def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: + return cls._wrap(tensor) + + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr() + + @property + def spatial_size(self) -> Tuple[int, int]: + return tuple(self.shape[-2:]) # type: ignore[return-value] + + @property + def num_channels(self) -> int: + return self.shape[-3] + + def horizontal_flip(self) -> Image: + output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) + return Image.wrap_like(self, output) + + def vertical_flip(self) -> Image: + output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) + return Image.wrap_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Image: + output = self._F.resize_image_tensor( + self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + return Image.wrap_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Image: + output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width) + return Image.wrap_like(self, output) + + def center_crop(self, output_size: List[int]) -> Image: + output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size) + return Image.wrap_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Image: + output = self._F.resized_crop_image_tensor( + self.as_subclass(torch.Tensor), + top, + left, + height, + width, + size=list(size), + interpolation=interpolation, + antialias=antialias, + ) + return Image.wrap_like(self, output) + + def pad( + self, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", + ) -> Image: + output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) + return Image.wrap_like(self, output) + + def rotate( + self, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: FillTypeJIT = None, + ) -> Image: + output = self._F.rotate_image_tensor( + self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Image.wrap_like(self, output) + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Image: + output = self._F.affine_image_tensor( + self.as_subclass(torch.Tensor), + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Image.wrap_like(self, output) + + def perspective( + self, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + ) -> Image: + output = self._F.perspective_image_tensor( + self.as_subclass(torch.Tensor), + startpoints, + endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + return Image.wrap_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> Image: + output = self._F.elastic_image_tensor( + self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill + ) + return Image.wrap_like(self, output) + + def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image: + output = self._F.rgb_to_grayscale_image_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Image.wrap_like(self, output) + + def adjust_brightness(self, brightness_factor: float) -> Image: + output = self._F.adjust_brightness_image_tensor( + self.as_subclass(torch.Tensor), brightness_factor=brightness_factor + ) + return Image.wrap_like(self, output) + + def adjust_saturation(self, saturation_factor: float) -> Image: + output = self._F.adjust_saturation_image_tensor( + self.as_subclass(torch.Tensor), saturation_factor=saturation_factor + ) + return Image.wrap_like(self, output) + + def adjust_contrast(self, contrast_factor: float) -> Image: + output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor) + return Image.wrap_like(self, output) + + def adjust_sharpness(self, sharpness_factor: float) -> Image: + output = self._F.adjust_sharpness_image_tensor( + self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor + ) + return Image.wrap_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Image: + output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor) + return Image.wrap_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: + output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain) + return Image.wrap_like(self, output) + + def posterize(self, bits: int) -> Image: + output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits) + return Image.wrap_like(self, output) + + def solarize(self, threshold: float) -> Image: + output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold) + return Image.wrap_like(self, output) + + def autocontrast(self) -> Image: + output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor)) + return Image.wrap_like(self, output) + + def equalize(self) -> Image: + output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor)) + return Image.wrap_like(self, output) + + def invert(self) -> Image: + output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor)) + return Image.wrap_like(self, output) + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: + output = self._F.gaussian_blur_image_tensor( + self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma + ) + return Image.wrap_like(self, output) + + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: + output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Image.wrap_like(self, output) + + +ImageType = Union[torch.Tensor, PIL.Image.Image, Image] +ImageTypeJIT = torch.Tensor +TensorImageType = Union[torch.Tensor, Image] +TensorImageTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_label.py b/torchvision/datapoints/_label.py new file mode 100644 index 00000000000..0ee2eb9f8e1 --- /dev/null +++ b/torchvision/datapoints/_label.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any, Optional, Sequence, Type, TypeVar, Union + +import torch +from torch.utils._pytree import tree_map + +from ._datapoint import Datapoint + + +L = TypeVar("L", bound="_LabelBase") + + +class _LabelBase(Datapoint): + categories: Optional[Sequence[str]] + + @classmethod + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: + label_base = tensor.as_subclass(cls) + label_base.categories = categories + return label_base + + def __new__( + cls: Type[L], + data: Any, + *, + categories: Optional[Sequence[str]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> L: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor, categories=categories) + + @classmethod + def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L: + return cls._wrap( + tensor, + categories=categories if categories is not None else other.categories, + ) + + @classmethod + def from_category( + cls: Type[L], + category: str, + *, + categories: Sequence[str], + **kwargs: Any, + ) -> L: + return cls(categories.index(category), categories=categories, **kwargs) + + +class Label(_LabelBase): + def to_categories(self) -> Any: + if self.categories is None: + raise RuntimeError("Label does not have categories") + + return tree_map(lambda idx: self.categories[idx], self.tolist()) + + +class OneHotLabel(_LabelBase): + def __new__( + cls, + data: Any, + *, + categories: Optional[Sequence[str]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: bool = False, + ) -> OneHotLabel: + one_hot_label = super().__new__( + cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad + ) + + if categories is not None and len(categories) != one_hot_label.shape[-1]: + raise ValueError() + + return one_hot_label diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py new file mode 100644 index 00000000000..41dce097c6c --- /dev/null +++ b/torchvision/datapoints/_mask.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Tuple, Union + +import PIL.Image +import torch +from torchvision.transforms import InterpolationMode + +from ._datapoint import Datapoint, FillTypeJIT + + +class Mask(Datapoint): + @classmethod + def _wrap(cls, tensor: torch.Tensor) -> Mask: + return tensor.as_subclass(cls) + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> Mask: + if isinstance(data, PIL.Image.Image): + from torchvision.prototype.transforms import functional as F + + data = F.pil_to_tensor(data) + + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor) + + @classmethod + def wrap_like( + cls, + other: Mask, + tensor: torch.Tensor, + ) -> Mask: + return cls._wrap(tensor) + + @property + def spatial_size(self) -> Tuple[int, int]: + return tuple(self.shape[-2:]) # type: ignore[return-value] + + def horizontal_flip(self) -> Mask: + output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor)) + return Mask.wrap_like(self, output) + + def vertical_flip(self) -> Mask: + output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor)) + return Mask.wrap_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Mask: + output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size) + return Mask.wrap_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Mask: + output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width) + return Mask.wrap_like(self, output) + + def center_crop(self, output_size: List[int]) -> Mask: + output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size) + return Mask.wrap_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Mask: + output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size) + return Mask.wrap_like(self, output) + + def pad( + self, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", + ) -> Mask: + output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill) + return Mask.wrap_like(self, output) + + def rotate( + self, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: FillTypeJIT = None, + ) -> Mask: + output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill) + return Mask.wrap_like(self, output) + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Mask: + output = self._F.affine_mask( + self.as_subclass(torch.Tensor), + angle, + translate=translate, + scale=scale, + shear=shear, + fill=fill, + center=center, + ) + return Mask.wrap_like(self, output) + + def perspective( + self, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + ) -> Mask: + output = self._F.perspective_mask( + self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients + ) + return Mask.wrap_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + ) -> Mask: + output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill) + return Mask.wrap_like(self, output) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py new file mode 100644 index 00000000000..f62edd68eaf --- /dev/null +++ b/torchvision/datapoints/_video.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Tuple, Union + +import torch +from torchvision.transforms.functional import InterpolationMode + +from ._datapoint import Datapoint, FillTypeJIT + + +class Video(Datapoint): + @classmethod + def _wrap(cls, tensor: torch.Tensor) -> Video: + video = tensor.as_subclass(cls) + return video + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> Video: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + if data.ndim < 4: + raise ValueError + return cls._wrap(tensor) + + @classmethod + def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: + return cls._wrap(tensor) + + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr() + + @property + def spatial_size(self) -> Tuple[int, int]: + return tuple(self.shape[-2:]) # type: ignore[return-value] + + @property + def num_channels(self) -> int: + return self.shape[-3] + + @property + def num_frames(self) -> int: + return self.shape[-4] + + def horizontal_flip(self) -> Video: + output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) + return Video.wrap_like(self, output) + + def vertical_flip(self) -> Video: + output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor)) + return Video.wrap_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Video: + output = self._F.resize_video( + self.as_subclass(torch.Tensor), + size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ) + return Video.wrap_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Video: + output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width) + return Video.wrap_like(self, output) + + def center_crop(self, output_size: List[int]) -> Video: + output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size) + return Video.wrap_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> Video: + output = self._F.resized_crop_video( + self.as_subclass(torch.Tensor), + top, + left, + height, + width, + size=list(size), + interpolation=interpolation, + antialias=antialias, + ) + return Video.wrap_like(self, output) + + def pad( + self, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", + ) -> Video: + output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) + return Video.wrap_like(self, output) + + def rotate( + self, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: FillTypeJIT = None, + ) -> Video: + output = self._F.rotate_video( + self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Video.wrap_like(self, output) + + def affine( + self, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: FillTypeJIT = None, + center: Optional[List[float]] = None, + ) -> Video: + output = self._F.affine_video( + self.as_subclass(torch.Tensor), + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Video.wrap_like(self, output) + + def perspective( + self, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + ) -> Video: + output = self._F.perspective_video( + self.as_subclass(torch.Tensor), + startpoints, + endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + return Video.wrap_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: FillTypeJIT = None, + ) -> Video: + output = self._F.elastic_video( + self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill + ) + return Video.wrap_like(self, output) + + def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video: + output = self._F.rgb_to_grayscale_image_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Video.wrap_like(self, output) + + def adjust_brightness(self, brightness_factor: float) -> Video: + output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor) + return Video.wrap_like(self, output) + + def adjust_saturation(self, saturation_factor: float) -> Video: + output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor) + return Video.wrap_like(self, output) + + def adjust_contrast(self, contrast_factor: float) -> Video: + output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor) + return Video.wrap_like(self, output) + + def adjust_sharpness(self, sharpness_factor: float) -> Video: + output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor) + return Video.wrap_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Video: + output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor) + return Video.wrap_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Video: + output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain) + return Video.wrap_like(self, output) + + def posterize(self, bits: int) -> Video: + output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits) + return Video.wrap_like(self, output) + + def solarize(self, threshold: float) -> Video: + output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold) + return Video.wrap_like(self, output) + + def autocontrast(self) -> Video: + output = self._F.autocontrast_video(self.as_subclass(torch.Tensor)) + return Video.wrap_like(self, output) + + def equalize(self) -> Video: + output = self._F.equalize_video(self.as_subclass(torch.Tensor)) + return Video.wrap_like(self, output) + + def invert(self) -> Video: + output = self._F.invert_video(self.as_subclass(torch.Tensor)) + return Video.wrap_like(self, output) + + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video: + output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) + return Video.wrap_like(self, output) + + def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video: + output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) + return Video.wrap_like(self, output) + + +VideoType = Union[torch.Tensor, Video] +VideoTypeJIT = torch.Tensor +TensorVideoType = Union[torch.Tensor, Video] +TensorVideoTypeJIT = torch.Tensor diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py new file mode 100644 index 00000000000..ff3b758454a --- /dev/null +++ b/torchvision/transforms/v2/__init__.py @@ -0,0 +1,59 @@ +from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip + +from . import functional, utils # usort: skip + +from ._transform import Transform # usort: skip +from ._presets import StereoMatching # usort: skip + +from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste +from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._color import ( + ColorJitter, + Grayscale, + RandomAdjustSharpness, + RandomAutocontrast, + RandomEqualize, + RandomGrayscale, + RandomInvert, + RandomPhotometricDistort, + RandomPosterize, + RandomSolarize, +) +from ._container import Compose, RandomApply, RandomChoice, RandomOrder +from ._geometry import ( + CenterCrop, + ElasticTransform, + FiveCrop, + FixedSizeCrop, + Pad, + RandomAffine, + RandomCrop, + RandomHorizontalFlip, + RandomIoUCrop, + RandomPerspective, + RandomResize, + RandomResizedCrop, + RandomRotation, + RandomShortestSize, + RandomVerticalFlip, + RandomZoomOut, + Resize, + ScaleJitter, + TenCrop, +) +from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype +from ._misc import ( + GaussianBlur, + Identity, + Lambda, + LinearTransformation, + Normalize, + PermuteDimensions, + SanitizeBoundingBoxes, + ToDtype, + TransposeDimensions, +) +from ._temporal import UniformTemporalSubsample +from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage + +from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py new file mode 100644 index 00000000000..3ceabba5e42 --- /dev/null +++ b/torchvision/transforms/v2/_augment.py @@ -0,0 +1,395 @@ +import math +import numbers +import warnings +from typing import Any, cast, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision import transforms as _transforms +from torchvision.ops import masks_to_boxes +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation + +from ._transform import _RandomApplyTransform +from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size + + +class RandomErasing(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomErasing + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return dict( + super()._extract_params_for_v1_transform(), + value="random" if self.value is None else self.value, + ) + + _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) + + def __init__( + self, + p: float = 0.5, + scale: Tuple[float, float] = (0.02, 0.33), + ratio: Tuple[float, float] = (0.3, 3.3), + value: float = 0.0, + inplace: bool = False, + ): + super().__init__(p=p) + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + self.scale = scale + self.ratio = ratio + if isinstance(value, (int, float)): + self.value = [float(value)] + elif isinstance(value, str): + self.value = None + elif isinstance(value, (list, tuple)): + self.value = [float(v) for v in value] + else: + self.value = value + self.inplace = inplace + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + img_c, img_h, img_w = query_chw(flat_inputs) + + if self.value is not None and not (len(self.value) in (1, img_c)): + raise ValueError( + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" + ) + + area = img_h * img_w + + log_ratio = self._log_ratio + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if self.value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(self.value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + break + else: + i, j, h, w, v = 0, 0, img_h, img_w, None + + return dict(i=i, j=j, h=h, w=w, v=v) + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + if params["v"] is not None: + inpt = F.erase(inpt, **params, inplace=self.inplace) + + return inpt + + +class _BaseMixupCutmix(_RandomApplyTransform): + def __init__(self, alpha: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.alpha = alpha + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if not ( + has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) + and has_any(flat_inputs, datapoints.OneHotLabel) + ): + raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") + if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label): + raise TypeError( + f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." + ) + + def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel: + if inpt.ndim < 2: + raise ValueError("Need a batch of one hot labels") + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) + return datapoints.OneHotLabel.wrap_like(inpt, output) + + +class RandomMixup(_BaseMixupCutmix): + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + lam = params["lam"] + if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 + if inpt.ndim < expected_ndim: + raise ValueError("The transform expects a batched input") + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + + return output + elif isinstance(inpt, datapoints.OneHotLabel): + return self._mixup_onehotlabel(inpt, lam) + else: + return inpt + + +class RandomCutmix(_BaseMixupCutmix): + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + lam = float(self._dist.sample(())) # type: ignore[arg-type] + + H, W = query_spatial_size(flat_inputs) + + r_x = torch.randint(W, ()) + r_y = torch.randint(H, ()) + + r = 0.5 * math.sqrt(1.0 - lam) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + box = (x1, y1, x2, y2) + + lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + return dict(box=box, lam_adjusted=lam_adjusted) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + box = params["box"] + expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 + if inpt.ndim < expected_ndim: + raise ValueError("The transform expects a batched input") + x1, y1, x2, y2 = box + rolled = inpt.roll(1, 0) + output = inpt.clone() + output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + + return output + elif isinstance(inpt, datapoints.OneHotLabel): + lam_adjusted = params["lam_adjusted"] + return self._mixup_onehotlabel(inpt, lam_adjusted) + else: + return inpt + + +class SimpleCopyPaste(Transform): + def __init__( + self, + blending: bool = True, + resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, + ) -> None: + super().__init__() + self.resize_interpolation = _check_interpolation(resize_interpolation) + self.blending = blending + self.antialias = antialias + + def _copy_paste( + self, + image: datapoints.TensorImageType, + target: Dict[str, Any], + paste_image: datapoints.TensorImageType, + paste_target: Dict[str, Any], + random_selection: torch.Tensor, + blending: bool, + resize_interpolation: F.InterpolationMode, + antialias: Optional[bool], + ) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]: + + paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) + paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) + paste_labels = paste_target["labels"].wrap_like( + paste_target["labels"], paste_target["labels"][random_selection] + ) + + masks = target["masks"] + + # We resize source and paste data if they have different sizes + # This is something different to TF implementation we introduced here as + # originally the algorithm works on equal-sized data + # (for example, coming from LSJ data augmentations) + size1 = cast(List[int], image.shape[-2:]) + size2 = paste_image.shape[-2:] + if size1 != size2: + paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) + paste_masks = F.resize(paste_masks, size=size1) + paste_boxes = F.resize(paste_boxes, size=size1) + + paste_alpha_mask = paste_masks.sum(dim=0) > 0 + + if blending: + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) + + inverse_paste_alpha_mask = paste_alpha_mask.logical_not() + # Copy-paste images: + image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) + + # Copy-paste masks: + masks = masks * inverse_paste_alpha_mask + non_all_zero_masks = masks.sum((-1, -2)) > 0 + masks = masks[non_all_zero_masks] + + # Do a shallow copy of the target dict + out_target = {k: v for k, v in target.items()} + + out_target["masks"] = torch.cat([masks, paste_masks]) + + # Copy-paste boxes and labels + bbox_format = target["boxes"].format + xyxy_boxes = masks_to_boxes(masks) + # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive + # we need to add +1 to x2y2. + # There is a similar +1 in other reference implementations: + # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 + xyxy_boxes[:, 2:] += 1 + boxes = F.convert_format_bounding_box( + xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True + ) + out_target["boxes"] = torch.cat([boxes, paste_boxes]) + + labels = target["labels"][non_all_zero_masks] + out_target["labels"] = torch.cat([labels, paste_labels]) + + # Check for degenerated boxes and remove them + boxes = F.convert_format_bounding_box( + out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY + ) + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + valid_targets = ~degenerate_boxes.any(dim=1) + + out_target["boxes"] = boxes[valid_targets] + out_target["masks"] = out_target["masks"][valid_targets] + out_target["labels"] = out_target["labels"][valid_targets] + + return image, out_target + + def _extract_image_targets( + self, flat_sample: List[Any] + ) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]: + # fetch all images, bboxes, masks and labels from unstructured input + # with List[image], List[BoundingBox], List[Mask], List[Label] + images, bboxes, masks, labels = [], [], [], [] + for obj in flat_sample: + if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): + images.append(obj) + elif isinstance(obj, PIL.Image.Image): + images.append(F.to_image_tensor(obj)) + elif isinstance(obj, datapoints.BoundingBox): + bboxes.append(obj) + elif isinstance(obj, datapoints.Mask): + masks.append(obj) + elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): + labels.append(obj) + + if not (len(images) == len(bboxes) == len(masks) == len(labels)): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " + "BoundingBoxes, Masks and Labels or OneHotLabels." + ) + + targets = [] + for bbox, mask, label in zip(bboxes, masks, labels): + targets.append({"boxes": bbox, "masks": mask, "labels": label}) + + return images, targets + + def _insert_outputs( + self, + flat_sample: List[Any], + output_images: List[datapoints.TensorImageType], + output_targets: List[Dict[str, Any]], + ) -> None: + c0, c1, c2, c3 = 0, 0, 0, 0 + for i, obj in enumerate(flat_sample): + if isinstance(obj, datapoints.Image): + flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) + c0 += 1 + elif isinstance(obj, PIL.Image.Image): + flat_sample[i] = F.to_image_pil(output_images[c0]) + c0 += 1 + elif is_simple_tensor(obj): + flat_sample[i] = output_images[c0] + c0 += 1 + elif isinstance(obj, datapoints.BoundingBox): + flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) + c1 += 1 + elif isinstance(obj, datapoints.Mask): + flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) + c2 += 1 + elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): + flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] + c3 += 1 + + def forward(self, *inputs: Any) -> Any: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + + images, targets = self._extract_image_targets(flat_inputs) + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images, output_targets = [], [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + output_image, output_target = image, target + else: + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection) + + output_image, output_target = self._copy_paste( + image, + target, + paste_image, + paste_target, + random_selection=random_selection, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + antialias=self.antialias, + ) + output_images.append(output_image) + output_targets.append(output_target) + + # Insert updated images and targets into input flat_sample + self._insert_outputs(flat_inputs, output_images, output_targets) + + return tree_unflatten(flat_inputs, spec) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py new file mode 100644 index 00000000000..67afecf5df1 --- /dev/null +++ b/torchvision/transforms/v2/_auto_augment.py @@ -0,0 +1,536 @@ +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import PIL.Image +import torch + +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +from torchvision import transforms as _transforms +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation +from torchvision.prototype.transforms.functional._meta import get_spatial_size +from torchvision.transforms import functional_tensor as _FT + +from ._utils import _setup_fill_arg +from .utils import check_type, is_simple_tensor + + +class _AutoAugmentBase(Transform): + def __init__( + self, + *, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, + ) -> None: + super().__init__() + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: + keys = tuple(dct.keys()) + key = keys[int(torch.randint(len(keys), ()))] + return key, dct[key] + + def _flatten_and_extract_image_or_video( + self, + inputs: Any, + unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), + ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + needs_transform_list = self._needs_transform_list(flat_inputs) + + image_or_videos = [] + for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): + if needs_transform and check_type( + inpt, + ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ), + ): + image_or_videos.append((idx, inpt)) + elif isinstance(inpt, unsupported_types): + raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") + + if not image_or_videos: + raise TypeError("Found no image in the sample.") + if len(image_or_videos) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image or video, " + f"but found {len(image_or_videos)}." + ) + + idx, image_or_video = image_or_videos[0] + return (flat_inputs, spec, idx), image_or_video + + def _unflatten_and_insert_image_or_video( + self, + flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], + image_or_video: Union[datapoints.ImageType, datapoints.VideoType], + ) -> Any: + flat_inputs, spec, idx = flat_inputs_with_spec + flat_inputs[idx] = image_or_video + return tree_unflatten(flat_inputs, spec) + + def _apply_image_or_video_transform( + self, + image: Union[datapoints.ImageType, datapoints.VideoType], + transform_id: str, + magnitude: float, + interpolation: Union[InterpolationMode, int], + fill: Dict[Type, datapoints.FillTypeJIT], + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + fill_ = fill[type(image)] + + if transform_id == "Identity": + return image + elif transform_id == "ShearX": + # magnitude should be arctan(magnitude) + # official autoaug: (1, level, 0, 0, 1, 0) + # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 + # compared to + # torchvision: (1, tan(level), 0, 0, 1, 0) + # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 + return F.affine( + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(math.atan(magnitude)), 0.0], + interpolation=interpolation, + fill=fill_, + center=[0, 0], + ) + elif transform_id == "ShearY": + # magnitude should be arctan(magnitude) + # See above + return F.affine( + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(math.atan(magnitude))], + interpolation=interpolation, + fill=fill_, + center=[0, 0], + ) + elif transform_id == "TranslateX": + return F.affine( + image, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill_, + ) + elif transform_id == "TranslateY": + return F.affine( + image, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill_, + ) + elif transform_id == "Rotate": + return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) + elif transform_id == "Brightness": + return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) + elif transform_id == "Color": + return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) + elif transform_id == "Contrast": + return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) + elif transform_id == "Sharpness": + return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) + elif transform_id == "Posterize": + return F.posterize(image, bits=int(magnitude)) + elif transform_id == "Solarize": + bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0 + return F.solarize(image, threshold=bound * magnitude) + elif transform_id == "AutoContrast": + return F.autocontrast(image) + elif transform_id == "Equalize": + return F.equalize(image) + elif transform_id == "Invert": + return F.invert(image) + else: + raise ValueError(f"No transform available for {transform_id}") + + +class AutoAugment(_AutoAugmentBase): + _v1_transform_cls = _transforms.AutoAugment + + _AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Invert": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self.policy = policy + self._policies = self._get_policies(policy) + + def _get_policies( + self, policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_spatial_size(image_or_video) + + policy = self._policies[int(torch.randint(len(self._policies), ()))] + + for transform_id, probability, magnitude_idx in policy: + if not torch.rand(()) <= probability: + continue + + magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] + + magnitudes = magnitudes_fn(10, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[magnitude_idx]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + ) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class RandAugment(_AutoAugmentBase): + _v1_transform_cls = _transforms.RandAugment + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_spatial_size(image_or_video) + + for _ in range(self.num_ops): + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[self.magnitude]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + ) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class TrivialAugmentWide(_AutoAugmentBase): + _v1_transform_cls = _transforms.TrivialAugmentWide + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + num_magnitude_bins: int = 31, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, + ): + super().__init__(interpolation=interpolation, fill=fill) + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_spatial_size(image_or_video) + + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + ) + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class AugMix(_AutoAugmentBase): + _v1_transform_cls = _transforms.AugMix + + _PARTIAL_AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = { + **_PARTIAL_AUGMENTATION_SPACE, + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + } + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_spatial_size(orig_image_or_video) + + if isinstance(orig_image_or_video, torch.Tensor): + image_or_video = orig_image_or_video + else: # isinstance(inpt, PIL.Image.Image): + image_or_video = F.pil_to_tensor(orig_image_or_video) + + augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE + + orig_dims = list(image_or_video.shape) + expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4 + batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a + # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of + # augmented image or video. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].reshape([batch_dims[0], -1]) + + mix = m[:, 0].reshape(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + + magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + aug = self._apply_image_or_video_transform( + aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + ) + mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) + mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) + + if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): + mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] + elif isinstance(orig_image_or_video, PIL.Image.Image): + mix = F.to_image_pil(mix) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py new file mode 100644 index 00000000000..8ac0d857753 --- /dev/null +++ b/torchvision/transforms/v2/_color.py @@ -0,0 +1,259 @@ +import collections.abc +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import PIL.Image +import torch +from torchvision import transforms as _transforms +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, Transform + +from ._transform import _RandomApplyTransform +from .utils import is_simple_tensor, query_chw + + +class Grayscale(Transform): + _v1_transform_cls = _transforms.Grayscale + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, num_output_channels: int = 1): + super().__init__() + self.num_output_channels = num_output_channels + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + + +class RandomGrayscale(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomGrayscale + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, p: float = 0.1) -> None: + super().__init__(p=p) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_input_channels, *_ = query_chw(flat_inputs) + return dict(num_input_channels=num_input_channels) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) + + +class ColorJitter(Transform): + _v1_transform_cls = _transforms.ColorJitter + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()} + + def __init__( + self, + brightness: Optional[Union[float, Sequence[float]]] = None, + contrast: Optional[Union[float, Sequence[float]]] = None, + saturation: Optional[Union[float, Sequence[float]]] = None, + hue: Optional[Union[float, Sequence[float]]] = None, + ) -> None: + super().__init__() + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + def _check_input( + self, + value: Optional[Union[float, Sequence[float]]], + name: str, + center: float = 1.0, + bound: Tuple[float, float] = (0, float("inf")), + clip_first_on_zero: bool = True, + ) -> Optional[Tuple[float, float]]: + if value is None: + return None + + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If {name} is a single number, it must be non negative.") + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, collections.abc.Sequence) and len(value) == 2: + value = [float(v) for v in value] + else: + raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}, but got {value}.") + + return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) + + @staticmethod + def _generate_value(left: float, right: float) -> float: + return torch.empty(1).uniform_(left, right).item() + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + fn_idx = torch.randperm(4) + + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + + return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + output = inpt + brightness_factor = params["brightness_factor"] + contrast_factor = params["contrast_factor"] + saturation_factor = params["saturation_factor"] + hue_factor = params["hue_factor"] + for fn_id in params["fn_idx"]: + if fn_id == 0 and brightness_factor is not None: + output = F.adjust_brightness(output, brightness_factor=brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + output = F.adjust_contrast(output, contrast_factor=contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + output = F.adjust_saturation(output, saturation_factor=saturation_factor) + elif fn_id == 3 and hue_factor is not None: + output = F.adjust_hue(output, hue_factor=hue_factor) + return output + + +# TODO: This class seems to be untested +class RandomPhotometricDistort(Transform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__( + self, + contrast: Tuple[float, float] = (0.5, 1.5), + saturation: Tuple[float, float] = (0.5, 1.5), + hue: Tuple[float, float] = (-0.05, 0.05), + brightness: Tuple[float, float] = (0.875, 1.125), + p: float = 0.5, + ): + super().__init__() + self.brightness = brightness + self.contrast = contrast + self.hue = hue + self.saturation = saturation + self.p = p + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + return dict( + zip( + ["brightness", "contrast1", "saturation", "hue", "contrast2"], + (torch.rand(5) < self.p).tolist(), + ), + contrast_before=bool(torch.rand(()) < 0.5), + channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, + ) + + def _permute_channels( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + + orig_inpt = inpt + if isinstance(orig_inpt, PIL.Image.Image): + inpt = F.pil_to_tensor(inpt) + + # TODO: Find a better fix than as_subclass??? + output = inpt[..., permutation, :, :].as_subclass(type(inpt)) + + if isinstance(orig_inpt, PIL.Image.Image): + output = F.to_image_pil(output) + + return output + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + if params["brightness"]: + inpt = F.adjust_brightness( + inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) + ) + if params["contrast1"] and params["contrast_before"]: + inpt = F.adjust_contrast( + inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) + ) + if params["saturation"]: + inpt = F.adjust_saturation( + inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1]) + ) + if params["hue"]: + inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1])) + if params["contrast2"] and not params["contrast_before"]: + inpt = F.adjust_contrast( + inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) + ) + if params["channel_permutation"] is not None: + inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) + return inpt + + +class RandomEqualize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomEqualize + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.equalize(inpt) + + +class RandomInvert(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomInvert + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.invert(inpt) + + +class RandomPosterize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomPosterize + + def __init__(self, bits: int, p: float = 0.5) -> None: + super().__init__(p=p) + self.bits = bits + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.posterize(inpt, bits=self.bits) + + +class RandomSolarize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomSolarize + + def __init__(self, threshold: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.threshold = threshold + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.solarize(inpt, threshold=self.threshold) + + +class RandomAutocontrast(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomAutocontrast + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.autocontrast(inpt) + + +class RandomAdjustSharpness(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomAdjustSharpness + + def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.sharpness_factor = sharpness_factor + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py new file mode 100644 index 00000000000..42c73a2c11e --- /dev/null +++ b/torchvision/transforms/v2/_container.py @@ -0,0 +1,113 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch + +from torch import nn +from torchvision import transforms as _transforms +from torchvision.prototype.transforms import Transform + + +class Compose(Transform): + def __init__(self, transforms: Sequence[Callable]) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + self.transforms = transforms + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + for transform in self.transforms: + sample = transform(sample) + return sample + + def extra_repr(self) -> str: + format_string = [] + for t in self.transforms: + format_string.append(f" {t}") + return "\n".join(format_string) + + +class RandomApply(Transform): + _v1_transform_cls = _transforms.RandomApply + + def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: + super().__init__() + + if not isinstance(transforms, (Sequence, nn.ModuleList)): + raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") + self.transforms = transforms + + if not (0.0 <= p <= 1.0): + raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") + self.p = p + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {"transforms": self.transforms, "p": self.p} + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + if torch.rand(1) >= self.p: + return sample + + for transform in self.transforms: + sample = transform(sample) + return sample + + def extra_repr(self) -> str: + format_string = [] + for t in self.transforms: + format_string.append(f" {t}") + return "\n".join(format_string) + + +class RandomChoice(Transform): + def __init__( + self, + transforms: Sequence[Callable], + probabilities: Optional[List[float]] = None, + p: Optional[List[float]] = None, + ) -> None: + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if p is not None: + warnings.warn( + "Argument p is deprecated and will be removed in a future release. " + "Please use probabilities argument instead." + ) + probabilities = p + + if probabilities is None: + probabilities = [1] * len(transforms) + elif len(probabilities) != len(transforms): + raise ValueError( + f"The number of probabilities doesn't match the number of transforms: " + f"{len(probabilities)} != {len(transforms)}" + ) + + super().__init__() + + self.transforms = transforms + total = sum(probabilities) + self.probabilities = [prob / total for prob in probabilities] + + def forward(self, *inputs: Any) -> Any: + idx = int(torch.multinomial(torch.tensor(self.probabilities), 1)) + transform = self.transforms[idx] + return transform(*inputs) + + +class RandomOrder(Transform): + def __init__(self, transforms: Sequence[Callable]) -> None: + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + super().__init__() + self.transforms = transforms + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + for idx in torch.randperm(len(self.transforms)): + transform = self.transforms[idx] + sample = transform(sample) + return sample diff --git a/torchvision/transforms/v2/_deprecated.py b/torchvision/transforms/v2/_deprecated.py new file mode 100644 index 00000000000..cd37f4d73d0 --- /dev/null +++ b/torchvision/transforms/v2/_deprecated.py @@ -0,0 +1,23 @@ +import warnings +from typing import Any, Dict, Union + +import numpy as np +import PIL.Image +import torch + +from torchvision.prototype.transforms import Transform +from torchvision.transforms import functional as _F + + +class ToTensor(Transform): + _transformed_types = (PIL.Image.Image, np.ndarray) + + def __init__(self) -> None: + warnings.warn( + "The transform `ToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`." + ) + super().__init__() + + def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: + return _F.to_tensor(inpt) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py new file mode 100644 index 00000000000..69238760be5 --- /dev/null +++ b/torchvision/transforms/v2/_geometry.py @@ -0,0 +1,964 @@ +import math +import numbers +import warnings +from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union + +import PIL.Image +import torch + +from torchvision import transforms as _transforms +from torchvision.ops.boxes import box_iou +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms.functional._geometry import _check_interpolation +from torchvision.transforms.functional import _get_perspective_coeffs + +from ._transform import _RandomApplyTransform +from ._utils import ( + _check_padding_arg, + _check_padding_mode_arg, + _check_sequence_input, + _setup_angle, + _setup_fill_arg, + _setup_float_or_seq, + _setup_size, +) +from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size + + +class RandomHorizontalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomHorizontalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.horizontal_flip(inpt) + + +class RandomVerticalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomVerticalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.vertical_flip(inpt) + + +class Resize(Transform): + _v1_transform_cls = _transforms.Resize + + def __init__( + self, + size: Union[int, Sequence[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + + if isinstance(size, int): + size = [size] + elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: + size = list(size) + else: + raise ValueError( + f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." + ) + self.size = size + + self.interpolation = _check_interpolation(interpolation) + self.max_size = max_size + self.antialias = antialias + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) + + +class CenterCrop(Transform): + _v1_transform_cls = _transforms.CenterCrop + + def __init__(self, size: Union[int, Sequence[int]]): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.center_crop(inpt, output_size=self.size) + + +class RandomResizedCrop(Transform): + _v1_transform_cls = _transforms.RandomResizedCrop + + def __init__( + self, + size: Union[int, Sequence[int]], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + scale = cast(Tuple[float, float], scale) + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + ratio = cast(Tuple[float, float], ratio) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + self.scale = scale + self.ratio = ratio + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + area = height * width + + log_ratio = self._log_ratio + for _ in range(10): + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + + return dict(top=i, left=j, height=h, width=w) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resized_crop( + inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + ) + + +ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] + + +class FiveCrop(Transform): + """ + Example: + >>> class BatchMultiCrop(transforms.Transform): + ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]): + ... images_or_videos, labels = sample + ... batch_size = len(images_or_videos) + ... image_or_video = images_or_videos[0] + ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) + ... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size)) + ... return images_or_videos, labels + ... + >>> image = datapoints.Image(torch.rand(3, 256, 256)) + >>> label = datapoints.Label(0) + >>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()]) + >>> images, labels = transform(image, label) + >>> images.shape + torch.Size([5, 3, 224, 224]) + >>> labels.shape + torch.Size([5]) + """ + + _v1_transform_cls = _transforms.FiveCrop + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform( + self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any] + ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + return F.five_crop(inpt, self.size) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): + raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + + +class TenCrop(Transform): + """ + See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. + """ + + _v1_transform_cls = _transforms.TenCrop + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): + raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ]: + return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) + + +class Pad(Transform): + _v1_transform_cls = _transforms.Pad + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + return params + + def __init__( + self, + padding: Union[int, Sequence[int]], + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + # This cast does Sequence[int] -> List[int] and is required to make mypy happy + if not isinstance(padding, int): + padding = list(padding) + self.padding = padding + self.fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + +class RandomZoomOut(_RandomApplyTransform): + def __init__( + self, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + side_range: Sequence[float] = (1.0, 4.0), + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + self.fill = _setup_fill_arg(fill) + + _check_sequence_input(side_range, "side_range", req_sizes=(2,)) + + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid canvas side range provided {side_range}.") + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_spatial_size(flat_inputs) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + padding = [left, top, right, bottom] + + return dict(padding=padding) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.pad(inpt, **params, fill=fill) + + +class RandomRotation(Transform): + _v1_transform_cls = _transforms.RandomRotation + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + center: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.interpolation = _check_interpolation(interpolation) + self.expand = expand + + self.fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + return dict(angle=angle) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.rotate( + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + center=self.center, + fill=fill, + ) + + +class RandomAffine(Transform): + _v1_transform_cls = _transforms.RandomAffine + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + translate: Optional[Sequence[float]] = None, + scale: Optional[Sequence[float]] = None, + shear: Optional[Union[int, float, Sequence[float]]] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + center: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + if self.translate is not None: + max_dx = float(self.translate[0] * width) + max_dy = float(self.translate[1] * height) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translate = (tx, ty) + else: + translate = (0, 0) + + if self.scale is not None: + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if self.shear is not None: + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + if len(self.shear) == 4: + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + + shear = (shear_x, shear_y) + return dict(angle=angle, translate=translate, scale=scale, shear=shear) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.affine( + inpt, + **params, + interpolation=self.interpolation, + fill=fill, + center=self.center, + ) + + +class RandomCrop(Transform): + _v1_transform_cls = _transforms.RandomCrop + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + padding = self.padding + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + padding = [pad_left, pad_top, pad_right, pad_bottom] + params["padding"] = padding + + return params + + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if pad_if_needed or padding is not None: + if padding is not None: + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] + self.pad_if_needed = pad_if_needed + self.fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + padded_height, padded_width = query_spatial_size(flat_inputs) + + if self.padding is not None: + pad_left, pad_right, pad_top, pad_bottom = self.padding + padded_height += pad_top + pad_bottom + padded_width += pad_left + pad_right + else: + pad_left = pad_right = pad_top = pad_bottom = 0 + + cropped_height, cropped_width = self.size + + if self.pad_if_needed: + if padded_height < cropped_height: + diff = cropped_height - padded_height + + pad_top += diff + pad_bottom += diff + padded_height += 2 * diff + + if padded_width < cropped_width: + diff = cropped_width - padded_width + + pad_left += diff + pad_right += diff + padded_width += 2 * diff + + if padded_height < cropped_height or padded_width < cropped_width: + raise ValueError( + f"Required crop size {(cropped_height, cropped_width)} is larger than " + f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." + ) + + # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` + padding = [pad_left, pad_top, pad_right, pad_bottom] + needs_pad = any(padding) + + needs_vert_crop, top = ( + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + if padded_height > cropped_height + else (False, 0) + ) + needs_horz_crop, left = ( + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + if padded_width > cropped_width + else (False, 0) + ) + + return dict( + needs_crop=needs_vert_crop or needs_horz_crop, + top=top, + left=left, + height=cropped_height, + width=cropped_width, + needs_pad=needs_pad, + padding=padding, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_pad"]: + fill = self.fill[type(inpt)] + inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + + if params["needs_crop"]: + inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + return inpt + + +class RandomPerspective(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomPerspective + + def __init__( + self, + distortion_scale: float = 0.5, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + if not (0 <= distortion_scale <= 1): + raise ValueError("Argument distortion_scale value should be between 0 and 1") + + self.distortion_scale = distortion_scale + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + + distortion_scale = self.distortion_scale + + half_height = height // 2 + half_width = width // 2 + bound_height = int(distortion_scale * half_height) + 1 + bound_width = int(distortion_scale * half_width) + 1 + topleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + topright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + botright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + botleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) + return dict(coefficients=perspective_coeffs) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.perspective( + inpt, + None, + None, + fill=fill, + interpolation=self.interpolation, + **params, + ) + + +class ElasticTransform(Transform): + _v1_transform_cls = _transforms.ElasticTransform + + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.alpha = _setup_float_or_seq(alpha, "alpha", 2) + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + self.interpolation = _check_interpolation(interpolation) + self.fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = list(query_spatial_size(flat_inputs)) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] + return F.elastic( + inpt, + **params, + fill=fill, + interpolation=self.interpolation, + ) + + +class RandomIoUCrop(Transform): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if not ( + has_all(flat_inputs, datapoints.BoundingBox) + and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) + and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks." + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_spatial_size(flat_inputs) + bboxes = query_bounding_box(flat_inputs) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return dict() + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # check for any valid boxes with centers within the crop area + xyxy_bboxes = F.convert_format_bounding_box( + bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY + ) + cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) + cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + xyxy_bboxes = xyxy_bboxes[is_within_crop_area] + ious = box_iou( + xyxy_bboxes, + torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), + ) + if ious.max() < min_jaccard_overlap: + continue + + return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if len(params) < 1: + return inpt + + is_within_crop_area = params["is_within_crop_area"] + + if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)): + return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] + + output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + if isinstance(output, datapoints.BoundingBox): + bboxes = output[is_within_crop_area] + bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) + output = datapoints.BoundingBox.wrap_like(output, bboxes) + elif isinstance(output, datapoints.Mask): + # apply is_within_crop_area if mask is one-hot encoded + masks = output[is_within_crop_area] + output = datapoints.Mask.wrap_like(output, masks) + + return output + + +class ScaleJitter(Transform): + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_spatial_size(flat_inputs) + + scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + + +class RandomShortestSize(Transform): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: Optional[int] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_spatial_size(flat_inputs) + + min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + r = min_size / min(orig_height, orig_width) + if self.max_size is not None: + r = min(r, self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + + +class FixedSizeCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, + padding_mode: str = "constant", + ) -> None: + super().__init__() + size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.crop_height = size[0] + self.crop_width = size[1] + + self.fill = _setup_fill_arg(fill) + + self.padding_mode = padding_mode + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if not has_any( + flat_inputs, + PIL.Image.Image, + datapoints.Image, + is_simple_tensor, + datapoints.Video, + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." + ) + + if has_any(flat_inputs, datapoints.BoundingBox) and not has_any( + flat_inputs, datapoints.Label, datapoints.OneHotLabel + ): + raise TypeError( + f"If a BoundingBox is contained in the input sample, " + f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_spatial_size(flat_inputs) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + needs_crop = new_height != height or new_width != width + + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = torch.rand(1) + top = int(offset_height * r) + left = int(offset_width * r) + + bounding_boxes: Optional[torch.Tensor] + try: + bounding_boxes = query_bounding_box(flat_inputs) + except ValueError: + bounding_boxes = None + + if needs_crop and bounding_boxes is not None: + format = bounding_boxes.format + bounding_boxes, spatial_size = F.crop_bounding_box( + bounding_boxes.as_subclass(torch.Tensor), + format=format, + top=top, + left=left, + height=new_height, + width=new_width, + ) + bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) + height_and_width = F.convert_format_bounding_box( + bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH + )[..., 2:] + is_valid = torch.all(height_and_width > 0, dim=-1) + else: + is_valid = None + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + + needs_pad = pad_bottom != 0 or pad_right != 0 + + return dict( + needs_crop=needs_crop, + top=top, + left=left, + height=new_height, + width=new_width, + is_valid=is_valid, + padding=[0, 0, pad_right, pad_bottom], + needs_pad=needs_pad, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_crop"]: + inpt = F.crop( + inpt, + top=params["top"], + left=params["left"], + height=params["height"], + width=params["width"], + ) + + if params["is_valid"] is not None: + if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)): + inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] + elif isinstance(inpt, datapoints.BoundingBox): + inpt = datapoints.BoundingBox.wrap_like( + inpt, + F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), + ) + + if params["needs_pad"]: + fill = self.fill[type(inpt)] + inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) + + return inpt + + +class RandomResize(Transform): + def __init__( + self, + min_size: int, + max_size: int, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", + ) -> None: + super().__init__() + self.min_size = min_size + self.max_size = max_size + self.interpolation = _check_interpolation(interpolation) + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = int(torch.randint(self.min_size, self.max_size, ())) + return dict(size=[size]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py new file mode 100644 index 00000000000..79bd5549b2e --- /dev/null +++ b/torchvision/transforms/v2/_meta.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, Union + +import torch + +from torchvision import transforms as _transforms +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, Transform + +from .utils import is_simple_tensor + + +class ConvertBoundingBoxFormat(Transform): + _transformed_types = (datapoints.BoundingBox,) + + def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: + super().__init__() + if isinstance(format, str): + format = datapoints.BoundingBoxFormat[format] + self.format = format + + def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: + return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] + + +class ConvertDtype(Transform): + _v1_transform_cls = _transforms.ConvertImageDtype + + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + super().__init__() + self.dtype = dtype + + def _transform( + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]: + return F.convert_dtype(inpt, self.dtype) + + +# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +ConvertImageDtype = ConvertDtype + + +class ClampBoundingBox(Transform): + _transformed_types = (datapoints.BoundingBox,) + + def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: + return F.clamp_bounding_box(inpt) # type: ignore[return-value] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py new file mode 100644 index 00000000000..caed3eec904 --- /dev/null +++ b/torchvision/transforms/v2/_misc.py @@ -0,0 +1,339 @@ +import collections +import warnings +from contextlib import suppress +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union + +import PIL.Image + +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten + +from torchvision import transforms as _transforms +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, Transform + +from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size +from .utils import has_any, is_simple_tensor, query_bounding_box + + +class Identity(Transform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + +class Lambda(Transform): + def __init__(self, lambd: Callable[[Any], Any], *types: Type): + super().__init__() + self.lambd = lambd + self.types = types or (object,) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, self.types): + return self.lambd(inpt) + else: + return inpt + + def extra_repr(self) -> str: + extras = [] + name = getattr(self.lambd, "__name__", None) + if name: + extras.append(name) + extras.append(f"types={[type.__name__ for type in self.types]}") + return ", ".join(extras) + + +class LinearTransformation(Transform): + _v1_transform_cls = _transforms.LinearTransformation + + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + + def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): + super().__init__() + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError( + "transformation_matrix should be square. Got " + f"{tuple(transformation_matrix.size())} rectangular matrix." + ) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError( + f"mean_vector should have the same length {mean_vector.size(0)}" + f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" + ) + + if transformation_matrix.device != mean_vector.device: + raise ValueError( + f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" + ) + + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError("LinearTransformation does not work on PIL Images") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + shape = inpt.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " + + f"{self.transformation_matrix.shape[0]}" + ) + + if inpt.device.type != self.mean_vector.device.type: + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + f"Got {inpt.device} vs {self.mean_vector.device}" + ) + + flat_inpt = inpt.reshape(-1, n) - self.mean_vector + + transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) + output = torch.mm(flat_inpt, transformation_matrix) + output = output.reshape(shape) + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + return output + + +class Normalize(Transform): + _v1_transform_cls = _transforms.Normalize + _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) + + def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): + super().__init__() + self.mean = list(mean) + self.std = list(std) + self.inplace = inplace + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError(f"{type(self).__name__}() does not support PIL images.") + + def _transform( + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> Any: + return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) + + +class GaussianBlur(Transform): + _v1_transform_cls = _transforms.GaussianBlur + + def __init__( + self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) + ) -> None: + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, (int, float)): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = float(sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") + + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + return dict(sigma=[sigma, sigma]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.gaussian_blur(inpt, self.kernel_size, **params) + + +class ToDtype(Transform): + _transformed_types = (torch.Tensor,) + + def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: + super().__init__() + if not isinstance(dtype, dict): + dtype = _get_defaultdict(dtype) + if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) + self.dtype = dtype + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + dtype = self.dtype[type(inpt)] + if dtype is None: + return inpt + return inpt.to(dtype=dtype) + + +class PermuteDimensions(Transform): + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + + def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) + self.dims = dims + + def _transform( + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: + dims = self.dims[type(inpt)] + if dims is None: + return inpt.as_subclass(torch.Tensor) + return inpt.permute(*dims) + + +class TransposeDimensions(Transform): + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + + def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) + self.dims = dims + + def _transform( + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: + dims = self.dims[type(inpt)] + if dims is None: + return inpt.as_subclass(torch.Tensor) + return inpt.transpose(*dims) + + +class SanitizeBoundingBoxes(Transform): + # This removes boxes and their corresponding labels: + # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) + # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) + + def __init__( + self, + min_size: float = 1.0, + labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", + ) -> None: + super().__init__() + + if min_size < 1: + raise ValueError(f"min_size must be >= 1, got {min_size}.") + self.min_size = min_size + + self.labels_getter = labels_getter + self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] + if labels_getter == "default": + self._labels_getter = self._find_labels_default_heuristic + elif callable(labels_getter): + self._labels_getter = labels_getter + elif isinstance(labels_getter, str): + self._labels_getter = lambda inputs: inputs[labels_getter] + elif labels_getter is None: + self._labels_getter = None + else: + raise ValueError( + "labels_getter should either be a str, callable, or 'default'. " + f"Got {labels_getter} of type {type(labels_getter)}." + ) + + @staticmethod + def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive + # Returns None if nothing is found + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + if candidate_key is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no samples and it is by design, pass labels_getter=None." + ) + return inputs[candidate_key] + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + + if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): + raise ValueError( + f"If labels_getter is a str or 'default' (got {self.labels_getter}), " + f"then the input to forward() must be a dict. Got {type(inputs)} instead." + ) + + if self._labels_getter is None: + labels = None + else: + labels = self._labels_getter(inputs) + if labels is not None and not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") + + flat_inputs, spec = tree_flatten(inputs) + # TODO: this enforces one single BoundingBox entry. + # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... + # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? + boxes = query_bounding_box(flat_inputs) + + if boxes.ndim != 2: + raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") + + if labels is not None and boxes.shape[0] != labels.shape[0]: + raise ValueError( + f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." + ) + + boxes = cast( + datapoints.BoundingBox, + F.convert_format_bounding_box( + boxes, + new_format=datapoints.BoundingBoxFormat.XYXY, + ), + ) + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = boxes.spatial_size + mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) + + params = dict(mask=mask, labels=labels) + flat_outputs = [ + # Even-though it may look like we're transforming all inputs, we don't: + # _transform() will only care about BoundingBoxes and the labels + self._transform(inpt, params) + for inpt in flat_inputs + ] + + return tree_unflatten(flat_outputs, spec) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + + if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): + inpt = inpt[params["mask"]] + + return inpt diff --git a/torchvision/transforms/v2/_presets.py b/torchvision/transforms/v2/_presets.py new file mode 100644 index 00000000000..7f18e885c39 --- /dev/null +++ b/torchvision/transforms/v2/_presets.py @@ -0,0 +1,80 @@ +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" +from typing import List, Optional, Tuple, Union + +import PIL.Image + +import torch +from torch import Tensor + +from torchvision.prototype.transforms.functional._geometry import _check_interpolation + +from . import functional as F, InterpolationMode + +__all__ = ["StereoMatching"] + + +class StereoMatching(torch.nn.Module): + def __init__( + self, + *, + use_gray_scale: bool = False, + resize_size: Optional[Tuple[int, ...]], + mean: Tuple[float, ...] = (0.5, 0.5, 0.5), + std: Tuple[float, ...] = (0.5, 0.5, 0.5), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + + # pacify mypy + self.resize_size: Union[None, List] + + if resize_size is not None: + self.resize_size = list(resize_size) + else: + self.resize_size = None + + self.mean = list(mean) + self.std = list(std) + self.interpolation = _check_interpolation(interpolation) + self.use_gray_scale = use_gray_scale + + def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]: + def _process_image(img: PIL.Image.Image) -> Tensor: + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + if self.resize_size is not None: + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the stereo models with antialias=True? + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=False) + if self.use_gray_scale is True: + img = F.rgb_to_grayscale(img) + img = F.convert_image_dtype(img, torch.float) + img = F.normalize(img, mean=self.mean, std=self.std) + img = img.contiguous() + return img + + left_image = _process_image(left_image) + right_image = _process_image(right_image) + return left_image, right_image + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " + f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " + f"``std={self.std}``." + ) diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py new file mode 100644 index 00000000000..62fe7f4edf5 --- /dev/null +++ b/torchvision/transforms/v2/_temporal.py @@ -0,0 +1,18 @@ +from typing import Any, Dict + +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, Transform + +from torchvision.prototype.transforms.utils import is_simple_tensor + + +class UniformTemporalSubsample(Transform): + _transformed_types = (is_simple_tensor, datapoints.Video) + + def __init__(self, num_samples: int, temporal_dim: int = -4): + super().__init__() + self.num_samples = num_samples + self.temporal_dim = temporal_dim + + def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType: + return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py new file mode 100644 index 00000000000..7f3c03d5e67 --- /dev/null +++ b/torchvision/transforms/v2/_transform.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import PIL.Image +import torch +from torch import nn +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision.prototype import datapoints +from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor +from torchvision.utils import _log_api_usage_once + + +class Transform(nn.Module): + + # Class attribute defining transformed types. Other types are passed-through without any transformation + # We support both Types and callables that are able to do further checks on the type of the input. + _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) + + def __init__(self) -> None: + super().__init__() + _log_api_usage_once(self) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + pass + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return dict() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + raise NotImplementedError + + def forward(self, *inputs: Any) -> Any: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + + self._check_inputs(flat_inputs) + + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) + + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) + + def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: + # Below is a heuristic on how to deal with simple tensor inputs: + # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image + # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. + # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is + # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` + # of `tree_flatten`, which recurses depth-first through the input. + # + # This heuristic stems from two requirements: + # 1. We need to keep BC for single input simple tensors and treat them as images. + # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` + # return supplemental numerical data as tensors that cannot be transformed as images. + # + # The heuristic should work well for most people in practice. The only case where it doesn't is if someone + # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. + # However, this case wasn't supported by transforms v1 either, so there is no BC concern. + + needs_transform_list = [] + transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) + for inpt in flat_inputs: + needs_transform = True + + if not check_type(inpt, self._transformed_types): + needs_transform = False + elif is_simple_tensor(inpt): + if transform_simple_tensor: + transform_simple_tensor = False + else: + needs_transform = False + needs_transform_list.append(needs_transform) + return needs_transform_list + + def extra_repr(self) -> str: + extra = [] + for name, value in self.__dict__.items(): + if name.startswith("_") or name == "training": + continue + + if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)): + continue + + extra.append(f"{name}={value}") + + return ", ".join(extra) + + # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things: + # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on + # the v2 transform. See `__init_subclass__` for details. + # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__` + # for details. + _v1_transform_cls: Optional[Type[nn.Module]] = None + + def __init_subclass__(cls) -> None: + # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. + # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. + if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): + cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined] + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current + # v2 transform instance. It does two things: + # 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general + # 2. If available handle the `fill` attribute for v1 compatibility (see below for details) + # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen + # if the v2 transform introduced new parameters that are not support by the v1 transform. + common_attrs = nn.Module().__dict__.keys() + params = { + attr: value + for attr, value in self.__dict__.items() + if not attr.startswith("_") and attr not in common_attrs + } + + # transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed + # with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value + # for the different datapoint types. Below we extract the value for tensors and return that together with the + # other params. + # This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and + # `RandomRotation` + if "fill" in params: + fill_type_defaultdict = params.pop("fill") + params["fill"] = fill_type_defaultdict[torch.Tensor] + + return params + + def __prepare_scriptable__(self) -> nn.Module: + # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return + # value is used for scripting over the original object that should have been scripted. Since the v1 transforms + # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the + # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1 + # is around. + if self._v1_transform_cls is None: + raise RuntimeError( + f"Transform {type(self).__name__} cannot be JIT scripted. " + "torchscript is only supported for backward compatibility with transforms " + "which are already in torchvision.transforms. " + "For torchscript support (on tensors only), you can use the functional API instead." + ) + + return self._v1_transform_cls(**self._extract_params_for_v1_transform()) + + +class _RandomApplyTransform(Transform): + def __init__(self, p: float = 0.5) -> None: + if not (0.0 <= p <= 1.0): + raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") + + super().__init__() + self.p = p + + def forward(self, *inputs: Any) -> Any: + # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return + # early afterwards in case the random check triggers. The same result could be achieved by calling + # `super().forward()` after the random check, but that would call `self._check_inputs` twice. + + inputs = inputs if len(inputs) > 1 else inputs[0] + flat_inputs, spec = tree_flatten(inputs) + + self._check_inputs(flat_inputs) + + if torch.rand(1) >= self.p: + return inputs + + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) + + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py new file mode 100644 index 00000000000..c84aee62afe --- /dev/null +++ b/torchvision/transforms/v2/_type_conversion.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import PIL.Image +import torch + +from torch.nn.functional import one_hot + +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F, Transform + +from torchvision.prototype.transforms.utils import is_simple_tensor + + +class LabelToOneHot(Transform): + _transformed_types = (datapoints.Label,) + + def __init__(self, num_categories: int = -1): + super().__init__() + self.num_categories = num_categories + + def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: + num_categories = self.num_categories + if num_categories == -1 and inpt.categories is not None: + num_categories = len(inpt.categories) + output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) + return datapoints.OneHotLabel(output, categories=inpt.categories) + + def extra_repr(self) -> str: + if self.num_categories == -1: + return "" + + return f"num_categories={self.num_categories}" + + +class PILToTensor(Transform): + _transformed_types = (PIL.Image.Image,) + + def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> torch.Tensor: + return F.pil_to_tensor(inpt) + + +class ToImageTensor(Transform): + _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> datapoints.Image: + return F.to_image_tensor(inpt) + + +class ToImagePIL(Transform): + _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) + + def __init__(self, mode: Optional[str] = None) -> None: + super().__init__() + self.mode = mode + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> PIL.Image.Image: + return F.to_image_pil(inpt, mode=self.mode) + + +# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +ToPILImage = ToImagePIL diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py new file mode 100644 index 00000000000..f2d818b1326 --- /dev/null +++ b/torchvision/transforms/v2/_utils.py @@ -0,0 +1,95 @@ +import functools +import numbers +from collections import defaultdict +from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union + +from torchvision.prototype import datapoints +from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT + +from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 + + +def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: + if not isinstance(arg, (float, Sequence)): + raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) != req_size: + raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, float): + raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") + + if isinstance(arg, float): + arg = [float(arg), float(arg)] + if isinstance(arg, (list, tuple)) and len(arg) == 1: + arg = [arg[0], arg[0]] + return arg + + +def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: + if isinstance(fill, dict): + for key, value in fill.items(): + # Check key for type + _check_fill_arg(value) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + _check_fill_arg(default_value) + else: + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") + + +T = TypeVar("T") + + +def _default_arg(value: T) -> T: + return value + + +def _get_defaultdict(default: T) -> Dict[Any, T]: + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: default)` + return defaultdict(functools.partial(_default_arg, default)) + + +def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we can't reassign fill to 0 + # if fill is None: + # fill = 0 + if fill is None: + return fill + + if not isinstance(fill, (int, float)): + fill = [float(v) for v in list(fill)] + return fill # type: ignore[return-value] + + +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: + _check_fill_arg(fill) + + if isinstance(fill, dict): + for k, v in fill.items(): + fill[k] = _convert_fill_arg(v) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + sanitized_default = _convert_fill_arg(default_value) + fill.default_factory = functools.partial(_default_arg, sanitized_default) + return fill # type: ignore[return-value] + + return _get_defaultdict(_convert_fill_arg(fill)) + + +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) +# https://github.com/pytorch/vision/issues/6250 +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py new file mode 100644 index 00000000000..0909b763441 --- /dev/null +++ b/torchvision/transforms/v2/functional/__init__.py @@ -0,0 +1,173 @@ +# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators + +from torchvision.transforms import InterpolationMode # usort: skip + +from ._utils import is_simple_tensor # usort: skip + +from ._meta import ( + clamp_bounding_box, + convert_format_bounding_box, + convert_dtype_image_tensor, + convert_dtype, + convert_dtype_video, + convert_image_dtype, + get_dimensions_image_tensor, + get_dimensions_image_pil, + get_dimensions, + get_num_frames_video, + get_num_frames, + get_image_num_channels, + get_num_channels_image_tensor, + get_num_channels_image_pil, + get_num_channels_video, + get_num_channels, + get_spatial_size_bounding_box, + get_spatial_size_image_tensor, + get_spatial_size_image_pil, + get_spatial_size_mask, + get_spatial_size_video, + get_spatial_size, +) # usort: skip + +from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video +from ._color import ( + adjust_brightness, + adjust_brightness_image_pil, + adjust_brightness_image_tensor, + adjust_brightness_video, + adjust_contrast, + adjust_contrast_image_pil, + adjust_contrast_image_tensor, + adjust_contrast_video, + adjust_gamma, + adjust_gamma_image_pil, + adjust_gamma_image_tensor, + adjust_gamma_video, + adjust_hue, + adjust_hue_image_pil, + adjust_hue_image_tensor, + adjust_hue_video, + adjust_saturation, + adjust_saturation_image_pil, + adjust_saturation_image_tensor, + adjust_saturation_video, + adjust_sharpness, + adjust_sharpness_image_pil, + adjust_sharpness_image_tensor, + adjust_sharpness_video, + autocontrast, + autocontrast_image_pil, + autocontrast_image_tensor, + autocontrast_video, + equalize, + equalize_image_pil, + equalize_image_tensor, + equalize_video, + invert, + invert_image_pil, + invert_image_tensor, + invert_video, + posterize, + posterize_image_pil, + posterize_image_tensor, + posterize_video, + rgb_to_grayscale, + rgb_to_grayscale_image_pil, + rgb_to_grayscale_image_tensor, + solarize, + solarize_image_pil, + solarize_image_tensor, + solarize_video, +) +from ._geometry import ( + affine, + affine_bounding_box, + affine_image_pil, + affine_image_tensor, + affine_mask, + affine_video, + center_crop, + center_crop_bounding_box, + center_crop_image_pil, + center_crop_image_tensor, + center_crop_mask, + center_crop_video, + crop, + crop_bounding_box, + crop_image_pil, + crop_image_tensor, + crop_mask, + crop_video, + elastic, + elastic_bounding_box, + elastic_image_pil, + elastic_image_tensor, + elastic_mask, + elastic_transform, + elastic_video, + five_crop, + five_crop_image_pil, + five_crop_image_tensor, + five_crop_video, + hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file + horizontal_flip, + horizontal_flip_bounding_box, + horizontal_flip_image_pil, + horizontal_flip_image_tensor, + horizontal_flip_mask, + horizontal_flip_video, + pad, + pad_bounding_box, + pad_image_pil, + pad_image_tensor, + pad_mask, + pad_video, + perspective, + perspective_bounding_box, + perspective_image_pil, + perspective_image_tensor, + perspective_mask, + perspective_video, + resize, + resize_bounding_box, + resize_image_pil, + resize_image_tensor, + resize_mask, + resize_video, + resized_crop, + resized_crop_bounding_box, + resized_crop_image_pil, + resized_crop_image_tensor, + resized_crop_mask, + resized_crop_video, + rotate, + rotate_bounding_box, + rotate_image_pil, + rotate_image_tensor, + rotate_mask, + rotate_video, + ten_crop, + ten_crop_image_pil, + ten_crop_image_tensor, + ten_crop_video, + vertical_flip, + vertical_flip_bounding_box, + vertical_flip_image_pil, + vertical_flip_image_tensor, + vertical_flip_mask, + vertical_flip_video, + vflip, +) +from ._misc import ( + gaussian_blur, + gaussian_blur_image_pil, + gaussian_blur_image_tensor, + gaussian_blur_video, + normalize, + normalize_image_tensor, + normalize_video, +) +from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video +from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image + +from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py new file mode 100644 index 00000000000..0164a0b5b9b --- /dev/null +++ b/torchvision/transforms/v2/functional/_augment.py @@ -0,0 +1,64 @@ +from typing import Union + +import PIL.Image + +import torch +from torchvision.prototype import datapoints +from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.utils import _log_api_usage_once + +from ._utils import is_simple_tensor + + +def erase_image_tensor( + image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + if not inplace: + image = image.clone() + + image[..., i : i + h, j : j + w] = v + return image + + +@torch.jit.unused +def erase_image_pil( + image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return to_pil_image(output, mode=image.mode) + + +def erase_video( + video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + +def erase( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], + i: int, + j: int, + h: int, + w: int, + v: torch.Tensor, + inplace: bool = False, +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(erase) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + elif isinstance(inpt, datapoints.Image): + output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return datapoints.Image.wrap_like(inpt, output) + elif isinstance(inpt, datapoints.Video): + output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return datapoints.Video.wrap_like(inpt, output) + elif isinstance(inpt, PIL.Image.Image): + return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py new file mode 100644 index 00000000000..e1c8bb87cdd --- /dev/null +++ b/torchvision/transforms/v2/functional/_color.py @@ -0,0 +1,672 @@ +from typing import Union + +import PIL.Image +import torch +from torch.nn.functional import conv2d +from torchvision.prototype import datapoints +from torchvision.transforms import functional_pil as _FP +from torchvision.transforms.functional_tensor import _max_value + +from torchvision.utils import _log_api_usage_once + +from ._meta import _num_value_bits, convert_dtype_image_tensor +from ._utils import is_simple_tensor + + +def _rgb_to_grayscale_image_tensor( + image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True +) -> torch.Tensor: + if image.shape[-3] == 1: + return image.clone() + + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + if preserve_dtype: + l_img = l_img.to(image.dtype) + if num_output_channels == 3: + l_img = l_img.expand(image.shape) + return l_img + + +def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) + + +rgb_to_grayscale_image_pil = _FP.to_grayscale + + +def rgb_to_grayscale( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(rgb_to_grayscale) + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.rgb_to_grayscale(num_output_channels=num_output_channels) + elif isinstance(inpt, PIL.Image.Image): + return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: + ratio = float(ratio) + fp = image1.is_floating_point() + bound = _max_value(image1.dtype) + output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) + return output if fp else output.to(image1.dtype) + + +def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + fp = image.is_floating_point() + bound = _max_value(image.dtype) + output = image.mul(brightness_factor).clamp_(0, bound) + return output if fp else output.to(image.dtype) + + +adjust_brightness_image_pil = _FP.adjust_brightness + + +def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: + return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) + + +def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_brightness) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_brightness(brightness_factor=brightness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) + if not image.is_floating_point(): + grayscale_image = grayscale_image.floor_() + + return _blend(image, grayscale_image, saturation_factor) + + +adjust_saturation_image_pil = _FP.adjust_saturation + + +def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: + return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) + + +def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_saturation) + + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_saturation(saturation_factor=saturation_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + fp = image.is_floating_point() + if c == 3: + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) + if not fp: + grayscale_image = grayscale_image.floor_() + else: + grayscale_image = image if fp else image.to(torch.float32) + mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) + return _blend(image, mean, contrast_factor) + + +adjust_contrast_image_pil = _FP.adjust_contrast + + +def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: + return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) + + +def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_contrast) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_contrast(contrast_factor=contrast_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + num_channels, height, width = image.shape[-3:] + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if sharpness_factor < 0: + raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") + + if image.numel() == 0 or height <= 2 or width <= 2: + return image + + bound = _max_value(image.dtype) + fp = image.is_floating_point() + shape = image.shape + + if image.ndim > 4: + image = image.reshape(-1, num_channels, height, width) + needs_unsquash = True + else: + needs_unsquash = False + + # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle. + kernel_dtype = image.dtype if fp else torch.float32 + a, b = 1.0 / 13.0, 5.0 / 13.0 + kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device) + kernel = kernel.expand(num_channels, 1, 3, 3) + + # We copy and cast at the same time to avoid modifications on the original data + output = image.to(dtype=kernel_dtype, copy=True) + blurred_degenerate = conv2d(output, kernel, groups=num_channels) + if not fp: + # it is better to round before cast + blurred_degenerate = blurred_degenerate.round_() + + # Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice. + view = output[..., 1:-1, 1:-1] + + # We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent: + # x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r) + view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor)) + + # The actual data of output have been modified by the above. We only need to clamp and cast now. + output = output.clamp_(0, bound) + if not fp: + output = output.to(image.dtype) + + if needs_unsquash: + output = output.reshape(shape) + + return output + + +adjust_sharpness_image_pil = _FP.adjust_sharpness + + +def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) + + +def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_sharpness) + + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: + r, g, _ = image.unbind(dim=-3) + + # Implementation is based on + # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330 + minc, maxc = torch.aminmax(image, dim=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occurring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + channels_range = maxc - minc + # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = channels_range / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3) + rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3) + + mask_maxc_neq_r = maxc != r + mask_maxc_eq_g = maxc == g + + hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r) + hr = bc.sub_(gc).mul_(~mask_maxc_neq_r) + hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_())) + + h = hr.add_(hg).add_(hb) + h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: + h, s, v = img.unbind(dim=-3) + h6 = h.mul(6) + i = torch.floor(h6) + f = h6.sub_(i) + i = i.to(dtype=torch.int32) + + sxf = s * f + one_minus_s = 1.0 - s + q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0) + t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0) + p = one_minus_s.mul_(v).clamp_(0.0, 1.0) + i.remainder_(6) + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) + + +def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + if image.numel() == 0: + # exit earlier on empty images + return image + + orig_dtype = image.dtype + image = convert_dtype_image_tensor(image, torch.float32) + + image = _rgb_to_hsv(image) + h, s, v = image.unbind(dim=-3) + h.add_(hue_factor).remainder_(1.0) + image = torch.stack((h, s, v), dim=-3) + image_hue_adj = _hsv_to_rgb(image) + + return convert_dtype_image_tensor(image_hue_adj, orig_dtype) + + +adjust_hue_image_pil = _FP.adjust_hue + + +def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: + return adjust_hue_image_tensor(video, hue_factor=hue_factor) + + +def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_hue) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_hue(hue_factor=hue_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_hue_image_pil(inpt, hue_factor=hue_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). + # Since the gamma is non-negative, the output remains at [0, 1] scale. + if not torch.is_floating_point(image): + output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) + else: + output = image.pow(gamma) + + if gain != 1.0: + # The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale + # of the output can go beyond [0, 1]. + output = output.mul_(gain).clamp_(0.0, 1.0) + + return convert_dtype_image_tensor(output, image.dtype) + + +adjust_gamma_image_pil = _FP.adjust_gamma + + +def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: + return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) + + +def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_gamma) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_gamma(gamma=gamma, gain=gain) + elif isinstance(inpt, PIL.Image.Image): + return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: + if image.is_floating_point(): + levels = 1 << bits + return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels) + else: + num_value_bits = _num_value_bits(image.dtype) + if bits >= num_value_bits: + return image + + mask = ((1 << bits) - 1) << (num_value_bits - bits) + return image & mask + + +posterize_image_pil = _FP.posterize + + +def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: + return posterize_image_tensor(video, bits=bits) + + +def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(posterize) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return posterize_image_tensor(inpt, bits=bits) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.posterize(bits=bits) + elif isinstance(inpt, PIL.Image.Image): + return posterize_image_pil(inpt, bits=bits) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: + if threshold > _max_value(image.dtype): + raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") + + return torch.where(image >= threshold, invert_image_tensor(image), image) + + +solarize_image_pil = _FP.solarize + + +def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: + return solarize_image_tensor(video, threshold=threshold) + + +def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(solarize) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return solarize_image_tensor(inpt, threshold=threshold) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.solarize(threshold=threshold) + elif isinstance(inpt, PIL.Image.Image): + return solarize_image_pil(inpt, threshold=threshold) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if image.numel() == 0: + # exit earlier on empty images + return image + + bound = _max_value(image.dtype) + fp = image.is_floating_point() + float_image = image if fp else image.to(torch.float32) + + minimum = float_image.amin(dim=(-2, -1), keepdim=True) + maximum = float_image.amax(dim=(-2, -1), keepdim=True) + + eq_idxs = maximum == minimum + inv_scale = maximum.sub_(minimum).mul_(1.0 / bound) + minimum[eq_idxs] = 0.0 + inv_scale[eq_idxs] = 1.0 + + if fp: + diff = float_image.sub(minimum) + else: + diff = float_image.sub_(minimum) + + return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) + + +autocontrast_image_pil = _FP.autocontrast + + +def autocontrast_video(video: torch.Tensor) -> torch.Tensor: + return autocontrast_image_tensor(video) + + +def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(autocontrast) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return autocontrast_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.autocontrast() + elif isinstance(inpt, PIL.Image.Image): + return autocontrast_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.numel() == 0: + return image + + # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that + # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for + # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely + # unfeasible for `torch.int64`. + # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we + # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition + # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower + # and more complicated to implement than a simple conversion and a fast histogram implementation for integers. + # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is + # by far the most common, we choose it as base. + output_dtype = image.dtype + image = convert_dtype_image_tensor(image, torch.uint8) + + # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image + # corresponds to adding 1 to index 127 in the histogram. + batch_shape = image.shape[:-2] + flat_image = image.flatten(start_dim=-2).to(torch.long) + hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) + hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) + cum_hist = hist.cumsum(dim=-1) + + # The simplest form of lookup-table (LUT) that also achieves histogram equalization is + # `lut = cum_hist / flat_image.shape[-1] * 255` + # However, PIL uses a more elaborate scheme: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 + # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` + + # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum + # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but + # rather the maximum value in the image, which might be or not be 255. + index = cum_hist.argmax(dim=-1) + num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1)) + + # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies + # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the + # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison. + step = num_non_max_pixels.div_(255, rounding_mode="floor") + + # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as + # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, + # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to + # pay the runtime cost for checking it every time. + valid_equalization = step.ne(0).unsqueeze_(-1) + + # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the + # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. + cum_hist = cum_hist[..., :-1] + ( + cum_hist.add_(step // 2) + # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no + # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is + # instead of equalized version. + .div_(step.clamp_(min=1), rounding_mode="floor") + # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value + # range of uint8 images + .clamp_(0, 255) + ) + lut = cum_hist.to(torch.uint8) + lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) + equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) + + output = torch.where(valid_equalization, equalized_image, image) + return convert_dtype_image_tensor(output, output_dtype) + + +equalize_image_pil = _FP.equalize + + +def equalize_video(video: torch.Tensor) -> torch.Tensor: + return equalize_image_tensor(video) + + +def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(equalize) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return equalize_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.equalize() + elif isinstance(inpt, PIL.Image.Image): + return equalize_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.is_floating_point(): + return 1.0 - image + elif image.dtype == torch.uint8: + return image.bitwise_not() + else: # signed integer dtypes + # We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign + return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) + + +invert_image_pil = _FP.invert + + +def invert_video(video: torch.Tensor) -> torch.Tensor: + return invert_image_tensor(video) + + +def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(invert) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return invert_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.invert() + elif isinstance(inpt, PIL.Image.Image): + return invert_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) diff --git a/torchvision/transforms/v2/functional/_deprecated.py b/torchvision/transforms/v2/functional/_deprecated.py new file mode 100644 index 00000000000..09870216059 --- /dev/null +++ b/torchvision/transforms/v2/functional/_deprecated.py @@ -0,0 +1,39 @@ +import warnings +from typing import Any, List, Union + +import PIL.Image +import torch + +from torchvision.prototype import datapoints +from torchvision.transforms import functional as _F + + +@torch.jit.unused +def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: + call = ", num_output_channels=3" if num_output_channels == 3 else "" + replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)" + if num_output_channels == 3: + replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)" + warnings.warn( + f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. " + f"Instead, please use `{replacement}`.", + ) + + return _F.to_grayscale(inpt, num_output_channels=num_output_channels) + + +@torch.jit.unused +def to_tensor(inpt: Any) -> torch.Tensor: + warnings.warn( + "The function `to_tensor(...)` is deprecated and will be removed in a future release. " + "Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`." + ) + return _F.to_tensor(inpt) + + +def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: + warnings.warn( + "The function `get_image_size(...)` is deprecated and will be removed in a future release. " + "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." + ) + return _F.get_image_size(inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py new file mode 100644 index 00000000000..840223908ac --- /dev/null +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -0,0 +1,2102 @@ +import math +import numbers +import warnings +from typing import List, Optional, Sequence, Tuple, Union + +import PIL.Image +import torch +from torch.nn.functional import grid_sample, interpolate, pad as torch_pad + +from torchvision.prototype import datapoints +from torchvision.transforms import functional_pil as _FP +from torchvision.transforms.functional import ( + _check_antialias, + _compute_resized_output_size as __compute_resized_output_size, + _get_perspective_coeffs, + _interpolation_modes_from_int, + InterpolationMode, + pil_modes_mapping, + pil_to_tensor, + to_pil_image, +) +from torchvision.transforms.functional_tensor import _pad_symmetric + +from torchvision.utils import _log_api_usage_once + +from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil + +from ._utils import is_simple_tensor + + +def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise ValueError( + f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, " + f"but got {interpolation}." + ) + return interpolation + + +def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: + return image.flip(-1) + + +horizontal_flip_image_pil = _FP.hflip + + +def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image_tensor(mask) + + +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = bounding_box.clone().reshape(-1, 4) + + if format == datapoints.BoundingBoxFormat.XYXY: + bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_() + elif format == datapoints.BoundingBoxFormat.XYWH: + bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_() + else: # format == datapoints.BoundingBoxFormat.CXCYWH: + bounding_box[:, 0].sub_(spatial_size[1]).neg_() + + return bounding_box.reshape(shape) + + +def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image_tensor(video) + + +def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(horizontal_flip) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return horizontal_flip_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.horizontal_flip() + elif isinstance(inpt, PIL.Image.Image): + return horizontal_flip_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: + return image.flip(-2) + + +vertical_flip_image_pil = _FP.vflip + + +def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(mask) + + +def vertical_flip_bounding_box( + bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = bounding_box.clone().reshape(-1, 4) + + if format == datapoints.BoundingBoxFormat.XYXY: + bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_() + elif format == datapoints.BoundingBoxFormat.XYWH: + bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_() + else: # format == datapoints.BoundingBoxFormat.CXCYWH: + bounding_box[:, 1].sub_(spatial_size[0]).neg_() + + return bounding_box.reshape(shape) + + +def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(video) + + +def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(vertical_flip) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return vertical_flip_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.vertical_flip() + elif isinstance(inpt, PIL.Image.Image): + return vertical_flip_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are +# prevalent and well understood. Thus, we just alias them without deprecating the old names. +hflip = horizontal_flip +vflip = vertical_flip + + +def _compute_resized_output_size( + spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +) -> List[int]: + if isinstance(size, int): + size = [size] + elif max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) + + +def resize_image_tensor( + image: torch.Tensor, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation) + assert not isinstance(antialias, str) + antialias = False if antialias is None else antialias + align_corners: Optional[bool] = None + if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC: + align_corners = False + else: + # The default of antialias should be True from 0.17, so we don't warn or + # error if other interpolation modes are used. This is documented. + antialias = False + + shape = image.shape + num_channels, old_height, old_width = shape[-3:] + new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) + + if image.numel() > 0: + image = image.reshape(-1, num_channels, old_height, old_width) + + dtype = image.dtype + need_cast = dtype not in (torch.float32, torch.float64) + if need_cast: + image = image.to(dtype=torch.float32) + + image = interpolate( + image, + size=[new_height, new_width], + mode=interpolation.value, + align_corners=align_corners, + antialias=antialias, + ) + + if need_cast: + if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: + image = image.clamp_(min=0, max=255) + image = image.round_().to(dtype=dtype) + + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) + + +@torch.jit.unused +def resize_image_pil( + image: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, +) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type] + return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation]) + + +def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def resize_bounding_box( + bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +) -> Tuple[torch.Tensor, Tuple[int, int]]: + old_height, old_width = spatial_size + new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) + w_ratio = new_width / old_width + h_ratio = new_height / old_height + ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device) + return ( + bounding_box.mul(ratios).to(bounding_box.dtype), + (new_height, new_width), + ) + + +def resize_video( + video: torch.Tensor, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> torch.Tensor: + return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + +def resize( + inpt: datapoints.InputTypeJIT, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(resize) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, PIL.Image.Image): + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _affine_parse_args( + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + center: Optional[List[float]] = None, +) -> Tuple[float, List[float], List[float], Optional[List[float]]]: + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None: + if not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + else: + center = [float(c) for c in center] + + return angle, translate, shear, center + + +def _get_inverse_affine_matrix( + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True +) -> List[float]: + # Helper method to compute inverse matrix for affine transformation + + # Pillow requires inverse affine transformation matrix: + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RotateScaleShear is rotation with scale and shear matrix + # + # RotateScaleShear(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + cx, cy = center + tx, ty = translate + + # Cached results + cos_sy = math.cos(sy) + tan_sx = math.tan(sx) + rot_minus_sy = rot - sy + cx_plus_tx = cx + tx + cy_plus_ty = cy + ty + + # Rotate Scale Shear (RSS) without scaling + a = math.cos(rot_minus_sy) / cos_sy + b = -(a * tan_sx + math.sin(rot)) + c = math.sin(rot_minus_sy) / cos_sy + d = math.cos(rot) - c * tan_sx + + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + # and then apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty + matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty + else: + matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0] + # Apply inverse of center translation: RSS * C^-1 + # and then apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy + matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy + + return matrix + + +def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) + half_w = 0.5 * w + half_h = 0.5 * h + pts = torch.tensor( + [ + [-half_w, -half_h, 1.0], + [-half_w, half_h, 1.0], + [half_w, half_h, 1.0], + [half_w, -half_h, 1.0], + ] + ) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) + min_vals, max_vals = new_pts.aminmax(dim=0) + + # shift points to [0, w] and [0, h] interval to match PIL results + halfs = torch.tensor((half_w, half_h)) + min_vals.add_(halfs) + max_vals.add_(halfs) + + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + inv_tol = 1.0 / tol + cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_() + cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_() + size = cmax.sub_(cmin) + return int(size[0]), int(size[1]) # w, h + + +def _apply_grid_transform( + img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT +) -> torch.Tensor: + + # We are using context knowledge that grid should have float dtype + fp = img.dtype == grid.dtype + float_img = img if fp else img.to(grid.dtype) + + shape = float_img.shape + if shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand(shape[0], -1, -1, -1) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device) + float_img = torch.cat((float_img, mask), dim=1) + + float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) + mask = mask.expand_as(float_img) + fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] + fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) + if mode == "nearest": + bool_mask = mask < 0.5 + float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] + else: # 'bilinear' + # The following is mathematically equivalent to: + # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill + float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) + + img = float_img.round_().to(img.dtype) if not fp else float_img + + return img + + +def _assert_grid_transform_inputs( + image: torch.Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: datapoints.FillTypeJIT, + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, +) -> None: + if matrix is not None: + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list") + elif len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") + + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + + if fill is not None: + if isinstance(fill, (tuple, list)): + length = len(fill) + num_channels = image.shape[-3] + if length > 1 and length != num_channels: + raise ValueError( + "The number of elements in 'fill' cannot broadcast to match the number of " + f"channels of the image ({length} != {num_channels})" + ) + elif not isinstance(fill, (int, float)): + raise ValueError("Argument fill should be either int, float, tuple or list") + + if interpolation not in supported_interpolation_modes: + raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") + + +def _affine_grid( + theta: torch.Tensor, + w: int, + h: int, + ow: int, + oh: int, +) -> torch.Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + dtype = theta.dtype + device = theta.device + + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device)) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) + return output_grid.view(1, oh, ow, 2) + + +def affine_image_tensor( + image: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: datapoints.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + + if image.numel() == 0: + return image + + shape = image.shape + ndim = image.ndim + + if ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True + else: + needs_unsquash = False + + height, width = shape[-2:] + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + center_f = [0.0, 0.0] + if center is not None: + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])] + + translate_f = [float(t) for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + if needs_unsquash: + output = output.reshape(shape) + + return output + + +@torch.jit.unused +def affine_image_pil( + image: PIL.Image.Image, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: datapoints.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + if center is None: + height, width = get_spatial_size_image_pil(image) + center = [width * 0.5, height * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + + return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) + + +def _affine_bounding_box_with_expand( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, + expand: bool = False, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + if bounding_box.numel() == 0: + return bounding_box, spatial_size + + original_shape = bounding_box.shape + original_dtype = bounding_box.dtype + bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() + dtype = bounding_box.dtype + device = bounding_box.device + bounding_box = ( + convert_format_bounding_box( + bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + ) + ).reshape(-1, 4) + + angle, translate, shear, center = _affine_parse_args( + angle, translate, scale, shear, InterpolationMode.NEAREST, center + ) + + if center is None: + height, width = spatial_size + center = [width * 0.5, height * 0.5] + + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) + transposed_affine_matrix = ( + torch.tensor( + affine_vector, + dtype=dtype, + device=device, + ) + .reshape(2, 3) + .T + ) + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) + # 2) Now let's transform the points using affine matrix + transformed_points = torch.matmul(points, transposed_affine_matrix) + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + if expand: + # Compute minimum point for transformed image frame: + # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + height, width = spatial_size + points = torch.tensor( + [ + [0.0, 0.0, 1.0], + [0.0, float(height), 1.0], + [float(width), float(height), 1.0], + [float(width), 0.0, 1.0], + ], + dtype=dtype, + device=device, + ) + new_points = torch.matmul(points, transposed_affine_matrix) + tr = torch.amin(new_points, dim=0, keepdim=True) + # Translate bounding boxes + out_bboxes.sub_(tr.repeat((1, 2))) + # Estimate meta-data for image with inverted=True and with center=[0,0] + affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) + new_width, new_height = _compute_affine_output_size(affine_vector, width, height) + spatial_size = (new_height, new_width) + + out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size) + out_bboxes = convert_format_bounding_box( + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + out_bboxes = out_bboxes.to(original_dtype) + return out_bboxes, spatial_size + + +def affine_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + out_box, _ = _affine_bounding_box_with_expand( + bounding_box, + format=format, + spatial_size=spatial_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + expand=False, + ) + return out_box + + +def affine_mask( + mask: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + fill: datapoints.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = affine_image_tensor( + mask, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=InterpolationMode.NEAREST, + fill=fill, + center=center, + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def affine_video( + video: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: datapoints.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return affine_image_tensor( + video, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + +def affine( + inpt: datapoints.InputTypeJIT, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: datapoints.FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(affine) + + # TODO: consider deprecating integers from angle and shear on the future + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return affine_image_tensor( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.affine( + angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center + ) + elif isinstance(inpt, PIL.Image.Image): + return affine_image_pil( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def rotate_image_tensor( + image: torch.Tensor, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + + shape = image.shape + num_channels, height, width = shape[-3:] + + center_f = [0.0, 0.0] + if center is not None: + if expand: + warnings.warn("The provided center argument has no effect on the result if expand is True") + else: + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])] + + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + + if image.numel() > 0: + image = image.reshape(-1, num_channels, height, width) + + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height) + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + new_height, new_width = output.shape[-2:] + else: + output = image + new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height) + + return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) + + +@torch.jit.unused +def rotate_image_pil( + image: PIL.Image.Image, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints.FillTypeJIT = None, +) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + + if center is not None and expand: + warnings.warn("The provided center argument has no effect on the result if expand is True") + center = None + + return _FP.rotate( + image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center + ) + + +def rotate_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + if center is not None and expand: + warnings.warn("The provided center argument has no effect on the result if expand is True") + center = None + + return _affine_bounding_box_with_expand( + bounding_box, + format=format, + spatial_size=spatial_size, + angle=-angle, + translate=[0.0, 0.0], + scale=1.0, + shear=[0.0, 0.0], + center=center, + expand=expand, + ) + + +def rotate_mask( + mask: torch.Tensor, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = rotate_image_tensor( + mask, + angle=angle, + expand=expand, + interpolation=InterpolationMode.NEAREST, + fill=fill, + center=center, + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def rotate_video( + video: torch.Tensor, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +def rotate( + inpt: datapoints.InputTypeJIT, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints.FillTypeJIT = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(rotate) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + elif isinstance(inpt, PIL.Image.Image): + return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif isinstance(padding, (tuple, list)): + if len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + elif len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + else: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + else: + raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") + + return [pad_left, pad_right, pad_top, pad_bottom] + + +def pad_image_tensor( + image: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses + # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` + # internally. + torch_padding = _parse_pad_padding(padding) + + if padding_mode not in ("constant", "edge", "reflect", "symmetric"): + raise ValueError( + f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " + f"but got `'{padding_mode}'`." + ) + + if fill is None: + fill = 0 + + if isinstance(fill, (int, float)): + return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + elif len(fill) == 1: + return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) + else: + return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + + +def _pad_with_scalar_fill( + image: torch.Tensor, + torch_padding: List[int], + fill: Union[int, float], + padding_mode: str, +) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + batch_size = 1 + for s in shape[:-3]: + batch_size *= s + + image = image.reshape(batch_size, num_channels, height, width) + + if padding_mode == "edge": + # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map + # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` + # name. + padding_mode = "replicate" + + if padding_mode == "constant": + image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) + elif padding_mode in ("reflect", "replicate"): + # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. + # TODO: See https://github.com/pytorch/pytorch/issues/40763 + dtype = image.dtype + if not image.is_floating_point(): + needs_cast = True + image = image.to(torch.float32) + else: + needs_cast = False + + image = torch_pad(image, torch_padding, mode=padding_mode) + + if needs_cast: + image = image.to(dtype) + else: # padding_mode == "symmetric" + image = _pad_symmetric(image, torch_padding) + + new_height, new_width = image.shape[-2:] + + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) + + +# TODO: This should be removed once torch_pad supports non-scalar padding values +def _pad_with_vector_fill( + image: torch.Tensor, + torch_padding: List[int], + fill: List[float], + padding_mode: str, +) -> torch.Tensor: + if padding_mode != "constant": + raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") + + output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + left, right, top, bottom = torch_padding + fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) + + if top > 0: + output[..., :top, :] = fill + if left > 0: + output[..., :, :left] = fill + if bottom > 0: + output[..., -bottom:, :] = fill + if right > 0: + output[..., :, -right:] = fill + return output + + +pad_image_pil = _FP.pad + + +def pad_mask( + mask: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + if fill is None: + fill = 0 + + if isinstance(fill, (tuple, list)): + raise ValueError("Non-scalar fill value is not supported") + + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def pad_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + padding: List[int], + padding_mode: str = "constant", +) -> Tuple[torch.Tensor, Tuple[int, int]]: + if padding_mode not in ["constant"]: + # TODO: add support of other padding modes + raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") + + left, right, top, bottom = _parse_pad_padding(padding) + + if format == datapoints.BoundingBoxFormat.XYXY: + pad = [left, top, left, top] + else: + pad = [left, top, 0, 0] + bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device) + + height, width = spatial_size + height += top + bottom + width += left + right + spatial_size = (height, width) + + return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size + + +def pad_video( + video: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) + + +def pad( + inpt: datapoints.InputTypeJIT, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(pad) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.pad(padding, fill=fill, padding_mode=padding_mode) + elif isinstance(inpt, PIL.Image.Image): + return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + h, w = image.shape[-2:] + + right = left + width + bottom = top + height + + if left < 0 or top < 0 or right > w or bottom > h: + image = image[..., max(top, 0) : bottom, max(left, 0) : right] + torch_padding = [ + max(min(right, 0) - left, 0), + max(right - max(w, left), 0), + max(min(bottom, 0) - top, 0), + max(bottom - max(h, top), 0), + ] + return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + return image[..., top:bottom, left:right] + + +crop_image_pil = _FP.crop + + +def crop_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + top: int, + left: int, + height: int, + width: int, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + + # Crop or implicit pad if left and/or top have negative values: + if format == datapoints.BoundingBoxFormat.XYXY: + sub = [left, top, left, top] + else: + sub = [left, top, 0, 0] + + bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device) + spatial_size = (height, width) + + return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size + + +def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = crop_image_tensor(mask, top, left, height, width) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + return crop_image_tensor(video, top, left, height, width) + + +def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return crop_image_tensor(inpt, top, left, height, width) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.crop(top, left, height, width) + elif isinstance(inpt, PIL.Image.Image): + return crop_image_pil(inpt, top, left, height, width) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) + shape = (1, oh * ow, 3) + output_grid1 = base_grid.view(shape).bmm(rescaled_theta1) + output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1.div_(output_grid2).sub_(1.0) + return output_grid.view(1, oh, ow, 2) + + +def _perspective_coefficients( + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]], +) -> List[float]: + if coefficients is not None: + if startpoints is not None and endpoints is not None: + raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.") + elif len(coefficients) != 8: + raise ValueError("Argument coefficients should have 8 float values") + return coefficients + elif startpoints is not None and endpoints is not None: + return _get_perspective_coeffs(startpoints, endpoints) + else: + raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") + + +def perspective_image_tensor( + image: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + + if image.numel() == 0: + return image + + shape = image.shape + ndim = image.ndim + + if ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True + else: + needs_unsquash = False + + _assert_grid_transform_inputs( + image, + matrix=None, + interpolation=interpolation.value, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + oh, ow = shape[-2:] + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + if needs_unsquash: + output = output.reshape(shape) + + return output + + +@torch.jit.unused +def perspective_image_pil( + image: PIL.Image.Image, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC, + fill: datapoints.FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> PIL.Image.Image: + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) + + +def perspective_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + if bounding_box.numel() == 0: + return bounding_box + + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + + original_shape = bounding_box.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_box + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) + ).reshape(-1, 4) + + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_box.device + + # perspective_coeffs are computed as endpoint -> start point + # We have to invert perspective_coeffs for bboxes: + # (x, y) - end point and (x_out, y_out) - start point + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # and we would like to get: + # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # and compute inv_coeffs in terms of coeffs + + denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3] + if denom == 0: + raise RuntimeError( + f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. " + f"Denominator is zero, denom={denom}" + ) + + inv_coeffs = [ + (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, + (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, + (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, + (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, + (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, + ] + + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + dtype=dtype, + device=device, + ) + + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) + # 2) Now let's transform the points using perspective matrices + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) + transformed_points = numer_points.div_(denom_points) + + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + + out_bboxes = clamp_bounding_box( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) + + # out_bboxes should be of shape [N boxes, 4] + + return convert_format_bounding_box( + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + +def perspective_mask( + mask: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + fill: datapoints.FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = perspective_image_tensor( + mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def perspective_video( + video: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + return perspective_image_tensor( + video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + + +def perspective( + inpt: datapoints.InputTypeJIT, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(perspective) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return perspective_image_tensor( + inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.perspective( + startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + elif isinstance(inpt, PIL.Image.Image): + return perspective_image_pil( + inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def elastic_image_tensor( + image: torch.Tensor, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + + if image.numel() == 0: + return image + + shape = image.shape + ndim = image.ndim + + device = image.device + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + + # Patch: elastic transform should support (cpu,f16) input + is_cpu_half = device.type == "cpu" and dtype == torch.float16 + if is_cpu_half: + image = image.to(torch.float32) + dtype = torch.float32 + + # We are aware that if input image dtype is uint8 and displacement is float64 then + # displacement will be casted to float32 and all computations will be done with float32 + # We can fix this later if needed + + expected_shape = (1,) + shape[-2:] + (2,) + if expected_shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + if ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True + else: + needs_unsquash = False + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) + + image_height, image_width = shape[-2:] + grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + if needs_unsquash: + output = output.reshape(shape) + + if is_cpu_half: + output = output.to(torch.float16) + + return output + + +@torch.jit.unused +def elastic_image_pil( + image: PIL.Image.Image, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) + return to_pil_image(output, mode=image.mode) + + +def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor: + sy, sx = size + base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype) + x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype) + base_grid[..., 0].copy_(x_grid) + + y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + + return base_grid + + +def elastic_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + displacement: torch.Tensor, +) -> torch.Tensor: + if bounding_box.numel() == 0: + return bounding_box + + # TODO: add in docstring about approximation we are doing for grid inversion + device = bounding_box.device + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) + + original_shape = bounding_box.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_box + bounding_box = ( + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) + ).reshape(-1, 4) + + id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) + # We construct an approximation of inverse grid as inv_grid = id_grid - displacement + # This is not an exact inverse of the grid + inv_grid = id_grid.sub_(displacement) + + # Get points from bboxes + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + if points.is_floating_point(): + points = points.ceil_() + index_xy = points.to(dtype=torch.long) + index_x, index_y = index_xy[:, 0], index_xy[:, 1] + + # Transform points: + t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) + + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = clamp_bounding_box( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) + + return convert_format_bounding_box( + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + +def elastic_mask( + mask: torch.Tensor, + displacement: torch.Tensor, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def elastic_video( + video: torch.Tensor, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, +) -> torch.Tensor: + return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) + + +def elastic( + inpt: datapoints.InputTypeJIT, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints.FillTypeJIT = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(elastic) + + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.elastic(displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, PIL.Image.Image): + return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +elastic_transform = elastic + + +def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: + if isinstance(output_size, numbers.Number): + s = int(output_size) + return [s, s] + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + return [output_size[0], output_size[0]] + else: + return list(output_size) + + +def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: + return [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + + +def _center_crop_compute_crop_anchor( + crop_height: int, crop_width: int, image_height: int, image_width: int +) -> Tuple[int, int]: + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop_top, crop_left + + +def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + shape = image.shape + if image.numel() == 0: + return image.reshape(shape[:-2] + (crop_height, crop_width)) + image_height, image_width = shape[-2:] + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0) + + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] + + +@torch.jit.unused +def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_height, image_width = get_spatial_size_image_pil(image) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = pad_image_pil(image, padding_ltrb, fill=0) + + image_height, image_width = get_spatial_size_image_pil(image) + if crop_width == image_width and crop_height == image_height: + return image + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) + + +def center_crop_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], + output_size: List[int], +) -> Tuple[torch.Tensor, Tuple[int, int]]: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size) + return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + + +def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = center_crop_image_tensor(image=mask, output_size=output_size) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: + return center_crop_image_tensor(video, output_size) + + +def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(center_crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return center_crop_image_tensor(inpt, output_size) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.center_crop(output_size) + elif isinstance(inpt, PIL.Image.Image): + return center_crop_image_pil(inpt, output_size) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def resized_crop_image_tensor( + image: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> torch.Tensor: + image = crop_image_tensor(image, top, left, height, width) + return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) + + +@torch.jit.unused +def resized_crop_image_pil( + image: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, +) -> PIL.Image.Image: + image = crop_image_pil(image, top, left, height, width) + return resize_image_pil(image, size, interpolation=interpolation) + + +def resized_crop_bounding_box( + bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> Tuple[torch.Tensor, Tuple[int, int]]: + bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width) + return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size) + + +def resized_crop_mask( + mask: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> torch.Tensor: + mask = crop_mask(mask, top, left, height, width) + return resize_mask(mask, size) + + +def resized_crop_video( + video: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> torch.Tensor: + return resized_crop_image_tensor( + video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + + +def resized_crop( + inpt: datapoints.InputTypeJIT, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(resized_crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return resized_crop_image_tensor( + inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) + elif isinstance(inpt, PIL.Image.Image): + return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def _parse_five_crop_size(size: List[int]) -> List[int]: + if isinstance(size, numbers.Number): + s = int(size) + size = [s, s] + elif isinstance(size, (tuple, list)) and len(size) == 1: + s = size[0] + size = [s, s] + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + return size + + +def five_crop_image_tensor( + image: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = image.shape[-2:] + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = crop_image_tensor(image, 0, 0, crop_height, crop_width) + tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_tensor(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +@torch.jit.unused +def five_crop_image_pil( + image: PIL.Image.Image, size: List[int] +) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = get_spatial_size_image_pil(image) + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = crop_image_pil(image, 0, 0, crop_height, crop_width) + tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_pil(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +def five_crop_video( + video: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return five_crop_image_tensor(video, size) + + +ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] + + +def five_crop( + inpt: ImageOrVideoTypeJIT, size: List[int] +) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(five_crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return five_crop_image_tensor(inpt, size) + elif isinstance(inpt, datapoints.Image): + output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) + return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + elif isinstance(inpt, datapoints.Video): + output = five_crop_video(inpt.as_subclass(torch.Tensor), size) + return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + elif isinstance(inpt, PIL.Image.Image): + return five_crop_image_pil(inpt, size) + else: + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def ten_crop_image_tensor( + image: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + non_flipped = five_crop_image_tensor(image, size) + + if vertical_flip: + image = vertical_flip_image_tensor(image) + else: + image = horizontal_flip_image_tensor(image) + + flipped = five_crop_image_tensor(image, size) + + return non_flipped + flipped + + +@torch.jit.unused +def ten_crop_image_pil( + image: PIL.Image.Image, size: List[int], vertical_flip: bool = False +) -> Tuple[ + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, +]: + non_flipped = five_crop_image_pil(image, size) + + if vertical_flip: + image = vertical_flip_image_pil(image) + else: + image = horizontal_flip_image_pil(image) + + flipped = five_crop_image_pil(image, size) + + return non_flipped + flipped + + +def ten_crop_video( + video: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) + + +def ten_crop( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False +) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, +]: + if not torch.jit.is_scripting(): + _log_api_usage_once(ten_crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) + elif isinstance(inpt, datapoints.Image): + output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [datapoints.Image.wrap_like(inpt, item) for item in output] + elif isinstance(inpt, datapoints.Video): + output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [datapoints.Video.wrap_like(inpt, item) for item in output] + elif isinstance(inpt, PIL.Image.Image): + return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) + else: + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py new file mode 100644 index 00000000000..5e32516fb8a --- /dev/null +++ b/torchvision/transforms/v2/functional/_meta.py @@ -0,0 +1,374 @@ +from typing import List, Optional, Tuple, Union + +import PIL.Image +import torch +from torchvision.prototype import datapoints +from torchvision.prototype.datapoints import BoundingBoxFormat +from torchvision.transforms import functional_pil as _FP +from torchvision.transforms.functional_tensor import _max_value + +from torchvision.utils import _log_api_usage_once + +from ._utils import is_simple_tensor + + +def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: + chw = list(image.shape[-3:]) + ndims = len(chw) + if ndims == 3: + return chw + elif ndims == 2: + chw.insert(0, 1) + return chw + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +get_dimensions_image_pil = _FP.get_dimensions + + +def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_dimensions) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_dimensions_image_tensor(inpt) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + channels = inpt.num_channels + height, width = inpt.spatial_size + return [channels, height, width] + elif isinstance(inpt, PIL.Image.Image): + return get_dimensions_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def get_num_channels_image_tensor(image: torch.Tensor) -> int: + chw = image.shape[-3:] + ndims = len(chw) + if ndims == 3: + return chw[0] + elif ndims == 2: + return 1 + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +get_num_channels_image_pil = _FP.get_image_num_channels + + +def get_num_channels_video(video: torch.Tensor) -> int: + return get_num_channels_image_tensor(video) + + +def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_num_channels) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_num_channels_image_tensor(inpt) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt.num_channels + elif isinstance(inpt, PIL.Image.Image): + return get_num_channels_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without +# deprecating the old names. +get_image_num_channels = get_num_channels + + +def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: + hw = list(image.shape[-2:]) + ndims = len(hw) + if ndims == 2: + return hw + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +@torch.jit.unused +def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: + width, height = _FP.get_image_size(image) + return [height, width] + + +def get_spatial_size_video(video: torch.Tensor) -> List[int]: + return get_spatial_size_image_tensor(video) + + +def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: + return get_spatial_size_image_tensor(mask) + + +@torch.jit.unused +def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]: + return list(bounding_box.spatial_size) + + +def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_spatial_size) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_spatial_size_image_tensor(inpt) + elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): + return list(inpt.spatial_size) + elif isinstance(inpt, PIL.Image.Image): + return get_spatial_size_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +def get_num_frames_video(video: torch.Tensor) -> int: + return video.shape[-4] + + +def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_num_frames) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_num_frames_video(inpt) + elif isinstance(inpt, datapoints.Video): + return inpt.num_frames + else: + raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") + + +def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: + xyxy = xywh if inplace else xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + xywh = xyxy if inplace else xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + cxcywh = cxcywh.clone() + + # Trick to do fast division by 2 and ceil, without casting. It produces the same result as + # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`. + half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_() + # (cx - width / 2) = x1, same for y1 + cxcywh[..., :2].sub_(half_wh) + # (x1 + width) = x2, same for y2 + cxcywh[..., 2:].add_(cxcywh[..., :2]) + + return cxcywh + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + xyxy = xyxy.clone() + + # (x2 - x1) = width, same for height + xyxy[..., 2:].sub_(xyxy[..., :2]) + # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy + xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor") + + return xyxy + + +def _convert_format_bounding_box( + bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False +) -> torch.Tensor: + + if new_format == old_format: + return bounding_box + + # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance + if old_format == BoundingBoxFormat.XYWH: + bounding_box = _xywh_to_xyxy(bounding_box, inplace) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_box = _cxcywh_to_xyxy(bounding_box, inplace) + + if new_format == BoundingBoxFormat.XYWH: + bounding_box = _xyxy_to_xywh(bounding_box, inplace) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_box = _xyxy_to_cxcywh(bounding_box, inplace) + + return bounding_box + + +def convert_format_bounding_box( + inpt: datapoints.InputTypeJIT, + old_format: Optional[BoundingBoxFormat] = None, + new_format: Optional[BoundingBoxFormat] = None, + inplace: bool = False, +) -> datapoints.InputTypeJIT: + # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor + # inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on + # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the + # default error that would be thrown if `new_format` had no default value. + if new_format is None: + raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'") + + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_format_bounding_box) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + if old_format is None: + raise ValueError("For simple tensor inputs, `old_format` has to be passed.") + return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace) + elif isinstance(inpt, datapoints.BoundingBox): + if old_format is not None: + raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") + output = _convert_format_bounding_box( + inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace + ) + return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." + ) + + +def _clamp_bounding_box( + bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] +) -> torch.Tensor: + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + in_dtype = bounding_box.dtype + bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() + xyxy_boxes = convert_format_bounding_box( + bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + ) + xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) + out_boxes = convert_format_bounding_box( + xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True + ) + return out_boxes.to(in_dtype) + + +def clamp_bounding_box( + inpt: datapoints.InputTypeJIT, + format: Optional[BoundingBoxFormat] = None, + spatial_size: Optional[Tuple[int, int]] = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_bounding_box) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + if format is None or spatial_size is None: + raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.") + return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size) + elif isinstance(inpt, datapoints.BoundingBox): + if format is not None or spatial_size is not None: + raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") + output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size) + return datapoints.BoundingBox.wrap_like(inpt, output) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." + ) + + +def _num_value_bits(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 8 + elif dtype == torch.int8: + return 7 + elif dtype == torch.int16: + return 15 + elif dtype == torch.int32: + return 31 + elif dtype == torch.int64: + return 63 + else: + raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") + + +def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + if image.dtype == dtype: + return image + + float_input = image.is_floating_point() + if torch.jit.is_scripting(): + # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT + float_output = torch.tensor(0, dtype=dtype).is_floating_point() + else: + float_output = dtype.is_floating_point + + if float_input: + # float to float + if float_output: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") + + # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting + # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only + # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # for a detailed analysis. + # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. + # Instead, we can also multiply by the maximum value plus something close to `1`. See + # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. + eps = 1e-3 + max_value = float(_max_value(dtype)) + # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the + # discrete set `{0, 1}`. + return image.mul(max_value + 1.0 - eps).to(dtype) + else: + # int to float + if float_output: + return image.to(dtype).mul_(1.0 / _max_value(image.dtype)) + + # int to int + num_value_bits_input = _num_value_bits(image.dtype) + num_value_bits_output = _num_value_bits(dtype) + + if num_value_bits_input > num_value_bits_output: + return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) + else: + return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) + + +# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +convert_image_dtype = convert_dtype_image_tensor + + +def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + return convert_dtype_image_tensor(video, dtype) + + +def convert_dtype( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float +) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_dtype) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return convert_dtype_image_tensor(inpt, dtype) + elif isinstance(inpt, datapoints.Image): + output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) + return datapoints.Image.wrap_like(inpt, output) + elif isinstance(inpt, datapoints.Video): + output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) + return datapoints.Video.wrap_like(inpt, output) + else: + raise TypeError( + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + ) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py new file mode 100644 index 00000000000..9d0a00f88c3 --- /dev/null +++ b/torchvision/transforms/v2/functional/_misc.py @@ -0,0 +1,184 @@ +import math +from typing import List, Optional, Union + +import PIL.Image +import torch +from torch.nn.functional import conv2d, pad as torch_pad + +from torchvision.prototype import datapoints +from torchvision.transforms.functional import pil_to_tensor, to_pil_image + +from torchvision.utils import _log_api_usage_once + +from ._utils import is_simple_tensor + + +def normalize_image_tensor( + image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False +) -> torch.Tensor: + if not image.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") + + if image.ndim < 3: + raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.") + + if isinstance(std, (tuple, list)): + divzero = not all(std) + elif isinstance(std, (int, float)): + divzero = std == 0 + else: + divzero = False + if divzero: + raise ValueError("std evaluated to zero, leading to division by zero.") + + dtype = image.dtype + device = image.device + mean = torch.as_tensor(mean, dtype=dtype, device=device) + std = torch.as_tensor(std, dtype=dtype, device=device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + if inplace: + image = image.sub_(mean) + else: + image = image.sub(mean) + + return image.div_(std) + + +def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: + return normalize_image_tensor(video, mean, std, inplace=inplace) + + +def normalize( + inpt: Union[datapoints.TensorImageTypeJIT, datapoints.TensorVideoTypeJIT], + mean: List[float], + std: List[float], + inplace: bool = False, +) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(normalize) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt.normalize(mean=mean, std=std, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + ) + + +def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma) + x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) + kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0) + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device) + kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x + return kernel2d + + +def gaussian_blur_image_tensor( + image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + # TODO: consider deprecating integers from sigma on the future + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + elif len(kernel_size) != 2: + raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") + + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + else: + if isinstance(sigma, (list, tuple)): + length = len(sigma) + if length == 1: + s = float(sigma[0]) + sigma = [s, s] + elif length != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") + elif isinstance(sigma, (int, float)): + s = float(sigma) + sigma = [s, s] + else: + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + if image.numel() == 0: + return image + + dtype = image.dtype + shape = image.shape + ndim = image.ndim + if ndim == 3: + image = image.unsqueeze(dim=0) + elif ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + + fp = torch.is_floating_point(image) + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) + kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + output = image if fp else image.to(dtype=torch.float32) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + output = torch_pad(output, padding, mode="reflect") + output = conv2d(output, kernel, groups=shape[-3]) + + if ndim == 3: + output = output.squeeze(dim=0) + elif ndim > 4: + output = output.reshape(shape) + + if not fp: + output = output.round_().to(dtype=dtype) + + return output + + +@torch.jit.unused +def gaussian_blur_image_pil( + image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) + return to_pil_image(output, mode=image.mode) + + +def gaussian_blur_video( + video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + return gaussian_blur_image_tensor(video, kernel_size, sigma) + + +def gaussian_blur( + inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(gaussian_blur) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) + elif isinstance(inpt, PIL.Image.Image): + return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py new file mode 100644 index 00000000000..d39a64534ca --- /dev/null +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -0,0 +1,33 @@ +import torch + +from torchvision.prototype import datapoints + +from torchvision.utils import _log_api_usage_once + +from ._utils import is_simple_tensor + + +def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: + # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 + t_max = video.shape[temporal_dim] - 1 + indices = torch.linspace(0, t_max, num_samples, device=video.device).long() + return torch.index_select(video, temporal_dim, indices) + + +def uniform_temporal_subsample( + inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 +) -> datapoints.VideoTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(uniform_temporal_subsample) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) + elif isinstance(inpt, datapoints.Video): + if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: + raise ValueError("Video inputs must have temporal_dim equivalent to -4") + output = uniform_temporal_subsample_video( + inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim + ) + return datapoints.Video.wrap_like(inpt, output) + else: + raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py new file mode 100644 index 00000000000..286aa7485da --- /dev/null +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -0,0 +1,28 @@ +from typing import Union + +import numpy as np +import PIL.Image +import torch +from torchvision.prototype import datapoints +from torchvision.transforms import functional as _F + + +@torch.jit.unused +def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: + if isinstance(inpt, np.ndarray): + output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() + elif isinstance(inpt, PIL.Image.Image): + output = pil_to_tensor(inpt) + elif isinstance(inpt, torch.Tensor): + output = inpt + else: + raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") + return datapoints.Image(output) + + +to_image_pil = _F.to_pil_image +pil_to_tensor = _F.pil_to_tensor + +# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is +# prevalent and well understood. Thus, we just alias it without deprecating the old name. +to_pil_image = to_image_pil diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py new file mode 100644 index 00000000000..e4efeb6016f --- /dev/null +++ b/torchvision/transforms/v2/functional/_utils.py @@ -0,0 +1,8 @@ +from typing import Any + +import torch +from torchvision.prototype.datapoints._datapoint import Datapoint + + +def is_simple_tensor(inpt: Any) -> bool: + return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py new file mode 100644 index 00000000000..ff7fff50ced --- /dev/null +++ b/torchvision/transforms/v2/utils.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Callable, List, Tuple, Type, Union + +import PIL.Image + +from torchvision._utils import sequence_to_str +from torchvision.prototype import datapoints +from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor + + +def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: + bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)] + if not bounding_boxes: + raise TypeError("No bounding box was found in the sample") + elif len(bounding_boxes) > 1: + raise ValueError("Found multiple bounding boxes in the sample") + return bounding_boxes.pop() + + +def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: + chws = { + tuple(get_dimensions(inpt)) + for inpt in flat_inputs + if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) + } + if not chws: + raise TypeError("No image or video was found in the sample") + elif len(chws) > 1: + raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") + c, h, w = chws.pop() + return c, h, w + + +def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: + sizes = { + tuple(get_spatial_size(inpt)) + for inpt in flat_inputs + if isinstance( + inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox) + ) + or is_simple_tensor(inpt) + } + if not sizes: + raise TypeError("No image, video, mask or bounding box was found in the sample") + elif len(sizes) > 1: + raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") + h, w = sizes.pop() + return h, w + + +def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: + for type_or_check in types_or_checks: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False + + +def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for inpt in flat_inputs: + if check_type(inpt, types_or_checks): + return True + return False + + +def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for type_or_check in types_or_checks: + for inpt in flat_inputs: + if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): + break + else: + return False + return True