diff --git a/test/test_transforms.py b/test/test_transforms.py index 3712e592cc4..e9cbbad3e68 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,6 +1,7 @@ import math import os import random +from functools import partial import numpy as np import pytest @@ -541,9 +542,8 @@ def test_pad_with_mode_F_images(self): assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size]) -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.parametrize( - "fn, trans, config", + "fn, trans, kwargs", [ (F.invert, transforms.RandomInvert, {}), (F.posterize, transforms.RandomPosterize, {"bits": 4}), @@ -551,28 +551,26 @@ def test_pad_with_mode_F_images(self): (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), (F.autocontrast, transforms.RandomAutocontrast, {}), (F.equalize, transforms.RandomEqualize, {}), + (F.vflip, transforms.RandomVerticalFlip, {}), + (F.hflip, transforms.RandomHorizontalFlip, {}), + (partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}), ], ) -@pytest.mark.parametrize("p", (0.5, 0.7)) -def test_randomness(fn, trans, config, p): - random_state = random.getstate() - random.seed(42) +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("p", (0, 1)) +def test_randomness(fn, trans, kwargs, seed, p): + torch.manual_seed(seed) img = transforms.ToPILImage()(torch.rand(3, 16, 18)) - inv_img = fn(img, **config) + expected_transformed_img = fn(img, **kwargs) + randomly_transformed_img = trans(p=p, **kwargs)(img) - num_samples = 250 - counts = 0 - for _ in range(num_samples): - tranformation = trans(p=p, **config) - tranformation.__repr__() - out = tranformation(img) - if out == inv_img: - counts += 1 + if p == 0: + assert randomly_transformed_img == img + elif p == 1: + assert randomly_transformed_img == expected_transformed_img - p_value = stats.binom_test(counts, num_samples, p=p) - random.setstate(random_state) - assert p_value > 0.0001 + trans(**kwargs).__repr__() class TestToPil: @@ -1362,160 +1360,42 @@ def test_to_grayscale(): trans4.__repr__() -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") -def test_random_grayscale(): - """Unit tests for random grayscale transform""" +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("p", (0, 1)) +def test_random_apply(p, seed): + torch.manual_seed(seed) + random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p) + img = transforms.ToPILImage()(torch.rand(3, 30, 40)) + out = random_apply_transform(img) + if p == 0: + assert out == img + elif p == 1: + assert out != img - # Test Set 1: RGB -> 3 channel grayscale - np_rng = np.random.RandomState(0) - random_state = random.getstate() - random.seed(42) - x_shape = [2, 2, 3] - x_np = np_rng.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode="RGB") - x_pil_2 = x_pil.convert("L") - gray_np = np.array(x_pil_2) - - num_samples = 250 - num_gray = 0 - for _ in range(num_samples): - gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) - gray_np_2 = np.array(gray_pil_2) - if ( - np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) - and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) - and np.array_equal(gray_np, gray_np_2[:, :, 0]) - ): - num_gray = num_gray + 1 - - p_value = stats.binom_test(num_gray, num_samples, p=0.5) - random.setstate(random_state) - assert p_value > 0.0001 - - # Test Set 2: grayscale -> 1 channel grayscale - random_state = random.getstate() - random.seed(42) - x_shape = [2, 2, 3] - x_np = np_rng.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode="RGB") - x_pil_2 = x_pil.convert("L") - gray_np = np.array(x_pil_2) - - num_samples = 250 - num_gray = 0 - for _ in range(num_samples): - gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2) - gray_np_3 = np.array(gray_pil_3) - if np.array_equal(gray_np, gray_np_3): - num_gray = num_gray + 1 - - p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged - random.setstate(random_state) - assert p_value > 0.0001 - - # Test set 3: Explicit tests - x_shape = [2, 2, 3] - x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] - x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode="RGB") - x_pil_2 = x_pil.convert("L") - gray_np = np.array(x_pil_2) - - # Case 3a: RGB -> 3 channel grayscale (grayscaled) - trans2 = transforms.RandomGrayscale(p=1.0) - gray_pil_2 = trans2(x_pil) - gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == "RGB", "mode should be RGB" - assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" - assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) - assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) - assert_equal(gray_np, gray_np_2[:, :, 0]) - - # Case 3b: RGB -> 3 channel grayscale (unchanged) - trans2 = transforms.RandomGrayscale(p=0.0) - gray_pil_2 = trans2(x_pil) - gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == "RGB", "mode should be RGB" - assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" - assert_equal(x_np, gray_np_2) - - # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) - trans3 = transforms.RandomGrayscale(p=1.0) - gray_pil_3 = trans3(x_pil_2) - gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == "L", "mode should be L" - assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" - assert_equal(gray_np, gray_np_3) - - # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) - trans3 = transforms.RandomGrayscale(p=0.0) - gray_pil_3 = trans3(x_pil_2) - gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == "L", "mode should be L" - assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" - assert_equal(gray_np, gray_np_3) + # Checking if RandomApply can be printed as string + random_apply_transform.__repr__() - # Checking if RandomGrayscale can be printed as string - trans3.__repr__() +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("proba_passthrough", (0, 1)) +def test_random_choice(proba_passthrough, seed): + random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") -def test_random_apply(): - random_state = random.getstate() - random.seed(42) - random_apply_transform = transforms.RandomApply( + random_choice_transform = transforms.RandomChoice( [ - transforms.RandomRotation((-45, 45)), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), + lambda x: x, # passthrough + transforms.RandomRotation((1, 45)), ], - p=0.75, + p=[proba_passthrough, 1 - proba_passthrough], ) - img = transforms.ToPILImage()(torch.rand(3, 10, 10)) - num_samples = 250 - num_applies = 0 - for _ in range(num_samples): - out = random_apply_transform(img) - if out != img: - num_applies += 1 - - p_value = stats.binom_test(num_applies, num_samples, p=0.75) - random.setstate(random_state) - assert p_value > 0.0001 - - # Checking if RandomApply can be printed as string - random_apply_transform.__repr__() - -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") -def test_random_choice(): - random_state = random.getstate() - random.seed(42) - random_choice_transform = transforms.RandomChoice( - [transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3] - ) - img = transforms.ToPILImage()(torch.rand(3, 25, 25)) - num_samples = 250 - num_resize_15 = 0 - num_resize_20 = 0 - num_crop_10 = 0 - for _ in range(num_samples): - out = random_choice_transform(img) - if out.size == (15, 15): - num_resize_15 += 1 - elif out.size == (20, 20): - num_resize_20 += 1 - elif out.size == (10, 10): - num_crop_10 += 1 - - p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333) - assert p_value > 0.0001 - p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333) - assert p_value > 0.0001 - p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333) - assert p_value > 0.0001 + img = transforms.ToPILImage()(torch.rand(3, 30, 40)) + out = random_choice_transform(img) + if proba_passthrough == 1: + assert out == img + elif proba_passthrough == 0: + assert out != img - random.setstate(random_state) # Checking if RandomChoice can be printed as string random_choice_transform.__repr__() @@ -1888,6 +1768,7 @@ def test_random_erasing(): tol = 0.05 assert 1 / 3 - tol <= aspect_ratio <= 3 + tol + # Make sure that h > w and h < w are equaly likely (log-scale sampling) aspect_ratios = [] random.seed(42) trial = 1000 @@ -2011,72 +1892,6 @@ def test_randomperspective_fill(mode): F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands)) -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") -def test_random_vertical_flip(): - random_state = random.getstate() - random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 10, 10)) - vimg = img.transpose(Image.FLIP_TOP_BOTTOM) - - num_samples = 250 - num_vertical = 0 - for _ in range(num_samples): - out = transforms.RandomVerticalFlip()(img) - if out == vimg: - num_vertical += 1 - - p_value = stats.binom_test(num_vertical, num_samples, p=0.5) - random.setstate(random_state) - assert p_value > 0.0001 - - num_samples = 250 - num_vertical = 0 - for _ in range(num_samples): - out = transforms.RandomVerticalFlip(p=0.7)(img) - if out == vimg: - num_vertical += 1 - - p_value = stats.binom_test(num_vertical, num_samples, p=0.7) - random.setstate(random_state) - assert p_value > 0.0001 - - # Checking if RandomVerticalFlip can be printed as string - transforms.RandomVerticalFlip().__repr__() - - -@pytest.mark.skipif(stats is None, reason="scipy.stats not available") -def test_random_horizontal_flip(): - random_state = random.getstate() - random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 10, 10)) - himg = img.transpose(Image.FLIP_LEFT_RIGHT) - - num_samples = 250 - num_horizontal = 0 - for _ in range(num_samples): - out = transforms.RandomHorizontalFlip()(img) - if out == himg: - num_horizontal += 1 - - p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) - random.setstate(random_state) - assert p_value > 0.0001 - - num_samples = 250 - num_horizontal = 0 - for _ in range(num_samples): - out = transforms.RandomHorizontalFlip(p=0.7)(img) - if out == himg: - num_horizontal += 1 - - p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) - random.setstate(random_state) - assert p_value > 0.0001 - - # Checking if RandomHorizontalFlip can be printed as string - transforms.RandomHorizontalFlip().__repr__() - - @pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_normalize(): def samples_from_standard_normal(tensor): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e402aa7f9a7..666e4e3a16d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -562,7 +562,7 @@ class RandomChoice(RandomTransforms): def __init__(self, transforms, p=None): super().__init__(transforms) if p is not None and not isinstance(p, Sequence): - raise TypeError("Argument transforms should be a sequence") + raise TypeError("Argument p should be a sequence") self.p = p def __call__(self, *args):