From 36fef0d6182b6075900094095e8514b2a4cf88ad Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 09:29:45 +0200 Subject: [PATCH 01/11] Added code for F_t.rotate with test - updated F.affine tests --- test/test_functional_tensor.py | 57 ++++++-- torchvision/transforms/functional.py | 88 +++++++----- torchvision/transforms/functional_pil.py | 31 +++++ torchvision/transforms/functional_tensor.py | 147 ++++++++++++++++---- 4 files changed, 252 insertions(+), 71 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 7b4b9b490da..909629a2806 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -406,7 +406,7 @@ def test_affine(self): ) # 3) Test translation test_configs = [ - [10, 12], (12, 13) + [10, 12], (-12, -13) ] for t in test_configs: for fn in [F.affine, scripted_affine]: @@ -418,21 +418,21 @@ def test_affine(self): test_configs = [ (45, [5, 6], 1.0, [0.0, 0.0]), (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [5, 4], 1.2, [0.0, 0.0]), - (33, (4, 8), 2.0, [0.0, 0.0]), + (45, [-5, 4], 1.2, [0.0, 0.0]), + (33, (-4, -8), 2.0, [0.0, 0.0]), (85, (10, -10), 0.7, [0.0, 0.0]), (0, [0, 0], 1.0, [35.0, ]), (25, [0, 0], 1.2, [0.0, 15.0]), - (45, [10, 0], 0.7, [2.0, 5.0]), - (45, [10, -10], 1.2, [4.0, 5.0]), + (45, [-10, 0], 0.7, [2.0, 5.0]), + (45, [-10, -10], 1.2, [4.0, 5.0]), ] for r in [0, ]: for a, t, s, sh in test_configs: + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.affine, scripted_affine]: out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] # Tolerance : less than 5% of different pixels @@ -444,6 +444,47 @@ def test_affine(self): ) ) + def test_rotate(self): + # Tests on square image + tensor, pil_img = self._create_data(26, 26) + scripted_rotate = torch.jit.script(F.rotate) + + img_size = pil_img.size + + centers = [ + None, + (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), + [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] + ] + + for r in [0, ]: + for a in range(-120, 120, 23): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (r, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) + ) + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 1% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.01, + msg="{}: {}\n{} vs \n{}".format( + (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 340592a01f0..f62ea49b382 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -758,40 +758,8 @@ def adjust_gamma(img, gamma, gain=1): return img -def rotate(img, angle, resample=False, expand=False, center=None, fill=None): - """Rotate the image by angle. - - - Args: - img (PIL Image): PIL Image to be rotated. - angle (float or int): In degrees degrees counter clockwise order. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - expand (bool, optional): Optional expansion flag. - If true, expands the output image to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple, optional): Optional center of rotation. - Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - opts = _parse_fill(fill, img, '5.2.0') - - return img.rotate(angle, resample, expand, center, **opts) - - def _get_inverse_affine_matrix( - center: List[int], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -840,6 +808,56 @@ def _get_inverse_affine_matrix( return matrix +def rotate( + img: Tensor, angle: float, resample: int = 0, expand: bool = False, + center: Optional[List[int]] = None, fill: Optional[int] = None +) -> Tensor: + """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image or Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + if not isinstance(img, torch.Tensor): + return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) + + center_f = [0.0, 0.0] + if center is not None: + img_size = _get_image_size(img) + # Center is normalized to [-1, +1] + center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)] + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + + def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], resample: int = 0, fillcolor: Optional[int] = None @@ -849,7 +867,7 @@ def affine( to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image or Tensor): image to be rotated. + img (PIL Image or Tensor): image to transform. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale @@ -911,7 +929,7 @@ def affine( # we need to rescale translate by image size / 2 as its values can be between -1 and 1 translate = [2.0 * t / s for s, t in zip(img_size, translate)] - matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) + matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index f165b65f8d8..fd603d83e4d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -390,3 +390,34 @@ def affine(img, matrix, resample=0, fillcolor=None): output_size = img.size opts = _parse_fill(fillcolor, img, '5.0.0') return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) + + +@torch.jit.unused +def rotate(img, angle, resample=0, expand=False, center=None, fill=None): + """Rotate PIL image by angle. + + Args: + img (PIL Image): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + opts = _parse_fill(fill, img, '5.2.0') + return img.rotate(angle, resample, expand, center, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2bd4549059e..d45a65cb3bc 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional +from typing import Optional, Dict, Tuple import torch from torch import Tensor @@ -578,48 +578,32 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: return img -def affine( - img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None -) -> Tensor: - """Apply affine transformation on the Tensor image keeping image center invariant. +def _assert_grid_transform_inputs( + img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str] +): + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError("img should be Tensor Image. Got {}".format(type(img))) - Args: - img (Tensor): image to be rotated. - matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values: - bilinear(=2). - fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the - transform in the output image is always 0. + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix))) - Returns: - Tensor: Transformed image. - """ - if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): - raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + if len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") if fillcolor is not None: - warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero") - - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } + warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") if resample not in _interpolation_modes: raise ValueError("This resampling mode is unsupported with Tensor input") - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) - shape = img.shape - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) +def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: # make image NCHW need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True - mode = _interpolation_modes[resample] - out_dtype = img.dtype need_cast = False if img.dtype not in (torch.float32, torch.float64): @@ -636,3 +620,110 @@ def affine( img = torch.round(img).to(out_dtype) return img + + +def affine( + img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the Tensor image keeping image center invariant. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + transform in the output image is always 0. + + Returns: + Tensor: Transformed image. + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) + + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + shape = img.shape + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode) + + +def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: + point = torch.tensor([0.0, 0.0, 1.0]) + pts = [] + for i in [0.0, float(h)]: + for j in [0.0, float(w)]: + # we need to normalize coordinates according to + # [0, s] is mapped [-1, +1] as theta translation parameters are + # normalized like that + point[1], point[0] = 2.0 * i / w - 1.0, 2.0 * j / h - 1.0 + new_point = torch.matmul(theta, point) + # denormalize back to w, h: + new_point = (new_point + 1.0) * torch.tensor([w, h]) / 2.0 + pts.append(new_point) + pts = torch.stack(pts) + min_vals, _ = pts.min(dim=0) + max_vals, _ = pts.max(dim=0) + size = torch.ceil(max_vals) - torch.floor(min_vals) + return int(size[0]), int(size[1]) + + +def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: + if expand: + ow, oh = _compute_output_size(theta, w, h) + else: + ow, oh = w, h + output_grid = torch.zeros(1, oh, ow, 2) + + d = 0.5 # if not align_corners + + point = torch.tensor([0.0, 0.0, 1.0]) + for i in range(oh): + for j in range(ow): + point[1] = (i + d - oh * 0.5) / (0.5 * h) + point[0] = (j + d - ow * 0.5) / (0.5 * w) + output_grid[0, i, j, :] = torch.matmul(theta, point) + return output_grid + + +def rotate( + img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None +) -> Tensor: + """Rotate the Tensor image by angle. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + fill (n-tuple or int or float): this option is not supported for Tensor input. + Fill value for the area outside the transform in the output image is always 0. + + Returns: + Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) + + theta = torch.tensor(matrix).reshape(2, 3) + shape = img.shape + grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode) From 2b98bdc6392e21b0d3941f850b7f67c9b2187e0b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 09:40:44 +0200 Subject: [PATCH 02/11] Rotate test tolerance to 2% --- test/test_functional_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 909629a2806..5972a078db4 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -476,10 +476,10 @@ def test_rotate(self): ) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 1% of different pixels + # Tolerance : less than 2% of different pixels self.assertLess( ratio_diff_pixels, - 0.01, + 0.02, msg="{}: {}\n{} vs \n{}".format( (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) From c7231bd5137ba9d26834f07a98a816cb3ba99d83 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 21 Jul 2020 10:17:13 +0200 Subject: [PATCH 03/11] Fixes failing test --- test/test_transforms.py | 2 +- torchvision/transforms/functional_pil.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 125502a3ad5..eb9efb3347e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1258,7 +1258,7 @@ def test_rotate(self): x = np.zeros((100, 100, 3), dtype=np.uint8) x[40, 40] = [255, 255, 255] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, r"img should be PIL Image"): F.rotate(x, 10) img = F.to_pil_image(x) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index fd603d83e4d..d128943c0ed 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -419,5 +419,8 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + opts = _parse_fill(fill, img, '5.2.0') return img.rotate(angle, resample, expand, center, **opts) From d72cb3dfa9686a42a2c68d62e52eea031c06a9b4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 Aug 2020 17:17:16 +0200 Subject: [PATCH 04/11] Optimized _expanded_affine_grid with a single matmul op --- torchvision/transforms/functional_tensor.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5e1de46f61e..597702cd274 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -719,17 +719,14 @@ def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) - ow, oh = _compute_output_size(theta, w, h) else: ow, oh = w, h - output_grid = torch.zeros(1, oh, ow, 2) - d = 0.5 # if not align_corners - point = torch.tensor([0.0, 0.0, 1.0]) - for i in range(oh): - for j in range(ow): - point[1] = (i + d - oh * 0.5) / (0.5 * h) - point[0] = (j + d - ow * 0.5) / (0.5 * w) - output_grid[0, i, j, :] = torch.matmul(theta, point) - return output_grid + x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w) + y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h) + y, x = torch.meshgrid(y, x) + pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) + output_grid = torch.matmul(pts, theta.t()) + return output_grid.unsqueeze(dim=0) def rotate( From a249f771250364455655cf73f460ce02d7876796 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 Aug 2020 17:30:54 +0200 Subject: [PATCH 05/11] Recoded _compute_output_size --- torchvision/transforms/functional_tensor.py | 29 ++++++++++----------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 597702cd274..3641d722730 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -695,21 +695,20 @@ def affine( def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: - point = torch.tensor([0.0, 0.0, 1.0]) - pts = [] - for i in [0.0, float(h)]: - for j in [0.0, float(w)]: - # we need to normalize coordinates according to - # [0, s] is mapped [-1, +1] as theta translation parameters are - # normalized like that - point[1], point[0] = 2.0 * i / w - 1.0, 2.0 * j / h - 1.0 - new_point = torch.matmul(theta, point) - # denormalize back to w, h: - new_point = (new_point + 1.0) * torch.tensor([w, h]) / 2.0 - pts.append(new_point) - pts = torch.stack(pts) - min_vals, _ = pts.min(dim=0) - max_vals, _ = pts.max(dim=0) + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # we need to normalize coordinates according to + # [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that + pts = torch.tensor([ + [-1.0, -1.0, 1.0], + [-1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, -1.0, 1.0], + ]) + # denormalize back to w, h: + new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0 + min_vals, _ = new_pts.min(dim=0) + max_vals, _ = new_pts.max(dim=0) size = torch.ceil(max_vals) - torch.floor(min_vals) return int(size[0]), int(size[1]) From a2c6dd1fe7f2adfeb5286871517a21b8dc5d154a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 4 Aug 2020 15:37:07 +0200 Subject: [PATCH 06/11] [WIP] recoded F_t.rotate internal methods --- test/test_functional_tensor.py | 5 ++- torchvision/transforms/functional.py | 2 +- torchvision/transforms/functional_tensor.py | 44 +++++++++++++++------ 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0532f171471..2dfea4f97a2 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -454,6 +454,7 @@ def test_affine(self): (25, [0, 0], 1.2, [0.0, 15.0]), (45, [-10, 0], 0.7, [2.0, 5.0]), (45, [-10, -10], 1.2, [4.0, 5.0]), + (90, [0, 0], 1.0, [0.0, 0.0]), ] for r in [0, ]: for a, t, s, sh in test_configs: @@ -475,7 +476,7 @@ def test_affine(self): def test_rotate(self): # Tests on square image - tensor, pil_img = self._create_data(26, 26) + tensor, pil_img = self._create_data(32, 26) scripted_rotate = torch.jit.script(F.rotate) img_size = pil_img.size @@ -487,7 +488,7 @@ def test_rotate(self): ] for r in [0, ]: - for a in range(-120, 120, 23): + for a in range(-180, 180, 23): for e in [True, False]: for c in centers: diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 374c131cf44..0765e8a2baf 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -853,7 +853,7 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) - return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill, center=center) def affine( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 3641d722730..a49719f2ebc 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -694,28 +694,39 @@ def affine( return _apply_grid_transform(img, grid, mode) -def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: +def _compute_output_size(theta: Tensor, w: int, h: int, center: Optional[List[int]] = None) -> Tuple[int, int]: # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - # we need to normalize coordinates according to - # [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that + # To compute extended output image size we should use denormalized theta + # where translation part (theta[0, 2] and theta[1, 2]) is computed using original center values in range [0, w] + # and [0, h]. Currently, theta[0, 2] and theta[1, 2] are normalized to [-1, 1] range + + center_f = [w * 0.5, h * 0.5] + if center is not None: + center_f = [float(v) for v in center] + + denorm_theta = theta.clone() + denorm_theta[0, 2] = center_f[0] * (1.0 - denorm_theta[0, 0]) - center_f[1] * denorm_theta[0, 1] + denorm_theta[1, 2] = center_f[0] * (1.0 - denorm_theta[1, 0]) - center_f[1] * denorm_theta[1, 1] + pts = torch.tensor([ - [-1.0, -1.0, 1.0], - [-1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - [1.0, -1.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 1.0 * h, 1.0], + [1.0 * w, 1.0 * h, 1.0], + [1.0 * w, 0.0, 1.0], ]) - # denormalize back to w, h: - new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0 + new_pts = torch.matmul(pts, denorm_theta.t()) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) size = torch.ceil(max_vals) - torch.floor(min_vals) return int(size[0]), int(size[1]) -def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: +def _expanded_affine_grid( + theta: Tensor, w: int, h: int, expand: bool = False, center: Optional[List[int]] = None +) -> Tensor: if expand: - ow, oh = _compute_output_size(theta, w, h) + ow, oh = _compute_output_size(theta, w, h, center) else: ow, oh = w, h d = 0.5 # if not align_corners @@ -729,13 +740,19 @@ def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) - def rotate( - img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None + img: Tensor, + matrix: List[float], + resample: int = 0, + expand: bool = False, + fill: Optional[int] = None, + center: Optional[List[int]] = None # this argument helps to correctly compute output image size if expand=True ) -> Tensor: """Rotate the Tensor image by angle. Args: img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. + Translation part (``matrix[2]`` and ``matrix[5]``) should be normalized to ``(-1, +1)`` range. resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: bilinear(=2). expand (bool, optional): Optional expansion flag. @@ -744,6 +761,7 @@ def rotate( Note that the expand flag assumes rotation around the center and no translation. fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + center (list or tuple, optional): this argument helps to correctly compute output image size if expand=True Returns: Tensor: Rotated image. @@ -760,7 +778,7 @@ def rotate( theta = torch.tensor(matrix).reshape(2, 3) shape = img.shape - grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand, center=center) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) From 407e9c43cb4d5c8f86c3a6fa9d963a44f6f83b7b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 4 Aug 2020 22:32:15 +0200 Subject: [PATCH 07/11] [WIP] Fixed F.affine to support rectangular images --- test/test_functional_tensor.py | 189 +++++++++++--------- torchvision/transforms/functional.py | 10 +- torchvision/transforms/functional_tensor.py | 23 ++- 3 files changed, 135 insertions(+), 87 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d01a357d7b5..7b95ae9af03 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -385,93 +385,116 @@ def test_resized_crop(self): ) def test_affine(self): - # Tests on square image - tensor, pil_img = self._create_data(26, 26) - + # Tests on square and rectangular images scripted_affine = torch.jit.script(F.affine) - # 1) identity map - out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - self.assertTrue( - tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) - ) - out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - self.assertTrue( - tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) - ) - # 2) Test rotation - test_configs = [ - (90, torch.rot90(tensor, k=1, dims=(-1, -2))), - (45, None), - (30, None), - (-30, None), - (-45, None), - (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), - (180, torch.rot90(tensor, k=2, dims=(-1, -2))), - ] - for a, true_tensor in test_configs: - for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - if true_tensor is not None: - self.assertTrue( - true_tensor.equal(out_tensor), - msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) - ) - else: - true_tensor = out_tensor - - out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] - # Tolerance : less than 6% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.06, - msg="{}\n{} vs \n{}".format( - ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] - ) - ) - # 3) Test translation - test_configs = [ - [10, 12], (12, 13) - ] - for t in test_configs: - for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) - self.compareTensorToPIL(out_tensor, out_pil_img) - - # 3) Test rotation + translation + scale + share - test_configs = [ - (45, [5, 6], 1.0, [0.0, 0.0]), - (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [5, 4], 1.2, [0.0, 0.0]), - (33, (4, 8), 2.0, [0.0, 0.0]), - (85, (10, -10), 0.7, [0.0, 0.0]), - (0, [0, 0], 1.0, [35.0, ]), - (25, [0, 0], 1.2, [0.0, 15.0]), - (45, [10, 0], 0.7, [2.0, 5.0]), - (45, [10, -10], 1.2, [4.0, 5.0]), - ] - for r in [0, ]: - for a, t, s, sh in test_configs: + for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: + + # 1) identity map + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + + if pil_img.size[0] == pil_img.size[1]: + # 2) Test rotation + test_configs = [ + (90, torch.rot90(tensor, k=1, dims=(-1, -2))), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), + (180, torch.rot90(tensor, k=2, dims=(-1, -2))), + ] + for a, true_tensor in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + if true_tensor is not None: + self.assertTrue( + true_tensor.equal(out_tensor), + msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) + ) + else: + true_tensor = out_tensor + + out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] + # Tolerance : less than 6% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.06, + msg="{}\n{} vs \n{}".format( + ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + else: + test_configs = [ + 90, 45, 15, -30, -60, -120 + ] + for a in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 3% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.03, + msg="{}: {}\n{} vs \n{}".format( + a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + + # 3) Test translation + test_configs = [ + [10, 12], (12, 13) + ] + for t in test_configs: for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.05, - msg="{}: {}\n{} vs \n{}".format( - (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + self.compareTensorToPIL(out_tensor, out_pil_img) + + # 3) Test rotation + translation + scale + share + test_configs = [ + (45, [5, 6], 1.0, [0.0, 0.0]), + (33, (5, -4), 1.0, [0.0, 0.0]), + (45, [5, 4], 1.2, [0.0, 0.0]), + (33, (4, 8), 2.0, [0.0, 0.0]), + (85, (10, -10), 0.7, [0.0, 0.0]), + (0, [0, 0], 1.0, [35.0, ]), + (25, [0, 0], 1.2, [0.0, 15.0]), + (45, [10, 0], 0.7, [2.0, 5.0]), + (45, [10, -10], 1.2, [4.0, 5.0]), + ] + for r in [0, ]: + for a, t, s, sh in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) ) - ) if __name__ == '__main__': diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 74c92d61bbd..c37b14f29b3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -908,10 +908,14 @@ def affine( return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) - # we need to rescale translate by image size / 2 as its values can be between -1 and 1 - translate = [2.0 * t / s for s, t in zip(img_size, translate)] + if img_size[0] == img_size[1]: + # we need to rescale translate by image size / 2 as its values can be between -1 and 1 + translate_f = [2.0 * t / s for s, t in zip(img_size, translate)] + else: + # if rectangular image, we should not rescale translation part + translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) + matrix = _get_inverse_affine_matrix([0, 0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 357f23b88fc..520483bfa91 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -619,6 +619,22 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: return img +def _gen_affine_grid( + theta: Tensor, w: int, h: int, ow: int, oh: int, +) -> Tensor: + d = 0.5 + x = torch.arange(ow) + d - ow * 0.5 + y = torch.arange(oh) + d - oh * 0.5 + + y, x = torch.meshgrid(y, x) + pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) + output_grid = torch.matmul(pts, theta.t()) + + output_grid = output_grid / torch.tensor([0.5 * w, 0.5 * h]) + + return output_grid.unsqueeze(dim=0) + + def affine( img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None ) -> Tensor: @@ -651,7 +667,12 @@ def affine( theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) shape = img.shape - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + if shape[-2] == shape[-1]: + # here we need normalized translation part of theta + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + else: + # here we need denormalized translation part of theta + grid = _gen_affine_grid(theta[0, :, :], w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) # make image NCHW need_squeeze = False From 4325f671b05df464fc5eb0b1dd3b01ded3440785 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Aug 2020 11:23:05 +0200 Subject: [PATCH 08/11] Recoded _gen_affine_grid to optimized version ~ affine_grid - Fixes flake8 --- test/test_functional_tensor.py | 8 ++++++-- torchvision/transforms/functional_tensor.py | 21 ++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 7b95ae9af03..75bd47a7f9c 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -422,7 +422,9 @@ def test_affine(self): else: true_tensor = out_tensor - out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 @@ -442,7 +444,9 @@ def test_affine(self): for a in test_configs: for fn in [F.affine, scripted_affine]: out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 520483bfa91..1eabb184812 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -622,17 +622,20 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def _gen_affine_grid( theta: Tensor, w: int, h: int, ow: int, oh: int, ) -> Tensor: - d = 0.5 - x = torch.arange(ow) + d - ow * 0.5 - y = torch.arange(oh) + d - oh * 0.5 - - y, x = torch.meshgrid(y, x) - pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) - output_grid = torch.matmul(pts, theta.t()) + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate - output_grid = output_grid / torch.tensor([0.5 * w, 0.5 * h]) + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3) + base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow)) + base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1)) + base_grid[..., 2].fill_(1) - return output_grid.unsqueeze(dim=0) + output_grid = base_grid.view(1, oh * ow, 3).bmm(theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h])) + return output_grid.view(1, oh, ow, 2) def affine( From 0e2a3c7f751e9012bd602a82c0af915e92cd3924 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Aug 2020 12:23:00 +0200 Subject: [PATCH 09/11] [WIP] Use _gen_affine_grid for affine and rotate --- torchvision/transforms/functional.py | 8 +---- torchvision/transforms/functional_tensor.py | 34 +++++---------------- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index ae286cb5438..bbf019a4124 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -926,13 +926,7 @@ def affine( return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) - if img_size[0] == img_size[1]: - # we need to rescale translate by image size / 2 as its values can be between -1 and 1 - translate_f = [2.0 * t / s for s, t in zip(img_size, translate)] - else: - # if rectangular image, we should not rescale translation part - translate_f = [1.0 * t for t in translate] - + translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 12e6f01485f..00462c257ba 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -707,15 +707,7 @@ def affine( theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) shape = img.shape - - # grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) - if shape[-2] == shape[-1]: - # here we need normalized translation part of theta - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) - else: - # here we need denormalized translation part of theta - grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - + grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) @@ -732,6 +724,7 @@ def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: [1.0, -1.0, 1.0], ]) # denormalize back to w, h: + theta = theta[0, ...] new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0 min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) @@ -739,21 +732,6 @@ def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: return int(size[0]), int(size[1]) -def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: - if expand: - ow, oh = _compute_output_size(theta, w, h) - else: - ow, oh = w, h - d = 0.5 # if not align_corners - - x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w) - y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h) - y, x = torch.meshgrid(y, x) - pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) - output_grid = torch.matmul(pts, theta.t()) - return output_grid.unsqueeze(dim=0) - - def rotate( img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None ) -> Tensor: @@ -784,9 +762,11 @@ def rotate( _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) - theta = torch.tensor(matrix).reshape(2, 3) - shape = img.shape - grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + theta = torch.tensor(matrix).reshape(1, 2, 3) + w, h = img.shape[-1], img.shape[-2] + + ow, oh = _compute_output_size(theta, w, h) if expand else (w, h) + grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) From 781747c924c12780e267509bff351a92b88b9b87 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Aug 2020 13:44:34 +0200 Subject: [PATCH 10/11] Fixed tests on square / rectangular images for affine and rotate ops --- test/test_functional_tensor.py | 67 +++++++++++---------- torchvision/transforms/functional.py | 5 +- torchvision/transforms/functional_tensor.py | 53 +++++----------- 3 files changed, 53 insertions(+), 72 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e736300c11f..4ffa8cf280e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -503,44 +503,47 @@ def test_affine(self): def test_rotate(self): # Tests on square image - tensor, pil_img = self._create_data(26, 26) scripted_rotate = torch.jit.script(F.rotate) - img_size = pil_img.size - - centers = [ - None, - (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), - [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] - ] + for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: - for r in [0, ]: - for a in range(-180, 180, 23): - for e in [True, False]: - for c in centers: + img_size = pil_img.size + centers = [ + None, + (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), + [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] + ] - out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) - - self.assertEqual( - out_tensor.shape, - out_pil_tensor.shape, - msg="{}: {} vs {}".format( - (r, a, e, c), out_tensor.shape, out_pil_tensor.shape + for r in [0, ]: + for a in range(-180, 180, 17): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) ) - ) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 2% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.02, - msg="{}: {}\n{} vs \n{}".format( - (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 2% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.02, + msg="{}: {}\n{} vs \n{}".format( + (img_size, r, a, e, c), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) ) - ) if __name__ == '__main__': diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index ab952fd2242..6158ccb82da 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -880,8 +880,9 @@ def rotate( center_f = [0.0, 0.0] if center is not None: img_size = _get_image_size(img) - # Center is normalized to [-1, +1] - center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)] + # Center values should be in pixel coordinates but translated such (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 7440154ab4a..452e5d37ed8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -712,49 +712,28 @@ def affine( return _apply_grid_transform(img, grid, mode) -def _compute_output_size(theta: Tensor, w: int, h: int, center: Optional[List[int]] = None) -> Tuple[int, int]: +def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: - # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - # To compute extended output image size we should use denormalized theta - # where translation part (theta[0, 2] and theta[1, 2]) is computed using original center values in range [0, w] - # and [0, h]. Currently, theta[0, 2] and theta[1, 2] are normalized to [-1, 1] range - - center_f = [w * 0.5, h * 0.5] - if center is not None: - center_f = [float(v) for v in center] - - denorm_theta = theta.clone() - denorm_theta[0, 2] = center_f[0] * (1.0 - denorm_theta[0, 0]) - center_f[1] * denorm_theta[0, 1] - denorm_theta[1, 2] = center_f[0] * (1.0 - denorm_theta[1, 0]) - center_f[1] * denorm_theta[1, 1] + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. pts = torch.tensor([ - [0.0, 0.0, 1.0], - [0.0, 1.0 * h, 1.0], - [1.0 * w, 1.0 * h, 1.0], - [1.0 * w, 0.0, 1.0], + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], ]) - new_pts = torch.matmul(pts, denorm_theta.t()) + new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) - size = torch.ceil(max_vals) - torch.floor(min_vals) - return int(size[0]), int(size[1]) - -def _expanded_affine_grid( - theta: Tensor, w: int, h: int, expand: bool = False, center: Optional[List[int]] = None -) -> Tensor: - if expand: - ow, oh = _compute_output_size(theta, w, h, center) - else: - ow, oh = w, h - d = 0.5 # if not align_corners - - x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w) - y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h) - y, x = torch.meshgrid(y, x) - pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) - output_grid = torch.matmul(pts, theta.t()) - return output_grid.unsqueeze(dim=0) + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + cmax = torch.ceil((max_vals / tol).trunc_() * tol) + cmin = torch.floor((min_vals / tol).trunc_() * tol) + size = cmax - cmin + return int(size[0]), int(size[1]) def rotate( @@ -787,10 +766,8 @@ def rotate( } _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) - theta = torch.tensor(matrix).reshape(1, 2, 3) w, h = img.shape[-1], img.shape[-2] - ow, oh = _compute_output_size(theta, w, h) if expand else (w, h) grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) mode = _interpolation_modes[resample] From c8621262a9025b7467ce9421a6884d13b0875a77 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Aug 2020 13:48:08 +0200 Subject: [PATCH 11/11] Removed redefinition of F.rotate - due to bad merge --- torchvision/transforms/functional.py | 34 +--------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 6158ccb82da..689137b44cb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -756,38 +756,6 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: return F_t.adjust_gamma(img, gamma, gain) -def rotate(img, angle, resample=False, expand=False, center=None, fill=None): - """Rotate the image by angle. - - - Args: - img (PIL Image): PIL Image to be rotated. - angle (float or int): In degrees degrees counter clockwise order. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - expand (bool, optional): Optional expansion flag. - If true, expands the output image to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple, optional): Optional center of rotation. - Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - opts = _parse_fill(fill, img, '5.2.0') - - return img.rotate(angle, resample, expand, center, **opts) - - def _get_inverse_affine_matrix( center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: @@ -880,7 +848,7 @@ def rotate( center_f = [0.0, 0.0] if center is not None: img_size = _get_image_size(img) - # Center values should be in pixel coordinates but translated such (0, 0) corresponds to image center. + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] # due to current incoherence of rotation angle direction between affine and rotate implementations