Skip to content

Commit 4e6faa7

Browse files
anguelosfmassa
authored andcommitted
A minor change for transforms.RandomCrop (#462)
* Made transorms.RandomCrop tolerate images smaller than the given size. * Extended the tescase for transforms.RandomCrop * Made the testcase test for owidth=width+1 * Fixed the one pixel pading and the testcase * Fixed minor lintint errors. flake8 passes.
1 parent 3b45057 commit 4e6faa7

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

test/test_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ def test_random_crop(self):
205205
assert result.size(2) == width
206206
assert np.allclose(img.numpy(), result.numpy())
207207

208+
result = transforms.Compose([
209+
transforms.ToPILImage(),
210+
transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
211+
transforms.ToTensor(),
212+
])(img)
213+
assert result.size(1) == height + 1
214+
assert result.size(2) == width + 1
215+
208216
def test_pad(self):
209217
height = random.randint(10, 32) * 2
210218
width = random.randint(10, 32) * 2

torchvision/transforms/transforms.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,17 @@ class RandomCrop(object):
368368
of the image. Default is 0, i.e no padding. If a sequence of length
369369
4 is provided, it is used to pad left, top, right, bottom borders
370370
respectively.
371+
pad_if_needed (boolean): It will pad the image if smaller than the
372+
desired size to avoid raising an exception.
371373
"""
372374

373-
def __init__(self, size, padding=0):
375+
def __init__(self, size, padding=0, pad_if_needed=False):
374376
if isinstance(size, numbers.Number):
375377
self.size = (int(size), int(size))
376378
else:
377379
self.size = size
378380
self.padding = padding
381+
self.pad_if_needed = pad_if_needed
379382

380383
@staticmethod
381384
def get_params(img, output_size):
@@ -408,6 +411,13 @@ def __call__(self, img):
408411
if self.padding > 0:
409412
img = F.pad(img, self.padding)
410413

414+
# pad the width if needed
415+
if self.pad_if_needed and img.size[0] < self.size[1]:
416+
img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
417+
# pad the height if needed
418+
if self.pad_if_needed and img.size[1] < self.size[0]:
419+
img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
420+
411421
i, j, h, w = self.get_params(img, self.size)
412422

413423
return F.crop(img, i, j, h, w)

0 commit comments

Comments
 (0)