Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,53 +602,6 @@ def test__get_params(self):
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()


class TestRandomErasing:
def test_assertions(self):
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
transforms.RandomErasing(value={})

with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
transforms.RandomErasing(value="abc")

with pytest.raises(TypeError, match="Scale should be a sequence"):
transforms.RandomErasing(scale=123)

with pytest.raises(TypeError, match="Ratio should be a sequence"):
transforms.RandomErasing(ratio=123)

with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
transforms.RandomErasing(scale=[-1, 2])

image = make_image((24, 32))

transform = transforms.RandomErasing(value=[1, 2, 3, 4])

with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params([image])

@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
def test__get_params(self, value):
image = make_image((24, 32))
num_channels, height, width = F.get_dimensions(image)

transform = transforms.RandomErasing(value=value)
params = transform._get_params([image])

v = params["v"]
h, w = params["h"], params["w"]
i, j = params["i"], params["j"]
assert isinstance(v, torch.Tensor)
if value == "random":
assert v.shape == (num_channels, h, w)
elif isinstance(value, (int, float)):
assert v.shape == (1, 1, 1)
elif isinstance(value, (list, tuple)):
assert v.shape == (num_channels, 1, 1)

assert 0 <= i <= height - h
assert 0 <= j <= width - w


class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
Expand Down
15 changes: 0 additions & 15 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,6 @@ def __init__(
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
legacy_transforms.RandomErasing,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
ArgsKwargs(p=1, scale=(0.3, 0.7)),
ArgsKwargs(p=1, ratio=(0.5, 1.5)),
ArgsKwargs(p=1, value=1),
ArgsKwargs(p=1, value=(1, 2, 3)),
ArgsKwargs(p=1, value="random"),
],
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.ColorJitter,
legacy_transforms.ColorJitter,
Expand Down Expand Up @@ -570,7 +556,6 @@ def test_call_consistency(config, args_kwargs):
)
for transform_cls, get_params_args_kwargs in [
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
Expand Down
142 changes: 142 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,3 +2485,145 @@ def test_correctness(self):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
else:
assert isinstance(out_value, type(input_value))


class TestErase:
INPUT_SIZE = (17, 11)
FUNCTIONAL_KWARGS = dict(
zip("ijhwv", [2, 2, 10, 8, torch.tensor(0.0, dtype=torch.float32, device="cpu").reshape(-1, 1, 1)])
)

@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.erase_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **self.FUNCTIONAL_KWARGS)

@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_inplace(self, dtype, device):
input = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
input_version = input._version

output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS)
assert output_out_of_place.data_ptr() != input.data_ptr()

output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True)
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version

assert_equal(output_inplace, output_out_of_place)

def test_kernel_video(self):
check_kernel(F.erase_video, make_video(self.INPUT_SIZE), **self.FUNCTIONAL_KWARGS)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
def test_functional(self, make_input):
check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.erase_image, torch.Tensor),
(F._erase_image_pil, PIL.Image.Image),
(F.erase_image, datapoints.Image),
(F.erase_video, datapoints.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomErasing(p=1), make_input(device=device))

def _reference_erase_image(self, image, *, i, j, h, w, v):
mask = torch.zeros_like(image, dtype=torch.bool)
mask[..., i : i + h, j : j + w] = True

# The broadcasting and type casting logic is handled automagically in the kernel through indexing
value = torch.broadcast_to(v, (*image.shape[:-2], h, w)).to(image)

erased_image = torch.empty_like(image)
erased_image[mask] = value.flatten()
erased_image[~mask] = image[~mask]

return erased_image

@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_functional_image_correctness(self, dtype, device):
image = make_image(dtype=dtype, device=device)

actual = F.erase(image, **self.FUNCTIONAL_KWARGS)
expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS)

assert_equal(actual, expected)

@param_value_parametrization(
scale=[(0.1, 0.2), [0.0, 1.0]],
ratio=[(0.3, 0.7), [0.1, 5.0]],
value=[0, 0.5, (0, 1, 0), [-0.2, 0.0, 1.3], "random"],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, param, value, dtype, device, seed):
transform = transforms.RandomErasing(**{param: value}, p=1)

image = make_image(dtype=dtype, device=device)

with freeze_rng_state():
torch.manual_seed(seed)
# This emulates the random apply check that happens before _get_params is called
torch.rand(1)
params = transform._get_params([image])

torch.manual_seed(seed)
actual = transform(image)

expected = self._reference_erase_image(image, **params)

assert_equal(actual, expected)

def test_transform_errors(self):
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
transforms.RandomErasing(value={})

with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
transforms.RandomErasing(value="abc")

with pytest.raises(TypeError, match="Scale should be a sequence"):
transforms.RandomErasing(scale=123)

with pytest.raises(TypeError, match="Ratio should be a sequence"):
transforms.RandomErasing(ratio=123)

with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
transforms.RandomErasing(scale=[-1, 2])

transform = transforms.RandomErasing(value=[1, 2, 3, 4])

with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params([make_image()])

@pytest.mark.parametrize(
"make_input",
[make_bounding_boxes, make_detection_mask, make_segmentation_mask],
)
def test_transform_passthrough(self, make_input):
transform = transforms.RandomErasing(p=1)

input = make_input(self.INPUT_SIZE)

with pytest.warns(UserWarning, match="currently passing through inputs of type"):
# RandomErasing requires an image or video to be present
_, output = transform(make_image(self.INPUT_SIZE), input)

assert output is input
11 changes: 0 additions & 11 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
),
DispatcherInfo(
F.erase,
kernels={
datapoints.Image: F.erase_image,
datapoints.Video: F.erase_video,
},
pil_kernel_info=PILKernelInfo(F._erase_image_pil),
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.adjust_contrast,
kernels={
Expand Down
30 changes: 0 additions & 30 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,36 +1222,6 @@ def sample_inputs_adjust_sharpness_video():
)


def sample_inputs_erase_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
# FIXME: make the parameters more diverse
h, w = 6, 7
v = torch.rand(image_loader.num_channels, h, w)
yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)


def sample_inputs_erase_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
# FIXME: make the parameters more diverse
h, w = 6, 7
v = torch.rand(video_loader.num_channels, h, w)
yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)


KERNEL_INFOS.extend(
[
KernelInfo(
F.erase_image,
kernel_name="erase_image_tensor",
sample_inputs_fn=sample_inputs_erase_image_tensor,
),
KernelInfo(
F.erase_video,
sample_inputs_fn=sample_inputs_erase_video,
),
]
)

_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


Expand Down