-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Remove p-value checks in test_transforms.py #4756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
806aea4
a9c428d
b1b4bd2
f3d3be7
5c254d7
1b76146
accbc74
84ed356
efb94f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import math | ||
import os | ||
import random | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
|
@@ -541,38 +542,35 @@ def test_pad_with_mode_F_images(self): | |
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size]) | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
@pytest.mark.parametrize( | ||
"fn, trans, config", | ||
"fn, trans, kwargs", | ||
[ | ||
(F.invert, transforms.RandomInvert, {}), | ||
(F.posterize, transforms.RandomPosterize, {"bits": 4}), | ||
(F.solarize, transforms.RandomSolarize, {"threshold": 192}), | ||
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), | ||
(F.autocontrast, transforms.RandomAutocontrast, {}), | ||
(F.equalize, transforms.RandomEqualize, {}), | ||
(F.vflip, transforms.RandomVerticalFlip, {}), | ||
(F.hflip, transforms.RandomHorizontalFlip, {}), | ||
(partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}), | ||
], | ||
) | ||
@pytest.mark.parametrize("p", (0.5, 0.7)) | ||
def test_randomness(fn, trans, config, p): | ||
random_state = random.getstate() | ||
random.seed(42) | ||
@pytest.mark.parametrize("seed", range(10)) | ||
@pytest.mark.parametrize("p", (0, 1)) | ||
def test_randomness(fn, trans, kwargs, seed, p): | ||
torch.manual_seed(seed) | ||
img = transforms.ToPILImage()(torch.rand(3, 16, 18)) | ||
|
||
inv_img = fn(img, **config) | ||
expected_transformed_img = fn(img, **kwargs) | ||
randomly_transformed_img = trans(p=p, **kwargs)(img) | ||
|
||
num_samples = 250 | ||
counts = 0 | ||
for _ in range(num_samples): | ||
tranformation = trans(p=p, **config) | ||
tranformation.__repr__() | ||
out = tranformation(img) | ||
if out == inv_img: | ||
counts += 1 | ||
if p == 0: | ||
assert randomly_transformed_img == img | ||
elif p == 1: | ||
assert randomly_transformed_img == expected_transformed_img | ||
|
||
p_value = stats.binom_test(counts, num_samples, p=p) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
trans(**kwargs).__repr__() | ||
|
||
|
||
class TestToPil: | ||
|
@@ -1362,160 +1360,42 @@ def test_to_grayscale(): | |
trans4.__repr__() | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_random_grayscale(): | ||
"""Unit tests for random grayscale transform""" | ||
@pytest.mark.parametrize("seed", range(10)) | ||
@pytest.mark.parametrize("p", (0, 1)) | ||
def test_random_apply(p, seed): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could move this one up with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Including this above will also have the negative side-effect of requiring you to modify your check condition for p=1 to |
||
torch.manual_seed(seed) | ||
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p) | ||
img = transforms.ToPILImage()(torch.rand(3, 30, 40)) | ||
out = random_apply_transform(img) | ||
if p == 0: | ||
assert out == img | ||
elif p == 1: | ||
assert out != img | ||
|
||
# Test Set 1: RGB -> 3 channel grayscale | ||
np_rng = np.random.RandomState(0) | ||
random_state = random.getstate() | ||
random.seed(42) | ||
x_shape = [2, 2, 3] | ||
x_np = np_rng.randint(0, 256, x_shape, np.uint8) | ||
x_pil = Image.fromarray(x_np, mode="RGB") | ||
x_pil_2 = x_pil.convert("L") | ||
gray_np = np.array(x_pil_2) | ||
|
||
num_samples = 250 | ||
num_gray = 0 | ||
for _ in range(num_samples): | ||
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) | ||
gray_np_2 = np.array(gray_pil_2) | ||
if ( | ||
np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) | ||
and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) | ||
and np.array_equal(gray_np, gray_np_2[:, :, 0]) | ||
): | ||
num_gray = num_gray + 1 | ||
|
||
p_value = stats.binom_test(num_gray, num_samples, p=0.5) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
# Test Set 2: grayscale -> 1 channel grayscale | ||
random_state = random.getstate() | ||
random.seed(42) | ||
x_shape = [2, 2, 3] | ||
x_np = np_rng.randint(0, 256, x_shape, np.uint8) | ||
x_pil = Image.fromarray(x_np, mode="RGB") | ||
x_pil_2 = x_pil.convert("L") | ||
gray_np = np.array(x_pil_2) | ||
|
||
num_samples = 250 | ||
num_gray = 0 | ||
for _ in range(num_samples): | ||
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2) | ||
gray_np_3 = np.array(gray_pil_3) | ||
if np.array_equal(gray_np, gray_np_3): | ||
num_gray = num_gray + 1 | ||
|
||
p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
# Test set 3: Explicit tests | ||
x_shape = [2, 2, 3] | ||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] | ||
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) | ||
x_pil = Image.fromarray(x_np, mode="RGB") | ||
x_pil_2 = x_pil.convert("L") | ||
gray_np = np.array(x_pil_2) | ||
|
||
# Case 3a: RGB -> 3 channel grayscale (grayscaled) | ||
trans2 = transforms.RandomGrayscale(p=1.0) | ||
gray_pil_2 = trans2(x_pil) | ||
gray_np_2 = np.array(gray_pil_2) | ||
assert gray_pil_2.mode == "RGB", "mode should be RGB" | ||
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" | ||
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) | ||
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) | ||
assert_equal(gray_np, gray_np_2[:, :, 0]) | ||
|
||
# Case 3b: RGB -> 3 channel grayscale (unchanged) | ||
trans2 = transforms.RandomGrayscale(p=0.0) | ||
gray_pil_2 = trans2(x_pil) | ||
gray_np_2 = np.array(gray_pil_2) | ||
assert gray_pil_2.mode == "RGB", "mode should be RGB" | ||
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" | ||
assert_equal(x_np, gray_np_2) | ||
|
||
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) | ||
trans3 = transforms.RandomGrayscale(p=1.0) | ||
gray_pil_3 = trans3(x_pil_2) | ||
gray_np_3 = np.array(gray_pil_3) | ||
assert gray_pil_3.mode == "L", "mode should be L" | ||
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" | ||
assert_equal(gray_np, gray_np_3) | ||
|
||
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) | ||
trans3 = transforms.RandomGrayscale(p=0.0) | ||
gray_pil_3 = trans3(x_pil_2) | ||
gray_np_3 = np.array(gray_pil_3) | ||
assert gray_pil_3.mode == "L", "mode should be L" | ||
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" | ||
assert_equal(gray_np, gray_np_3) | ||
# Checking if RandomApply can be printed as string | ||
random_apply_transform.__repr__() | ||
|
||
# Checking if RandomGrayscale can be printed as string | ||
trans3.__repr__() | ||
|
||
@pytest.mark.parametrize("seed", range(10)) | ||
@pytest.mark.parametrize("proba_passthrough", (0, 1)) | ||
def test_random_choice(proba_passthrough, seed): | ||
random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch | ||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_random_apply(): | ||
random_state = random.getstate() | ||
random.seed(42) | ||
random_apply_transform = transforms.RandomApply( | ||
random_choice_transform = transforms.RandomChoice( | ||
[ | ||
transforms.RandomRotation((-45, 45)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomVerticalFlip(), | ||
lambda x: x, # passthrough | ||
transforms.RandomRotation((1, 45)), | ||
], | ||
p=0.75, | ||
p=[proba_passthrough, 1 - proba_passthrough], | ||
) | ||
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) | ||
num_samples = 250 | ||
num_applies = 0 | ||
for _ in range(num_samples): | ||
out = random_apply_transform(img) | ||
if out != img: | ||
num_applies += 1 | ||
|
||
p_value = stats.binom_test(num_applies, num_samples, p=0.75) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
# Checking if RandomApply can be printed as string | ||
random_apply_transform.__repr__() | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_random_choice(): | ||
random_state = random.getstate() | ||
random.seed(42) | ||
random_choice_transform = transforms.RandomChoice( | ||
[transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3] | ||
) | ||
img = transforms.ToPILImage()(torch.rand(3, 25, 25)) | ||
num_samples = 250 | ||
num_resize_15 = 0 | ||
num_resize_20 = 0 | ||
num_crop_10 = 0 | ||
for _ in range(num_samples): | ||
out = random_choice_transform(img) | ||
if out.size == (15, 15): | ||
num_resize_15 += 1 | ||
elif out.size == (20, 20): | ||
num_resize_20 += 1 | ||
elif out.size == (10, 10): | ||
num_crop_10 += 1 | ||
|
||
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333) | ||
assert p_value > 0.0001 | ||
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333) | ||
assert p_value > 0.0001 | ||
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333) | ||
assert p_value > 0.0001 | ||
img = transforms.ToPILImage()(torch.rand(3, 30, 40)) | ||
out = random_choice_transform(img) | ||
if proba_passthrough == 1: | ||
assert out == img | ||
elif proba_passthrough == 0: | ||
assert out != img | ||
|
||
random.setstate(random_state) | ||
# Checking if RandomChoice can be printed as string | ||
random_choice_transform.__repr__() | ||
|
||
|
@@ -1888,6 +1768,7 @@ def test_random_erasing(): | |
tol = 0.05 | ||
assert 1 / 3 - tol <= aspect_ratio <= 3 + tol | ||
|
||
# Make sure that h > w and h < w are equaly likely (log-scale sampling) | ||
aspect_ratios = [] | ||
random.seed(42) | ||
trial = 1000 | ||
|
@@ -2011,72 +1892,6 @@ def test_randomperspective_fill(mode): | |
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands)) | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_random_vertical_flip(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved this one in |
||
random_state = random.getstate() | ||
random.seed(42) | ||
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) | ||
vimg = img.transpose(Image.FLIP_TOP_BOTTOM) | ||
|
||
num_samples = 250 | ||
num_vertical = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomVerticalFlip()(img) | ||
if out == vimg: | ||
num_vertical += 1 | ||
|
||
p_value = stats.binom_test(num_vertical, num_samples, p=0.5) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
num_samples = 250 | ||
num_vertical = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomVerticalFlip(p=0.7)(img) | ||
if out == vimg: | ||
num_vertical += 1 | ||
|
||
p_value = stats.binom_test(num_vertical, num_samples, p=0.7) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
# Checking if RandomVerticalFlip can be printed as string | ||
transforms.RandomVerticalFlip().__repr__() | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_random_horizontal_flip(): | ||
random_state = random.getstate() | ||
random.seed(42) | ||
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) | ||
himg = img.transpose(Image.FLIP_LEFT_RIGHT) | ||
|
||
num_samples = 250 | ||
num_horizontal = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomHorizontalFlip()(img) | ||
if out == himg: | ||
num_horizontal += 1 | ||
|
||
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
num_samples = 250 | ||
num_horizontal = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomHorizontalFlip(p=0.7)(img) | ||
if out == himg: | ||
num_horizontal += 1 | ||
|
||
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
# Checking if RandomHorizontalFlip can be printed as string | ||
transforms.RandomHorizontalFlip().__repr__() | ||
|
||
|
||
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") | ||
def test_normalize(): | ||
def samples_from_standard_normal(tensor): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this one into
test_randomness
above. It will cover all the randomness-related checks in this test. For the rest (expected values), everything is already covered intest_to_grayscale()
just above.