Skip to content

Commit 7666252

Browse files
authored
Unified inputs for F.rotate (#2495)
* Added code for F_t.rotate with test - updated F.affine tests * Rotate test tolerance to 2% * Fixes failing test * Optimized _expanded_affine_grid with a single matmul op * Recoded _compute_output_size
1 parent 23295fb commit 7666252

File tree

5 files changed

+252
-72
lines changed

5 files changed

+252
-72
lines changed

test/test_functional_tensor.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def test_affine(self):
435435
)
436436
# 3) Test translation
437437
test_configs = [
438-
[10, 12], (12, 13)
438+
[10, 12], (-12, -13)
439439
]
440440
for t in test_configs:
441441
for fn in [F.affine, scripted_affine]:
@@ -447,21 +447,21 @@ def test_affine(self):
447447
test_configs = [
448448
(45, [5, 6], 1.0, [0.0, 0.0]),
449449
(33, (5, -4), 1.0, [0.0, 0.0]),
450-
(45, [5, 4], 1.2, [0.0, 0.0]),
451-
(33, (4, 8), 2.0, [0.0, 0.0]),
450+
(45, [-5, 4], 1.2, [0.0, 0.0]),
451+
(33, (-4, -8), 2.0, [0.0, 0.0]),
452452
(85, (10, -10), 0.7, [0.0, 0.0]),
453453
(0, [0, 0], 1.0, [35.0, ]),
454454
(25, [0, 0], 1.2, [0.0, 15.0]),
455-
(45, [10, 0], 0.7, [2.0, 5.0]),
456-
(45, [10, -10], 1.2, [4.0, 5.0]),
455+
(45, [-10, 0], 0.7, [2.0, 5.0]),
456+
(45, [-10, -10], 1.2, [4.0, 5.0]),
457457
]
458458
for r in [0, ]:
459459
for a, t, s, sh in test_configs:
460+
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
461+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
462+
460463
for fn in [F.affine, scripted_affine]:
461464
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
462-
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
463-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
464-
465465
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
466466
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
467467
# Tolerance : less than 5% of different pixels
@@ -473,6 +473,47 @@ def test_affine(self):
473473
)
474474
)
475475

