diff --git a/test/test_transforms.py b/test/test_transforms.py index 5d0275f946f..9319fb0664a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9,6 +9,11 @@ except ImportError: accimage = None +try: + from scipy import stats +except ImportError: + stats = None + GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg' @@ -327,6 +332,34 @@ def test_ndarray_gray_int32_to_pil_image(self): assert img.mode == 'I' assert np.allclose(img, img_data[:, :, 0]) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_vertical_flip(self): + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + vimg = img.transpose(Image.FLIP_TOP_BOTTOM) + + num_vertical = 0 + for _ in range(100): + out = transforms.RandomVerticalFlip()(img) + if out == vimg: + num_vertical += 1 + + p_value = stats.binom_test(num_vertical, 100, p=0.5) + assert p_value > 0.05 + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_horizontal_flip(self): + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + himg = img.transpose(Image.FLIP_LEFT_RIGHT) + + num_horizontal = 0 + for _ in range(100): + out = transforms.RandomHorizontalFlip()(img) + if out == himg: + num_horizontal += 1 + + p_value = stats.binom_test(num_horizontal, 100, p=0.5) + assert p_value > 0.05 + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms.py b/torchvision/transforms.py index da58aa12b9a..6202ec9c147 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -338,6 +338,22 @@ def __call__(self, img): return img +class RandomVerticalFlip(object): + """Vertically flip the given PIL.Image randomly with a probability of 0.5""" + + def __call__(self, img): + """ + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Randomly flipped image. + """ + if random.random() < 0.5: + return img.transpose(Image.FLIP_TOP_BOTTOM) + return img + + class RandomSizedCrop(object): """Crop the given PIL.Image to random size and aspect ratio.