@@ -61,6 +61,66 @@ def test_crop(self):
61
61
assert sum2 > sum1 , "height: " + str (height ) + " width: " \
62
62
+ str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
63
63
64
+ def test_five_crop (self ):
65
+ to_pil_image = transforms .ToPILImage ()
66
+ h = random .randint (5 , 25 )
67
+ w = random .randint (5 , 25 )
68
+ for single_dim in [True , False ]:
69
+ crop_h = random .randint (1 , h )
70
+ crop_w = random .randint (1 , w )
71
+ if single_dim :
72
+ crop_h = min (crop_h , crop_w )
73
+ crop_w = crop_h
74
+ transform = transforms .FiveCrop (crop_h )
75
+ else :
76
+ transform = transforms .FiveCrop ((crop_h , crop_w ))
77
+
78
+ img = torch .FloatTensor (3 , h , w ).uniform_ ()
79
+ results = transform (to_pil_image (img ))
80
+
81
+ assert len (results ) == 5
82
+ for crop in results :
83
+ assert crop .size == (crop_w , crop_h )
84
+
85
+ to_pil_image = transforms .ToPILImage ()
86
+ tl = to_pil_image (img [:, 0 :crop_h , 0 :crop_w ])
87
+ tr = to_pil_image (img [:, 0 :crop_h , w - crop_w :])
88
+ bl = to_pil_image (img [:, h - crop_h :, 0 :crop_w ])
89
+ br = to_pil_image (img [:, h - crop_h :, w - crop_w :])
90
+ center = transforms .CenterCrop ((crop_h , crop_w ))(to_pil_image (img ))
91
+ expected_output = (tl , tr , bl , br , center )
92
+ assert results == expected_output
93
+
94
+ def test_ten_crop (self ):
95
+ to_pil_image = transforms .ToPILImage ()
96
+ h = random .randint (5 , 25 )
97
+ w = random .randint (5 , 25 )
98
+ for should_vflip in [True , False ]:
99
+ for single_dim in [True , False ]:
100
+ crop_h = random .randint (1 , h )
101
+ crop_w = random .randint (1 , w )
102
+ if single_dim :
103
+ crop_h = min (crop_h , crop_w )
104
+ crop_w = crop_h
105
+ transform = transforms .TenCrop (crop_h , vflip = should_vflip )
106
+ five_crop = transforms .FiveCrop (crop_h )
107
+ else :
108
+ transform = transforms .TenCrop ((crop_h , crop_w ), vflip = should_vflip )
109
+ five_crop = transforms .FiveCrop ((crop_h , crop_w ))
110
+
111
+ img = to_pil_image (torch .FloatTensor (3 , h , w ).uniform_ ())
112
+ results = transform (img )
113
+ expected_output = five_crop (img )
114
+ if should_vflip :
115
+ vflipped_img = img .transpose (Image .FLIP_TOP_BOTTOM )
116
+ expected_output += five_crop (vflipped_img )
117
+ else :
118
+ hflipped_img = img .transpose (Image .FLIP_LEFT_RIGHT )
119
+ expected_output += five_crop (hflipped_img )
120
+
121
+ assert len (results ) == 10
122
+ assert expected_output == results
123
+
64
124
def test_scale (self ):
65
125
height = random .randint (24 , 32 ) * 2
66
126
width = random .randint (24 , 32 ) * 2
0 commit comments