@@ -627,66 +627,77 @@ def __repr__(self):
627
627
return self .__class__ .__name__ + '(p={})' .format (self .p )
628
628
629
629
630
- class RandomPerspective (object ):
631
- """Performs Perspective transformation of the given PIL Image randomly with a given probability.
630
+ class RandomPerspective (torch .nn .Module ):
631
+ """Performs a random perspective transformation of the given image with a given probability.
632
+ The image can be a PIL Image or a Tensor, in which case it is expected
633
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
632
634
633
635
Args:
634
- interpolation : Default- Image.BICUBIC
635
-
636
- p (float): probability of the image being perspectively transformed. Default value is 0.5
637
-
638
- distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
636
+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
637
+ Default is 0.5.
638
+ p (float): probability of the image being transformed. Default is 0.5.
639
+ interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
640
+ ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
641
+ fill (n-tuple or int or float): Pixel fill value for area outside the rotated
642
+ image. If int or float, the value is used for all bands respectively. Default is 0.
643
+ This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
644
+ input. Fill value for the area outside the transform in the output image is always 0.
639
645
640
- fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
641
- If int, it is used for all channels respectively. Default value is 0.
642
646
"""
643
647
644
- def __init__ (self , distortion_scale = 0.5 , p = 0.5 , interpolation = Image .BICUBIC , fill = 0 ):
648
+ def __init__ (self , distortion_scale = 0.5 , p = 0.5 , interpolation = Image .BILINEAR , fill = 0 ):
649
+ super ().__init__ ()
645
650
self .p = p
646
651
self .interpolation = interpolation
647
652
self .distortion_scale = distortion_scale
648
653
self .fill = fill
649
654
650
- def __call__ (self , img ):
655
+ def forward (self , img ):
651
656
"""
652
657
Args:
653
- img (PIL Image): Image to be Perspectively transformed.
658
+ img (PIL Image or Tensor ): Image to be Perspectively transformed.
654
659
655
660
Returns:
656
- PIL Image: Random perspectivley transformed image.
661
+ PIL Image or Tensor: Randomly transformed image.
657
662
"""
658
- if not F ._is_pil_image (img ):
659
- raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
660
-
661
- if random .random () < self .p :
662
- width , height = img .size
663
+ if torch .rand (1 ) < self .p :
664
+ width , height = F ._get_image_size (img )
663
665
startpoints , endpoints = self .get_params (width , height , self .distortion_scale )
664
666
return F .perspective (img , startpoints , endpoints , self .interpolation , self .fill )
665
667
return img
666
668
667
669
@staticmethod
668
- def get_params (width , height , distortion_scale ) :
670
+ def get_params (width : int , height : int , distortion_scale : float ) -> Tuple [ List [ List [ int ]], List [ List [ int ]]] :
669
671
"""Get parameters for ``perspective`` for a random perspective transform.
670
672
671
673
Args:
672
- width : width of the image.
673
- height : height of the image.
674
+ width (int): width of the image.
675
+ height (int): height of the image.
676
+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
674
677
675
678
Returns:
676
679
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
677
680
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
678
681
"""
679
- half_height = int (height / 2 )
680
- half_width = int (width / 2 )
681
- topleft = (random .randint (0 , int (distortion_scale * half_width )),
682
- random .randint (0 , int (distortion_scale * half_height )))
683
- topright = (random .randint (width - int (distortion_scale * half_width ) - 1 , width - 1 ),
684
- random .randint (0 , int (distortion_scale * half_height )))
685
- botright = (random .randint (width - int (distortion_scale * half_width ) - 1 , width - 1 ),
686
- random .randint (height - int (distortion_scale * half_height ) - 1 , height - 1 ))
687
- botleft = (random .randint (0 , int (distortion_scale * half_width )),
688
- random .randint (height - int (distortion_scale * half_height ) - 1 , height - 1 ))
689
- startpoints = [(0 , 0 ), (width - 1 , 0 ), (width - 1 , height - 1 ), (0 , height - 1 )]
682
+ half_height = height // 2
683
+ half_width = width // 2
684
+ topleft = [
685
+ int (torch .randint (0 , int (distortion_scale * half_width ) + 1 , size = (1 , )).item ()),
686
+ int (torch .randint (0 , int (distortion_scale * half_height ) + 1 , size = (1 , )).item ())
687
+ ]
688
+ topright = [
689
+ int (torch .randint (width - int (distortion_scale * half_width ) - 1 , width , size = (1 , )).item ()),
690
+ int (torch .randint (0 , int (distortion_scale * half_height ) + 1 , size = (1 , )).item ())
691
+ ]
692
+ botright = [
693
+ int (torch .randint (width - int (distortion_scale * half_width ) - 1 , width , size = (1 , )).item ()),
694
+ int (torch .randint (height - int (distortion_scale * half_height ) - 1 , height , size = (1 , )).item ())
695
+ ]
696
+ botleft = [
697
+ int (torch .randint (0 , int (distortion_scale * half_width ) + 1 , size = (1 , )).item ()),
698
+ int (torch .randint (height - int (distortion_scale * half_height ) - 1 , height , size = (1 , )).item ())
699
+ ]
700
+ startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
690
701
endpoints = [topleft , topright , botright , botleft ]
691
702
return startpoints , endpoints
692
703
0 commit comments