476+
def test_rotate(self):
477+
# Tests on square image
478+
tensor, pil_img = self._create_data(26, 26)
479+
scripted_rotate = torch.jit.script(F.rotate)
480+
481+
img_size = pil_img.size
482+
483+
centers = [
484+
None,
485+
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
486+
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
487+
]
488+
489+
for r in [0, ]:
490+
for a in range(-120, 120, 23):
491+
for e in [True, False]:
492+
for c in centers:
493+
494+
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
495+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
496+
for fn in [F.rotate, scripted_rotate]:
497+
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
498+
499+
self.assertEqual(
500+
out_tensor.shape,
501+
out_pil_tensor.shape,
502+
msg="{}: {} vs {}".format(
503+
(r, a, e, c), out_tensor.shape, out_pil_tensor.shape
504+
)
505+
)
506+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
507+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
508+
# Tolerance : less than 2% of different pixels
509+
self.assertLess(
510+
ratio_diff_pixels,
511+
0.02,
512+
msg="{}: {}\n{} vs \n{}".format(
513+
(r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
514+
)
515+
)
516+
476517

477518
if __name__ == '__main__':
478519
unittest.main()

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ def test_rotate(self):
12661266
x = np.zeros((100, 100, 3), dtype=np.uint8)
12671267
x[40, 40] = [255, 255, 255]
12681268

1269-
with self.assertRaises(TypeError):
1269+
with self.assertRaisesRegex(TypeError, r"img should be PIL Image"):
12701270
F.rotate(x, 10)
12711271

12721272
img = F.to_pil_image(x)

torchvision/transforms/functional.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -756,40 +756,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
756756
return F_t.adjust_gamma(img, gamma, gain)
757757

758758

759-
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
760-
"""Rotate the image by angle.
761-
762-
763-
Args:
764-
img (PIL Image): PIL Image to be rotated.
765-
angle (float or int): In degrees degrees counter clockwise order.
766-
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
767-
An optional resampling filter. See `filters`_ for more information.
768-
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
769-
expand (bool, optional): Optional expansion flag.
770-
If true, expands the output image to make it large enough to hold the entire rotated image.
771-
If false or omitted, make the output image the same size as the input image.
772-
Note that the expand flag assumes rotation around the center and no translation.
773-
center (2-tuple, optional): Optional center of rotation.
774-
Origin is the upper left corner.
775-
Default is the center of the image.
776-
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
777-
image. If int or float, the value is used for all bands respectively.
778-
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
779-
780-
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
781-
782-
"""
783-
if not F_pil._is_pil_image(img):
784-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
785-
786-
opts = _parse_fill(fill, img, '5.2.0')
787-
788-
return img.rotate(angle, resample, expand, center, **opts)
789-
790-
791759
def _get_inverse_affine_matrix(
792-
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
760+
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
793761
) -> List[float]:
794762
# Helper method to compute inverse matrix for affine transformation
795763

@@ -838,6 +806,56 @@ def _get_inverse_affine_matrix(
838806
return matrix
839807

840808

809+
def rotate(
810+
img: Tensor, angle: float, resample: int = 0, expand: bool = False,
811+
center: Optional[List[int]] = None, fill: Optional[int] = None
812+
) -> Tensor:
813+
"""Rotate the image by angle.
814+
The image can be a PIL Image or a Tensor, in which case it is expected
815+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
816+
817+
Args:
818+
img (PIL Image or Tensor): image to be rotated.
819+
angle (float or int): rotation angle value in degrees, counter-clockwise.
820+
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
821+
An optional resampling filter. See `filters`_ for more information.
822+
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
823+
expand (bool, optional): Optional expansion flag.
824+
If true, expands the output image to make it large enough to hold the entire rotated image.
825+
If false or omitted, make the output image the same size as the input image.
826+
Note that the expand flag assumes rotation around the center and no translation.
827+
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
828+
Default is the center of the image.
829+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
830+
image. If int or float, the value is used for all bands respectively.
831+
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
832+
833+
Returns:
834+
PIL Image or Tensor: Rotated image.
835+
836+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
837+
838+
"""
839+
if not isinstance(angle, (int, float)):
840+
raise TypeError("Argument angle should be int or float")
841+
842+
if center is not None and not isinstance(center, (list, tuple)):
843+
raise TypeError("Argument center should be a sequence")
844+
845+
if not isinstance(img, torch.Tensor):
846+
return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill)
847+
848+
center_f = [0.0, 0.0]
849+
if center is not None:
850+
img_size = _get_image_size(img)
851+
# Center is normalized to [-1, +1]
852+
center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)]
853+
# due to current incoherence of rotation angle direction between affine and rotate implementations
854+
# we need to set -angle.
855+
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
856+
return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill)
857+
858+
841859
def affine(
842860
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
843861
resample: int = 0, fillcolor: Optional[int] = None
@@ -847,7 +865,7 @@ def affine(
847865
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
848866
849867
Args:
850-
img (PIL Image or Tensor): image to be rotated.
868+
img (PIL Image or Tensor): image to transform.
851869
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
852870
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
853871
scale (float): overall scale
@@ -911,7 +929,7 @@ def affine(
911929
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
912930
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
913931

914-
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
932+
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
915933
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
916934

917935

torchvision/transforms/functional_pil.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,37 @@ def affine(img, matrix, resample=0, fillcolor=None):
422422
output_size = img.size
423423
opts = _parse_fill(fillcolor, img, '5.0.0')
424424
return img.transform(output_size, Image.AFFINE, matrix, resample, **opts)
425+
426+
427+
@torch.jit.unused
428+
def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
429+
"""Rotate PIL image by angle.
430+
431+
Args:
432+
img (PIL Image): image to be rotated.
433+
angle (float or int): rotation angle value in degrees, counter-clockwise.
434+
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
435+
An optional resampling filter. See `filters`_ for more information.
436+
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
437+
expand (bool, optional): Optional expansion flag.
438+
If true, expands the output image to make it large enough to hold the entire rotated image.
439+
If false or omitted, make the output image the same size as the input image.
440+
Note that the expand flag assumes rotation around the center and no translation.
441+
center (2-tuple, optional): Optional center of rotation.
442+
Origin is the upper left corner.
443+
Default is the center of the image.
444+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
445+
image. If int or float, the value is used for all bands respectively.
446+
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
447+
448+
Returns:
449+
PIL Image: Rotated image.
450+
451+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
452+
453+
"""
454+
if not _is_pil_image(img):
455+
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
456+
457+
opts = _parse_fill(fill, img, '5.2.0')
458+
return img.rotate(angle, resample, expand, center, **opts)

0 commit comments

Comments
 (0)