From 91c4ae483f78f315daabc5aa029ab278ac317fe5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 8 Jul 2020 10:29:17 +0200 Subject: [PATCH 1/5] [WIP] F.affine --- torchvision/transforms/functional.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4b38c7bb92e..714f4162727 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -880,11 +880,13 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): return M -def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): - """Apply affine transformation on the image keeping image center invariant +def affine(img: Tensor, angle: int, translate: List[int], scale: float, shear List: , resample=0, fillcolor=None): + """Apply affine transformation on the image keeping image center invariant. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image): PIL Image to be rotated. + img (PIL Image or Tensor): image to be rotated. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale From 37edd9426002befa7112f3d72ed0e7d1f615124e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 8 Jul 2020 22:42:46 +0200 Subject: [PATCH 2/5] [WIP] F.affine + tests --- test/test_functional_tensor.py | 43 +++++++ test/test_transforms.py | 10 +- torchvision/transforms/functional.py | 125 +++++++++----------- torchvision/transforms/functional_pil.py | 67 ++++++++++- torchvision/transforms/functional_tensor.py | 67 ++++++++++- 5 files changed, 232 insertions(+), 80 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 95f7383a4f7..47d6a58f82e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -348,6 +348,49 @@ def test_resized_crop(self): msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) ) + def test_affine(self): + # Let's do some tests on square image at first + tensor, pil_img = self._create_data(26, 26) + # 1) identity map + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + # 2) Test rotation + test_configs = [ + (90, torch.rot90(tensor, k=1, dims=(-1, -2))), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), + (180, torch.rot90(tensor, k=2, dims=(-1, -2))), + ] + for a, true_tensor in test_configs: + + out_tensor = F.affine(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + if true_tensor is not None: + self.assertTrue( + true_tensor.equal(out_tensor), + msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) + ) + else: + true_tensor = out_tensor + + out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] + # Tolerence : 6% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.06, + msg="{}\n{} vs \n{}".format( + ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms.py b/test/test_transforms.py index b0eb844fcf8..e41132afcf8 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1311,17 +1311,17 @@ def test_rotate_fill(self): def test_affine(self): input_img = np.zeros((40, 40, 3), dtype=np.uint8) - pts = [] + # pts = [] cnt = [20, 20] for pt in [(16, 16), (20, 16), (20, 20)]: for i in range(-5, 5): for j in range(-5, 5): input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55] - pts.append((pt[0] + i, pt[1] + j)) - pts = list(set(pts)) + # pts.append((pt[0] + i, pt[1] + j)) + # pts = list(set(pts)) - with self.assertRaises(TypeError): - F.affine(input_img, 10) + with self.assertRaises(TypeError, msg="Argument translate should be a sequence"): + F.affine(input_img, 10, translate=0, scale=1, shear=1) pil_img = F.to_pil_image(input_img) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 45bc81255f2..e59356c6666 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,12 +1,11 @@ import math import numbers import warnings -from collections.abc import Iterable -from typing import Any +from typing import Any, Optional, Sequence import numpy as np from numpy import sin, cos, tan -from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +from PIL import Image, ImageOps, ImageEnhance import torch from torch import Tensor @@ -22,6 +21,7 @@ _is_pil_image = F_pil._is_pil_image +_parse_fill = F_pil._parse_fill def _get_image_size(img: Tensor) -> List[int]: @@ -486,43 +486,6 @@ def hflip(img: Tensor) -> Tensor: return F_t.hflip(img) -def _parse_fill(fill, img, min_pil_version): - """Helper function to get the fill color for rotate and perspective transforms. - - Args: - fill (n-tuple or int or float): Pixel fill value for area outside the transformed - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. - img (PIL Image): Image to be filled. - min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option - was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0) - - Returns: - dict: kwarg for ``fillcolor`` - """ - major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) - major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2]) - if major_found < major_required or (major_found == major_required and minor_found < minor_required): - if fill is None: - return {} - else: - msg = ("The option to fill background area of the transformed image, " - "requires pillow>={}") - raise RuntimeError(msg.format(min_pil_version)) - - num_bands = len(img.getbands()) - if fill is None: - fill = 0 - if isinstance(fill, (int, float)) and num_bands > 1: - fill = tuple([fill] * num_bands) - if not isinstance(fill, (int, float)) and len(fill) != num_bands: - msg = ("The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") - raise ValueError(msg.format(len(fill), num_bands)) - - return {"fillcolor": fill} - - def _get_perspective_coeffs(startpoints, endpoints): """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. @@ -848,14 +811,6 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): # # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 - if isinstance(shear, numbers.Number): - shear = [shear, 0] - - if not isinstance(shear, (tuple, list)) and len(shear) == 2: - raise ValueError( - "Shear should be a single value or a tuple/list containing " + - "two values. Got {}".format(shear)) - rot = math.radians(angle) sx, sy = [math.radians(s) for s in shear] @@ -870,21 +825,23 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - M = [d, -b, 0, - -c, a, 0] - M = [x / scale for x in M] + matrix = [d, -b, 0, -c, a, 0] + matrix = [x / scale for x in matrix] # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) - M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - M[2] += cx - M[5] += cy - return M + matrix[2] += cx + matrix[5] += cy + return matrix -def affine(img: Tensor, angle: int, translate: List[int], scale: float, shear List: , resample=0, fillcolor=None): +def affine( + img: Tensor, angle: int, translate: List[int], scale: float, shear: List[float], + resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: """Apply affine transformation on the image keeping image center invariant. The image can be a PIL Image or a Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -895,27 +852,51 @@ def affine(img: Tensor, angle: int, translate: List[int], scale: float, shear Li translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. - If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while - the second value corresponds to a shear parallel to the y axis. + If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while + the second value corresponds to a shear parallel to the y axis. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. - See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + Returns: + PIL Image or Tensor: Transformed image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(translate, Sequence): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") - assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ - "Argument translate should be a list or tuple of length 2" + if scale <= 0.0: + raise ValueError("Argument scale should be positive") - assert scale > 0.0, "Argument scale should be positive" + if not isinstance(shear, (numbers.Number, Sequence)): + raise TypeError("Shear should be either a single value or a sequence of two values") - output_size = img.size - center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) - matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - kwargs = {"fillcolor": fillcolor} if int(PILLOW_VERSION.split('.')[0]) >= 5 else {} - return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) + if isinstance(shear, numbers.Number): + shear = [shear, 0] + + if len(shear) != 2: + raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear)) + + if not isinstance(img, torch.Tensor): + img_size = _get_image_size(img) + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + center = (img_size[0] * 0.5, img_size[1] * 0.5) + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + + return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) + + # compute affine matrix (not inversed) + # matrix = _get_inverse_affine_matrix( + # (0, 0), -angle, [-t for t in translate], 1.0 / scale, [-s for s in shear] + # ) + matrix = _get_inverse_affine_matrix((0, 0), angle, translate, scale, shear) + return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) def to_grayscale(img, num_output_channels=1): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 994988ce1f6..f165b65f8d8 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,13 +1,14 @@ import numbers from typing import Any, List, Sequence +import numpy as np import torch +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION + try: import accimage except ImportError: accimage = None -from PIL import Image, ImageOps, ImageEnhance -import numpy as np @torch.jit.unused @@ -327,3 +328,65 @@ def resize(img, size, interpolation=Image.BILINEAR): return img.resize((ow, oh), interpolation) else: return img.resize(size[::-1], interpolation) + + +@torch.jit.unused +def _parse_fill(fill, img, min_pil_version): + """Helper function to get the fill color for rotate and perspective transforms. + + Args: + fill (n-tuple or int or float): Pixel fill value for area outside the transformed + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. + img (PIL Image): Image to be filled. + min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option + was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0) + + Returns: + dict: kwarg for ``fillcolor`` + """ + major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) + major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2]) + if major_found < major_required or (major_found == major_required and minor_found < minor_required): + if fill is None: + return {} + else: + msg = ("The option to fill background area of the transformed image, " + "requires pillow>={}") + raise RuntimeError(msg.format(min_pil_version)) + + num_bands = len(img.getbands()) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if not isinstance(fill, (int, float)) and len(fill) != num_bands: + msg = ("The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + return {"fillcolor": fill} + + +@torch.jit.unused +def affine(img, matrix, resample=0, fillcolor=None): + """Apply affine transformation on the PIL Image keeping image center invariant. + + Args: + img (PIL Image): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. + See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + Returns: + PIL Image: Transformed image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + output_size = img.size + opts = _parse_fill(fillcolor, img, '5.0.0') + return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 59cf6bc2764..a6ee64c3c1a 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,9 @@ +import warnings +from typing import Optional + import torch from torch import Tensor +from torch.nn.functional import affine_grid, grid_sample from torch.jit.annotations import List, BroadcastingList2 @@ -496,7 +500,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode padding as a single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation. Default is bilinear. + interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values: + nearest(=0) and bicubic(=3). Returns: Tensor: Resized image. @@ -571,3 +576,63 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: img = img.to(out_dtype) return img + + +def affine( + img: Tensor, matrix: List[int], resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the Tensor image keeping image center invariant. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values: + bilinear(=2). + fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + transform in the output image is always 0. + + Returns: + Tensor: Transformed image. + """ + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + + if fillcolor is not None: + warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero") + + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + if resample not in _interpolation_modes: + raise ValueError("This resampling mode is unsupported with Tensor input") + + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + shape = img.shape + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + + # make image NCHW + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + mode = _interpolation_modes[resample] + + out_dtype = img.dtype + need_cast = False + if img.dtype not in (torch.float32, torch.float64): + need_cast = True + img = img.to(torch.float32) + + img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + # it is better to round before cast + img = torch.round(img).to(out_dtype) + + return img From 483980418ee83c9980458ac3da0ab5dbfc68ed8e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 9 Jul 2020 15:20:05 +0200 Subject: [PATCH 3/5] Unified input for F.affine --- test/test_functional_tensor.py | 92 +++++++++++++++------ test/test_transforms.py | 11 +-- torchvision/transforms/functional.py | 51 ++++++++---- torchvision/transforms/functional_tensor.py | 2 +- 4 files changed, 110 insertions(+), 46 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 47d6a58f82e..f246588c439 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -349,13 +349,20 @@ def test_resized_crop(self): ) def test_affine(self): - # Let's do some tests on square image at first + # Tests on square image tensor, pil_img = self._create_data(26, 26) + + scripted_affine = torch.jit.script(F.affine) # 1) identity map out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) + out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + # 2) Test rotation test_configs = [ (90, torch.rot90(tensor, k=1, dims=(-1, -2))), @@ -367,29 +374,68 @@ def test_affine(self): (180, torch.rot90(tensor, k=2, dims=(-1, -2))), ] for a, true_tensor in test_configs: - - out_tensor = F.affine(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - if true_tensor is not None: - self.assertTrue( - true_tensor.equal(out_tensor), - msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) - ) - else: - true_tensor = out_tensor - - out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] - # Tolerence : 6% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.06, - msg="{}\n{} vs \n{}".format( - ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + if true_tensor is not None: + self.assertTrue( + true_tensor.equal(out_tensor), + msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) + ) + else: + true_tensor = out_tensor + + out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] + # Tolerance : less than 6% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.06, + msg="{}\n{} vs \n{}".format( + ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) ) - ) + # 3) Test translation + test_configs = [ + [10, 12], (12, 13) + ] + for t in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + self.compareTensorToPIL(out_tensor, out_pil_img) + + # 3) Test rotation + translation + scale + share + test_configs = [ + (45, [5, 6], 1.0, [0.0, 0.0]), + (33, (5, -4), 1.0, [0.0, 0.0]), + (45, [5, 4], 1.2, [0.0, 0.0]), + (33, (4, 8), 2.0, [0.0, 0.0]), + (85, (10, -10), 0.7, [0.0, 0.0]), + (0, [0, 0], 1.0, [35.0, ]), + (25, [0, 0], 1.2, [0.0, 15.0]), + (45, [10, 0], 0.7, [2.0, 5.0]), + (45, [10, -10], 1.2, [4.0, 5.0]), + ] + for r in [0, ]: + for a, t, s, sh in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) if __name__ == '__main__': diff --git a/test/test_transforms.py b/test/test_transforms.py index e41132afcf8..390d372a4b3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1373,9 +1373,10 @@ def _test_transformation(a, t, s, sh): inv_true_matrix = np.linalg.inv(true_matrix) for y in range(true_result.shape[0]): for x in range(true_result.shape[1]): - res = np.dot(inv_true_matrix, [x, y, 1]) - _x = int(res[0] + 0.5) - _y = int(res[1] + 0.5) + # transform pixel's center instead of pixel's TL corner + res = np.dot(inv_true_matrix, [x + 0.5, y + 0.5, 1]) + _x = int(res[0]) + _y = int(res[1]) if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]: true_result[y, x, :] = input_img[_y, _x, :] @@ -1384,8 +1385,8 @@ def _test_transformation(a, t, s, sh): # Compute number of different pixels: np_result = np.array(result) n_diff_pixels = np.sum(np_result != true_result) / 3 - # Accept 3 wrong pixels - self.assertLess(n_diff_pixels, 3, + # Accept 7 wrong pixels + self.assertLess(n_diff_pixels, 7, "a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) + "n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0]))) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e59356c6666..9893ca7efff 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -791,7 +791,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): return img.rotate(angle, resample, expand, center, **opts) -def _get_inverse_affine_matrix(center, angle, translate, scale, shear): +def _get_inverse_affine_matrix( + center: List[int], angle: float, translate: List[float], scale: float, shear: List[float] +) -> List[float]: # Helper method to compute inverse matrix for affine transformation # As it is explained in PIL.Image.rotate @@ -818,14 +820,14 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): tx, ty = translate # RSS without scaling - a = cos(rot - sy) / cos(sy) - b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) - c = sin(rot - sy) / cos(sy) - d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + a = math.cos(rot - sy) / math.cos(sy) + b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) + c = math.sin(rot - sy) / math.cos(sy) + d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - matrix = [d, -b, 0, -c, a, 0] + matrix = [d, -b, 0.0, -c, a, 0.0] matrix = [x / scale for x in matrix] # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 @@ -835,11 +837,12 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): # Apply center translation: C * RSS^-1 * C^-1 * T^-1 matrix[2] += cx matrix[5] += cy + return matrix def affine( - img: Tensor, angle: int, translate: List[int], scale: float, shear: List[float], + img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], resample: int = 0, fillcolor: Optional[int] = None ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. @@ -863,7 +866,10 @@ def affine( Returns: PIL Image or Tensor: Transformed image. """ - if not isinstance(translate, Sequence): + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): raise TypeError("Argument translate should be a sequence") if len(translate) != 2: @@ -872,30 +878,41 @@ def affine( if scale <= 0.0: raise ValueError("Argument scale should be positive") - if not isinstance(shear, (numbers.Number, Sequence)): + if not isinstance(shear, (numbers.Number, (list, tuple))): raise TypeError("Shear should be either a single value or a sequence of two values") + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + if isinstance(shear, numbers.Number): - shear = [shear, 0] + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] if len(shear) != 2: raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear)) + img_size = _get_image_size(img) if not isinstance(img, torch.Tensor): - img_size = _get_image_size(img) # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine - center = (img_size[0] * 0.5, img_size[1] * 0.5) + center = [img_size[0] * 0.5, img_size[1] * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) - # compute affine matrix (not inversed) - # matrix = _get_inverse_affine_matrix( - # (0, 0), -angle, [-t for t in translate], 1.0 / scale, [-s for s in shear] - # ) - matrix = _get_inverse_affine_matrix((0, 0), angle, translate, scale, shear) + # we need to rescale translate by image size / 2 as its values can be between -1 and 1 + translate = [2.0 * t / s for s, t in zip(img_size, translate)] + + matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a6ee64c3c1a..2bd4549059e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -579,7 +579,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def affine( - img: Tensor, matrix: List[int], resample: int = 0, fillcolor: Optional[int] = None + img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None ) -> Tensor: """Apply affine transformation on the Tensor image keeping image center invariant. From f33b39166bb85e0661b29f50faf5534d9a3ee0fc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 9 Jul 2020 17:42:45 +0200 Subject: [PATCH 4/5] Removed commented code --- test/test_transforms.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 390d372a4b3..140167ebe08 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1311,14 +1311,11 @@ def test_rotate_fill(self): def test_affine(self): input_img = np.zeros((40, 40, 3), dtype=np.uint8) - # pts = [] cnt = [20, 20] for pt in [(16, 16), (20, 16), (20, 20)]: for i in range(-5, 5): for j in range(-5, 5): input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55] - # pts.append((pt[0] + i, pt[1] + j)) - # pts = list(set(pts)) with self.assertRaises(TypeError, msg="Argument translate should be a sequence"): F.affine(input_img, 10, translate=0, scale=1, shear=1) From 8e85d7b39794badb898c203292aca553fc53865f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 15 Jul 2020 19:37:22 +0200 Subject: [PATCH 5/5] Removed unused imports --- torchvision/transforms/functional.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 1cbaff08b82..340592a01f0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,10 +1,9 @@ import math import numbers import warnings -from typing import Any, Optional, Sequence +from typing import Any, Optional import numpy as np -from numpy import sin, cos, tan from PIL import Image import torch