Skip to content

Remove p-value checks in test_transforms.py #4756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 44 additions & 229 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import random
from functools import partial

import numpy as np
import pytest
Expand Down Expand Up @@ -541,38 +542,35 @@ def test_pad_with_mode_F_images(self):
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize(
"fn, trans, config",
"fn, trans, kwargs",
[
(F.invert, transforms.RandomInvert, {}),
(F.posterize, transforms.RandomPosterize, {"bits": 4}),
(F.solarize, transforms.RandomSolarize, {"threshold": 192}),
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
(F.autocontrast, transforms.RandomAutocontrast, {}),
(F.equalize, transforms.RandomEqualize, {}),
(F.vflip, transforms.RandomVerticalFlip, {}),
(F.hflip, transforms.RandomHorizontalFlip, {}),
(partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}),
],
)
@pytest.mark.parametrize("p", (0.5, 0.7))
def test_randomness(fn, trans, config, p):
random_state = random.getstate()
random.seed(42)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
def test_randomness(fn, trans, kwargs, seed, p):
torch.manual_seed(seed)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

inv_img = fn(img, **config)
expected_transformed_img = fn(img, **kwargs)
randomly_transformed_img = trans(p=p, **kwargs)(img)

num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1
if p == 0:
assert randomly_transformed_img == img
elif p == 1:
assert randomly_transformed_img == expected_transformed_img

p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
assert p_value > 0.0001
trans(**kwargs).__repr__()


class TestToPil:
Expand Down Expand Up @@ -1362,160 +1360,42 @@ def test_to_grayscale():
trans4.__repr__()


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_grayscale():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this one into test_randomness above. It will cover all the randomness-related checks in this test. For the rest (expected values), everything is already covered in test_to_grayscale() just above.

"""Unit tests for random grayscale transform"""
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
def test_random_apply(p, seed):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move this one up with test_randomness but this would require taking care of the case where fn is None -- this transform doesn't have a functional equivalent. Not worth it IMHO considering how simple the code is, but I don't mind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including this above will also have the negative side-effect of requiring you to modify your check condition for p=1 to out != img which IMO is less strong to what you have above. I agree keeping this as you have it.

torch.manual_seed(seed)
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p)
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
out = random_apply_transform(img)
if p == 0:
assert out == img
elif p == 1:
assert out != img

# Test Set 1: RGB -> 3 channel grayscale
np_rng = np.random.RandomState(0)
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)

num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
gray_np_2 = np.array(gray_pil_2)
if (
np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
and np.array_equal(gray_np, gray_np_2[:, :, 0])
):
num_gray = num_gray + 1

p_value = stats.binom_test(num_gray, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001

# Test Set 2: grayscale -> 1 channel grayscale
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)

num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
if np.array_equal(gray_np, gray_np_3):
num_gray = num_gray + 1

p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged
random.setstate(random_state)
assert p_value > 0.0001

# Test set 3: Explicit tests
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)

# Case 3a: RGB -> 3 channel grayscale (grayscaled)
trans2 = transforms.RandomGrayscale(p=1.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0])

# Case 3b: RGB -> 3 channel grayscale (unchanged)
trans2 = transforms.RandomGrayscale(p=0.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(x_np, gray_np_2)

# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
trans3 = transforms.RandomGrayscale(p=1.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3)

# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
trans3 = transforms.RandomGrayscale(p=0.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3)
# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()

# Checking if RandomGrayscale can be printed as string
trans3.__repr__()

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("proba_passthrough", (0, 1))
def test_random_choice(proba_passthrough, seed):
random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch

@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_apply():
random_state = random.getstate()
random.seed(42)
random_apply_transform = transforms.RandomApply(
random_choice_transform = transforms.RandomChoice(
[
transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
lambda x: x, # passthrough
transforms.RandomRotation((1, 45)),
],
p=0.75,
p=[proba_passthrough, 1 - proba_passthrough],
)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250
num_applies = 0
for _ in range(num_samples):
out = random_apply_transform(img)
if out != img:
num_applies += 1

p_value = stats.binom_test(num_applies, num_samples, p=0.75)
random.setstate(random_state)
assert p_value > 0.0001

# Checking if RandomApply can be printed as string
random_apply_transform.__repr__()


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_choice():
random_state = random.getstate()
random.seed(42)
random_choice_transform = transforms.RandomChoice(
[transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_resize_15 = 0
num_resize_20 = 0
num_crop_10 = 0
for _ in range(num_samples):
out = random_choice_transform(img)
if out.size == (15, 15):
num_resize_15 += 1
elif out.size == (20, 20):
num_resize_20 += 1
elif out.size == (10, 10):
num_crop_10 += 1

p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
assert p_value > 0.0001
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
assert p_value > 0.0001
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
out = random_choice_transform(img)
if proba_passthrough == 1:
assert out == img
elif proba_passthrough == 0:
assert out != img

random.setstate(random_state)
# Checking if RandomChoice can be printed as string
random_choice_transform.__repr__()

Expand Down Expand Up @@ -1888,6 +1768,7 @@ def test_random_erasing():
tol = 0.05
assert 1 / 3 - tol <= aspect_ratio <= 3 + tol

# Make sure that h > w and h < w are equaly likely (log-scale sampling)
aspect_ratios = []
random.seed(42)
trial = 1000
Expand Down Expand Up @@ -2011,72 +1892,6 @@ def test_randomperspective_fill(mode):
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_vertical_flip():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved this one in test_randomness above, same for horizontal split.

random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)

num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip()(img)
if out == vimg:
num_vertical += 1

p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001

num_samples = 250
num_vertical = 0
for _ in range(num_samples):
out = transforms.RandomVerticalFlip(p=0.7)(img)
if out == vimg:
num_vertical += 1

p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001

# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_horizontal_flip():
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
himg = img.transpose(Image.FLIP_LEFT_RIGHT)

num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip()(img)
if out == himg:
num_horizontal += 1

p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001

num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlip(p=0.7)(img)
if out == himg:
num_horizontal += 1

p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001

# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_normalize():
def samples_from_standard_normal(tensor):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ class RandomChoice(RandomTransforms):
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument transforms should be a sequence")
raise TypeError("Argument p should be a sequence")
self.p = p

def __call__(self, *args):
Expand Down