Skip to content

Commit 5209e10

Browse files
vfdev-5bryant1410
authored andcommitted
[BC-breaking] Unified input for RandomPerspective (pytorch#2561)
* Unified input for RandomPerspective * Updated docs * Fixed failing test and bug with torch.randint * Update test_functional_tensor.py
1 parent cd145b1 commit 5209e10

File tree

3 files changed

+62
-34
lines changed

3 files changed

+62
-34
lines changed

test/test_functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,10 +573,10 @@ def test_perspective(self):
573573

574574
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
575575
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
576-
# Tolerance : less than 3% of different pixels
576+
# Tolerance : less than 5% of different pixels
577577
self.assertLess(
578578
ratio_diff_pixels,
579-
0.03,
579+
0.05,
580580
msg="{}: {}\n{} vs \n{}".format(
581581
(r, spoints, epoints),
582582
ratio_diff_pixels,

test/test_transforms_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,23 @@ def test_random_rotate(self):
301301
out2 = s_transform(tensor)
302302
self.assertTrue(out1.equal(out2))
303303

304+
def test_random_perspective(self):
305+
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
306+
307+
for distortion_scale in np.linspace(0.1, 1.0, num=20):
308+
for interpolation in [NEAREST, BILINEAR]:
309+
transform = T.RandomPerspective(
310+
distortion_scale=distortion_scale,
311+
interpolation=interpolation
312+
)
313+
s_transform = torch.jit.script(transform)
314+
315+
torch.manual_seed(12)
316+
out1 = transform(tensor)
317+
torch.manual_seed(12)
318+
out2 = s_transform(tensor)
319+
self.assertTrue(out1.equal(out2))
320+
304321

305322
if __name__ == '__main__':
306323
unittest.main()

torchvision/transforms/transforms.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -627,66 +627,77 @@ def __repr__(self):
627627
return self.__class__.__name__ + '(p={})'.format(self.p)
628628

629629

630-
class RandomPerspective(object):
631-
"""Performs Perspective transformation of the given PIL Image randomly with a given probability.
630+
class RandomPerspective(torch.nn.Module):
631+
"""Performs a random perspective transformation of the given image with a given probability.
632+
The image can be a PIL Image or a Tensor, in which case it is expected
633+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
632634
633635
Args:
634-
interpolation : Default- Image.BICUBIC
635-
636-
p (float): probability of the image being perspectively transformed. Default value is 0.5
637-
638-
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
636+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
637+
Default is 0.5.
638+
p (float): probability of the image being transformed. Default is 0.5.
639+
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
640+
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
641+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
642+
image. If int or float, the value is used for all bands respectively. Default is 0.
643+
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
644+
input. Fill value for the area outside the transform in the output image is always 0.
639645
640-
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
641-
If int, it is used for all channels respectively. Default value is 0.
642646
"""
643647

644-
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
648+
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0):
649+
super().__init__()
645650
self.p = p
646651
self.interpolation = interpolation
647652
self.distortion_scale = distortion_scale
648653
self.fill = fill
649654

650-
def __call__(self, img):
655+
def forward(self, img):
651656
"""
652657
Args:
653-
img (PIL Image): Image to be Perspectively transformed.
658+
img (PIL Image or Tensor): Image to be Perspectively transformed.
654659
655660
Returns:
656-
PIL Image: Random perspectivley transformed image.
661+
PIL Image or Tensor: Randomly transformed image.
657662
"""
658-
if not F._is_pil_image(img):
659-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
660-
661-
if random.random() < self.p:
662-
width, height = img.size
663+
if torch.rand(1) < self.p:
664+
width, height = F._get_image_size(img)
663665
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
664666
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
665667
return img
666668

667669
@staticmethod
668-
def get_params(width, height, distortion_scale):
670+
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
669671
"""Get parameters for ``perspective`` for a random perspective transform.
670672
671673
Args:
672-
width : width of the image.
673-
height : height of the image.
674+
width (int): width of the image.
675+
height (int): height of the image.
676+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
674677
675678
Returns:
676679
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
677680
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
678681
"""
679-
half_height = int(height / 2)
680-
half_width = int(width / 2)
681-
topleft = (random.randint(0, int(distortion_scale * half_width)),
682-
random.randint(0, int(distortion_scale * half_height)))
683-
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
684-
random.randint(0, int(distortion_scale * half_height)))
685-
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
686-
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
687-
botleft = (random.randint(0, int(distortion_scale * half_width)),
688-
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
689-
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
682+
half_height = height // 2
683+
half_width = width // 2
684+
topleft = [
685+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
686+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
687+
]
688+
topright = [
689+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
690+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
691+
]
692+
botright = [
693+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
694+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
695+
]
696+
botleft = [
697+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
698+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
699+
]
700+
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
690701
endpoints = [topleft, topright, botright, botleft]
691702
return startpoints, endpoints
692703

0 commit comments

Comments
 (0)