Skip to content

Commit 8cd15cb

Browse files
Add FiveCrop and TenCrop transforms (#261)
add FiveCrop and TenCrop
1 parent a5b75c8 commit 8cd15cb

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

test/test_transforms.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,66 @@ def test_crop(self):
6161
assert sum2 > sum1, "height: " + str(height) + " width: " \
6262
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
6363

64+
def test_five_crop(self):
65+
to_pil_image = transforms.ToPILImage()
66+
h = random.randint(5, 25)
67+
w = random.randint(5, 25)
68+
for single_dim in [True, False]:
69+
crop_h = random.randint(1, h)
70+
crop_w = random.randint(1, w)
71+
if single_dim:
72+
crop_h = min(crop_h, crop_w)
73+
crop_w = crop_h
74+
transform = transforms.FiveCrop(crop_h)
75+
else:
76+
transform = transforms.FiveCrop((crop_h, crop_w))
77+
78+
img = torch.FloatTensor(3, h, w).uniform_()
79+
results = transform(to_pil_image(img))
80+
81+
assert len(results) == 5
82+
for crop in results:
83+
assert crop.size == (crop_w, crop_h)
84+
85+
to_pil_image = transforms.ToPILImage()
86+
tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
87+
tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
88+
bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
89+
br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
90+
center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
91+
expected_output = (tl, tr, bl, br, center)
92+
assert results == expected_output
93+
94+
def test_ten_crop(self):
95+
to_pil_image = transforms.ToPILImage()
96+
h = random.randint(5, 25)
97+
w = random.randint(5, 25)
98+
for should_vflip in [True, False]:
99+
for single_dim in [True, False]:
100+
crop_h = random.randint(1, h)
101+
crop_w = random.randint(1, w)
102+
if single_dim:
103+
crop_h = min(crop_h, crop_w)
104+
crop_w = crop_h
105+
transform = transforms.TenCrop(crop_h, vflip=should_vflip)
106+
five_crop = transforms.FiveCrop(crop_h)
107+
else:
108+
transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip)
109+
five_crop = transforms.FiveCrop((crop_h, crop_w))
110+
111+
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
112+
results = transform(img)
113+
expected_output = five_crop(img)
114+
if should_vflip:
115+
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
116+
expected_output += five_crop(vflipped_img)
117+
else:
118+
hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
119+
expected_output += five_crop(hflipped_img)
120+
121+
assert len(results) == 10
122+
assert expected_output == results
123+
64124
def test_scale(self):
65125
height = random.randint(24, 32) * 2
66126
width = random.randint(24, 32) * 2

torchvision/transforms.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,72 @@ def __call__(self, img):
638638
"""
639639
i, j, h, w = self.get_params(img)
640640
return scaled_crop(img, i, j, h, w, self.size, self.interpolation)
641+
642+
643+
class FiveCrop(object):
644+
"""Crop the given PIL.Image into four corners and the central crop.abs
645+
646+
Note: this transform returns a tuple of images and there may be a mismatch in the number of
647+
inputs and targets your `Dataset` returns.
648+
649+
Args:
650+
size (sequence or int): Desired output size of the crop. If size is an
651+
int instead of sequence like (h, w), a square crop (size, size) is
652+
made.
653+
"""
654+
655+
def __init__(self, size):
656+
self.size = size
657+
if isinstance(size, numbers.Number):
658+
self.size = (int(size), int(size))
659+
else:
660+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
661+
self.size = size
662+
663+
def __call__(self, img):
664+
w, h = img.size
665+
crop_h, crop_w = self.size
666+
if crop_w > w or crop_h > h:
667+
raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size,
668+
(h, w)))
669+
tl = img.crop((0, 0, crop_w, crop_h))
670+
tr = img.crop((w - crop_w, 0, w, crop_h))
671+
bl = img.crop((0, h - crop_h, crop_w, h))
672+
br = img.crop((w - crop_w, h - crop_h, w, h))
673+
center = CenterCrop((crop_h, crop_w))(img)
674+
return (tl, tr, bl, br, center)
675+
676+
677+
class TenCrop(object):
678+
"""Crop the given PIL.Image into four corners and the central crop plus the
679+
flipped version of these (horizontal flipping is used by default)
680+
681+
Note: this transform returns a tuple of images and there may be a mismatch in the number of
682+
inputs and targets your `Dataset` returns.
683+
684+
Args:
685+
size (sequence or int): Desired output size of the crop. If size is an
686+
int instead of sequence like (h, w), a square crop (size, size) is
687+
made.
688+
vflip bool: Use vertical flipping instead of horizontal
689+
"""
690+
691+
def __init__(self, size, vflip=False):
692+
self.size = size
693+
if isinstance(size, numbers.Number):
694+
self.size = (int(size), int(size))
695+
else:
696+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
697+
self.size = size
698+
self.vflip = vflip
699+
700+
def __call__(self, img):
701+
five_crop = FiveCrop(self.size)
702+
first_five = five_crop(img)
703+
if self.vflip:
704+
img = img.transpose(Image.FLIP_TOP_BOTTOM)
705+
else:
706+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
707+
708+
second_five = five_crop(img)
709+
return first_five + second_five

0 commit comments

Comments
 (0)