From f5a97764cdee21c2adc069b16c18c64756abb117 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 3 Jul 2020 14:31:10 +0200 Subject: [PATCH 1/3] [WIP] F.resize with tensor --- test/test_functional_tensor.py | 35 +++++++++ torchvision/transforms/functional.py | 34 +++------ torchvision/transforms/functional_pil.py | 40 ++++++++++- torchvision/transforms/functional_tensor.py | 79 +++++++++++++++++++++ 4 files changed, 163 insertions(+), 25 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 89af6dce5d7..d36609468b5 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -282,6 +282,41 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") + def test_resize(self): + script_fn = torch.jit.script(F_t.resize) + tensor, pil_img = self._create_data(26, 31) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: + configs = [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + {"padding_mode": "symmetric"}, + ] + for kwargs in configs: + pad_tensor = F_t.pad(tensor, pad, **kwargs) + pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) + + pad_tensor_8b = pad_tensor + # we need to cast to uint8 to compare with PIL image + if pad_tensor_8b.dtype != torch.uint8: + pad_tensor_8b = pad_tensor_8b.to(torch.uint8) + + self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs)) + + if isinstance(pad, int): + script_pad = [pad, ] + else: + script_pad = pad + pad_tensor_script = script_fn(tensor, script_pad, **kwargs) + self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 81a601a8e20..1779fda561b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -311,41 +311,27 @@ def normalize(tensor, mean, std, inplace=False): return tensor -def resize(img, size, interpolation=Image.BILINEAR): - r"""Resize the input PIL Image to the given size. +def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: + r"""Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - img (PIL Image): Image to be resized. + img (PIL Image or Tensor): Image to be resized. size (sequence or int): Desired output size. If size is a sequence like (h, w), the output size will be matched to this. If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. i.e, if height > width, then image will be rescaled to :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` + interpolation (int, optional): Desired interpolation. Default is bilinear. Returns: - PIL Image: Resized image. + PIL Image or Tensor: Resized image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): - raise TypeError('Got inappropriate size arg: {}'.format(size)) + if not isinstance(img, torch.Tensor): + return F_pil.resize(img, size=size, interpolation=interpolation) - if isinstance(size, int): - w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), interpolation) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), interpolation) - else: - return img.resize(size[::-1], interpolation) + return F_t.resize(img, size=size, interpolation=interpolation) def scale(*args, **kwargs): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index f1bcda113aa..0a7f5a14a5d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, List +from typing import Any, List, Iterable import torch try: @@ -286,3 +286,41 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.crop((left, top, left + width, top + height)) + + +@torch.jit.unused +def resize(img, size, interpolation=Image.BILINEAR): + r"""Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8b64abe9f9c..056077ba114 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,3 +1,5 @@ +from PIL.Image import NEAREST, BOX, BILINEAR, HAMMING, BICUBIC, LANCZOS + import torch from torch import Tensor from torch.jit.annotations import List, BroadcastingList2 @@ -8,6 +10,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool: def _get_image_size(img: Tensor) -> List[int]: + """Returns (w, h) of tensor image""" if _is_tensor_a_torch_image(img): return [img.shape[-1], img.shape[-2]] raise TypeError("Unexpected type {}".format(type(img))) @@ -480,3 +483,79 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con img = img.to(out_dtype) return img + + +_interpolation_modes = { + NEAREST: "nearest", + BOX: "linear", + BILINEAR: "bilinear", + BICUBIC: "bicubic", +} + + +def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: + r"""Resize the input Tensor to the given size. + + Args: + img (Tensor): Image to be resized. + size (int or tuple or list): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + In torchscript mode padding as a single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation. Default is bilinear. + + Returns: + Tensor: Resized image. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError("tensor is not a torch image.") + + if not isinstance(size, (int, tuple, list)): + raise TypeError("Got inappropriate size arg") + if not isinstance(interpolation, int): + raise TypeError("Got inappropriate interpolation arg") + + if isinstance(size, tuple): + size = list(size) + + if isinstance(size, list) and len(size) not in [1, 2]: + raise ValueError("Padding must be an int or a 1 or 2 element tuple/list, not a " + + "{} element tuple/list".format(len(size))) + + if interpolation not in [0, 1, 2, 3, 4]: + raise ValueError("Interpolation mode should be either constant, edge, reflect or symmetric") + + w, h = _get_image_size(img) + if isinstance(size, int) or len(size) == 1: + if isinstance(size, list): + size = size[0] + if w < h: + size_w = size + size_h = int(size * h / w) + else: + size_h = size + size_w = int(size * w / h) + else: + size_w = size_h = size[0], size[1] + + if (w <= h and w == size_w) or (h <= w and h == size_h): + return img + + # make image NCHW + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + # get mode from interpolation + mode = "nearest" + + img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode) + + if need_squeeze: + img = img.squeeze(dim=0) + + return img From 965ad6b277ad872660364d5be9c67af2b21da53c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 6 Jul 2020 10:41:28 +0200 Subject: [PATCH 2/3] Adapted T.Resize and F.resize with a test --- test/test_transforms_tensor.py | 28 ++++++++++++++++++++++++++++ torchvision/transforms/functional.py | 4 +++- torchvision/transforms/transforms.py | 27 +++++++++++++++++---------- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 6a8d9930754..9d70744dfc1 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -2,6 +2,7 @@ from torchvision import transforms as T from torchvision.transforms import functional as F from PIL import Image +from PIL.Image import NEAREST, BILINEAR, BICUBIC import numpy as np @@ -217,6 +218,33 @@ def test_ten_crop(self): "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_resize(self): + tensor, _ = self._create_data(height=34, width=36) + script_fn = torch.jit.script(F.resize) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + for size in [32, [32, ], [32, 32], (32, 32), ]: + for interpolation in [BILINEAR, BICUBIC, NEAREST]: + + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) + + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size + + s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation) + self.assertTrue(s_resized_tensor.equal(resized_tensor)) + + transform = T.Resize(size=script_size, interpolation=interpolation) + resized_tensor = transform(tensor) + script_transform = torch.jit.script(transform) + s_resized_tensor = script_transform(tensor) + self.assertTrue(s_resized_tensor.equal(resized_tensor)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3da121b6ec9..72ca54d7260 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -322,7 +322,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: (h, w), the output size will be matched to this. If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. i.e, if height > width, then image will be rescaled to - :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. interpolation (int, optional): Desired interpolation. Default is bilinear. Returns: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6bc9e7cbc4d..9f4ad8175c6 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -2,7 +2,7 @@ import numbers import random import warnings -from collections.abc import Sequence, Iterable +from collections.abc import Sequence from typing import Tuple, List, Optional import numpy as np @@ -209,31 +209,38 @@ def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) -class Resize(object): - """Resize the input PIL Image to the given size. +class Resize(torch.nn.Module): + """Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to - (size * height / width, size) - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` + (size * height / width, size). + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=Image.BILINEAR): - assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size self.interpolation = interpolation - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be scaled. + img (PIL Image or Tensor): Image to be scaled. Returns: - PIL Image: Rescaled image. + PIL Image or Tensor: Rescaled image. """ return F.resize(img, self.size, self.interpolation) From 8f48a02673bba908408a328ee2573404389bf6c8 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 6 Jul 2020 14:25:10 +0200 Subject: [PATCH 3/3] According to the review, fixed copy-pasted messages and unused imports --- torchvision/transforms/functional_tensor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5c3fe270b45..be0b7b3a622 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,3 @@ -from PIL.Image import NEAREST, BILINEAR, BICUBIC - import torch from torch import Tensor from torch.jit.annotations import List, BroadcastingList2 @@ -524,12 +522,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: size = list(size) if isinstance(size, list) and len(size) not in [1, 2]: - raise ValueError("Padding must be an int or a 1 or 2 element tuple/list, not a " + + raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size))) - if interpolation not in [0, 1, 2, 3, 4]: - raise ValueError("Interpolation mode should be either constant, edge, reflect or symmetric") - w, h = _get_image_size(img) if isinstance(size, int):