Skip to content

Commit a129b6b

Browse files
Philip Meierfmassa
authored andcommitted
Adds optional fill colour to rotate (#1280)
* Adds optional fill colour to rotate * bug fix
1 parent a91fe72 commit a129b6b

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

torchvision/transforms/functional.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def adjust_gamma(img, gamma, gain=1):
686686
return img
687687

688688

689-
def rotate(img, angle, resample=False, expand=False, center=None):
689+
def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
690690
"""Rotate the image by angle.
691691
692692
@@ -703,6 +703,8 @@ def rotate(img, angle, resample=False, expand=False, center=None):
703703
center (2-tuple, optional): Optional center of rotation.
704704
Origin is the upper left corner.
705705
Default is the center of the image.
706+
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
707+
If int, it is used for all channels respectively.
706708
707709
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
708710
@@ -711,7 +713,10 @@ def rotate(img, angle, resample=False, expand=False, center=None):
711713
if not _is_pil_image(img):
712714
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
713715

714-
return img.rotate(angle, resample, expand, center)
716+
if isinstance(fill, int):
717+
fill = tuple([fill] * 3)
718+
719+
return img.rotate(angle, resample, expand, center, fillcolor=fill)
715720

716721

717722
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):

torchvision/transforms/transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,12 +946,14 @@ class RandomRotation(object):
946946
center (2-tuple, optional): Optional center of rotation.
947947
Origin is the upper left corner.
948948
Default is the center of the image.
949+
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
950+
If int, it is used for all channels respectively.
949951
950952
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
951953
952954
"""
953955

954-
def __init__(self, degrees, resample=False, expand=False, center=None):
956+
def __init__(self, degrees, resample=False, expand=False, center=None, fill=0):
955957
if isinstance(degrees, numbers.Number):
956958
if degrees < 0:
957959
raise ValueError("If degrees is a single number, it must be positive.")
@@ -964,6 +966,7 @@ def __init__(self, degrees, resample=False, expand=False, center=None):
964966
self.resample = resample
965967
self.expand = expand
966968
self.center = center
969+
self.fill = fill
967970

968971
@staticmethod
969972
def get_params(degrees):
@@ -987,7 +990,7 @@ def __call__(self, img):
987990

988991
angle = self.get_params(self.degrees)
989992

990-
return F.rotate(img, angle, self.resample, self.expand, self.center)
993+
return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
991994

992995
def __repr__(self):
993996
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)

0 commit comments

Comments
 (0)