diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0532f171471..4ffa8cf280e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -385,134 +385,165 @@ def test_resized_crop(self): ) def test_affine(self): - # Tests on square image - tensor, pil_img = self._create_data(26, 26) - + # Tests on square and rectangular images scripted_affine = torch.jit.script(F.affine) - # 1) identity map - out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - self.assertTrue( - tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) - ) - out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - self.assertTrue( - tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) - ) - # 2) Test rotation - test_configs = [ - (90, torch.rot90(tensor, k=1, dims=(-1, -2))), - (45, None), - (30, None), - (-30, None), - (-45, None), - (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), - (180, torch.rot90(tensor, k=2, dims=(-1, -2))), - ] - for a, true_tensor in test_configs: - for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - if true_tensor is not None: - self.assertTrue( - true_tensor.equal(out_tensor), - msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) - ) - else: - true_tensor = out_tensor - - out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] - # Tolerance : less than 6% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.06, - msg="{}\n{} vs \n{}".format( - ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] - ) - ) - # 3) Test translation - test_configs = [ - [10, 12], (-12, -13) - ] - for t in test_configs: - for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) - self.compareTensorToPIL(out_tensor, out_pil_img) - - # 3) Test rotation + translation + scale + share - test_configs = [ - (45, [5, 6], 1.0, [0.0, 0.0]), - (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [-5, 4], 1.2, [0.0, 0.0]), - (33, (-4, -8), 2.0, [0.0, 0.0]), - (85, (10, -10), 0.7, [0.0, 0.0]), - (0, [0, 0], 1.0, [35.0, ]), - (25, [0, 0], 1.2, [0.0, 15.0]), - (45, [-10, 0], 0.7, [2.0, 5.0]), - (45, [-10, -10], 1.2, [4.0, 5.0]), - ] - for r in [0, ]: - for a, t, s, sh in test_configs: - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: + + # 1) identity map + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + + if pil_img.size[0] == pil_img.size[1]: + # 2) Test rotation + test_configs = [ + (90, torch.rot90(tensor, k=1, dims=(-1, -2))), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), + (180, torch.rot90(tensor, k=2, dims=(-1, -2))), + ] + for a, true_tensor in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + if true_tensor is not None: + self.assertTrue( + true_tensor.equal(out_tensor), + msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) + ) + else: + true_tensor = out_tensor + + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] + # Tolerance : less than 6% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.06, + msg="{}\n{} vs \n{}".format( + ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + else: + test_configs = [ + 90, 45, 15, -30, -60, -120 + ] + for a in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine( + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + ) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 3% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.03, + msg="{}: {}\n{} vs \n{}".format( + a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + + # 3) Test translation + test_configs = [ + [10, 12], (-12, -13) + ] + for t in test_configs: for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.05, - msg="{}: {}\n{} vs \n{}".format( - (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + self.compareTensorToPIL(out_tensor, out_pil_img) + + # 3) Test rotation + translation + scale + share + test_configs = [ + (45, [5, 6], 1.0, [0.0, 0.0]), + (33, (5, -4), 1.0, [0.0, 0.0]), + (45, [-5, 4], 1.2, [0.0, 0.0]), + (33, (-4, -8), 2.0, [0.0, 0.0]), + (85, (10, -10), 0.7, [0.0, 0.0]), + (0, [0, 0], 1.0, [35.0, ]), + (-25, [0, 0], 1.2, [0.0, 15.0]), + (-45, [-10, 0], 0.7, [2.0, 5.0]), + (-45, [-10, -10], 1.2, [4.0, 5.0]), + (-90, [0, 0], 1.0, [0.0, 0.0]), + ] + for r in [0, ]: + for a, t, s, sh in test_configs: + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) ) - ) def test_rotate(self): # Tests on square image - tensor, pil_img = self._create_data(26, 26) scripted_rotate = torch.jit.script(F.rotate) - img_size = pil_img.size - - centers = [ - None, - (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), - [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] - ] - - for r in [0, ]: - for a in range(-120, 120, 23): - for e in [True, False]: - for c in centers: - - out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) - - self.assertEqual( - out_tensor.shape, - out_pil_tensor.shape, - msg="{}: {} vs {}".format( - (r, a, e, c), out_tensor.shape, out_pil_tensor.shape + for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]: + + img_size = pil_img.size + centers = [ + None, + (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), + [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] + ] + + for r in [0, ]: + for a in range(-180, 180, 17): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) ) - ) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 2% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.02, - msg="{}: {}\n{} vs \n{}".format( - (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 2% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.02, + msg="{}: {}\n{} vs \n{}".format( + (img_size, r, a, e, c), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) ) - ) if __name__ == '__main__': diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 374c131cf44..689137b44cb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -848,8 +848,9 @@ def rotate( center_f = [0.0, 0.0] if center is not None: img_size = _get_image_size(img) - # Center is normalized to [-1, +1] - center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)] + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) @@ -926,10 +927,8 @@ def affine( return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) - # we need to rescale translate by image size / 2 as its values can be between -1 and 1 - translate = [2.0 * t / s for s, t in zip(img_size, translate)] - - matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 3641d722730..452e5d37ed8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -663,6 +663,25 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: return img +def _gen_affine_grid( + theta: Tensor, w: int, h: int, ow: int, oh: int, +) -> Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3) + base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow)) + base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1)) + base_grid[..., 2].fill_(1) + + output_grid = base_grid.view(1, oh * ow, 3).bmm(theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h])) + return output_grid.view(1, oh, ow, 2) + + def affine( img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None ) -> Tensor: @@ -688,44 +707,33 @@ def affine( theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) shape = img.shape - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) mode = _interpolation_modes[resample] - return _apply_grid_transform(img, grid, mode) def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - # we need to normalize coordinates according to - # [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that pts = torch.tensor([ - [-1.0, -1.0, 1.0], - [-1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - [1.0, -1.0, 1.0], + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], ]) - # denormalize back to w, h: - new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0 + new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) - size = torch.ceil(max_vals) - torch.floor(min_vals) - return int(size[0]), int(size[1]) - - -def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: - if expand: - ow, oh = _compute_output_size(theta, w, h) - else: - ow, oh = w, h - d = 0.5 # if not align_corners - x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w) - y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h) - y, x = torch.meshgrid(y, x) - pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) - output_grid = torch.matmul(pts, theta.t()) - return output_grid.unsqueeze(dim=0) + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + cmax = torch.ceil((max_vals / tol).trunc_() * tol) + cmin = torch.floor((min_vals / tol).trunc_() * tol) + size = cmax - cmin + return int(size[0]), int(size[1]) def rotate( @@ -736,6 +744,7 @@ def rotate( Args: img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. + Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates. resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: bilinear(=2). expand (bool, optional): Optional expansion flag. @@ -757,10 +766,10 @@ def rotate( } _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) - - theta = torch.tensor(matrix).reshape(2, 3) - shape = img.shape - grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + theta = torch.tensor(matrix).reshape(1, 2, 3) + w, h = img.shape[-1], img.shape[-2] + ow, oh = _compute_output_size(theta, w, h) if expand else (w, h) + grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode)