diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 362a7a1c0e8..7d72463260e 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -254,9 +254,10 @@ def __init__( legacy_transforms.RandomAdjustSharpness, [ ArgsKwargs(p=0, sharpness_factor=0.5), - ArgsKwargs(p=1, sharpness_factor=0.3), + ArgsKwargs(p=1, sharpness_factor=0.2), ArgsKwargs(p=1, sharpness_factor=0.99), ], + closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, ), ConsistencyConfig( prototype_transforms.RandomGrayscale, @@ -306,8 +307,9 @@ def __init__( ArgsKwargs(saturation=(0.8, 0.9)), ArgsKwargs(hue=0.3), ArgsKwargs(hue=(-0.1, 0.2)), - ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), + ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6), ], + closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, ), *[ ConsistencyConfig( @@ -753,7 +755,7 @@ def test_randaug(self, inpt, interpolation, mocker): expected_output = t_ref(inpt) output = t(inpt) - assert_equal(expected_output, output) + assert_close(expected_output, output, atol=1, rtol=0.1) @pytest.mark.parametrize( "inpt", @@ -801,7 +803,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker): expected_output = t_ref(inpt) output = t(inpt) - assert_equal(expected_output, output) + assert_close(expected_output, output, atol=1, rtol=0.1) @pytest.mark.parametrize( "inpt", diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 49a769e04e0..ae07cc0056d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,9 +2,29 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import get_dimensions_image_tensor +from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor + + +def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: + ratio = float(ratio) + fp = image1.is_floating_point() + bound = 1.0 if fp else 255.0 + output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) + return output if fp else output.to(image1.dtype) + + +def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + _FT._assert_channels(image, [1, 3]) + + fp = image.is_floating_point() + bound = 1.0 if fp else 255.0 + output = image.mul(brightness_factor).clamp_(0, bound) + return output if fp else output.to(image.dtype) + -adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness @@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) -adjust_saturation_image_tensor = _FT.adjust_saturation +def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = get_num_channels_image_tensor(image) + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + return _blend(image, _rgb_to_gray(image), saturation_factor) + + adjust_saturation_image_pil = _FP.adjust_saturation @@ -38,7 +71,19 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) -adjust_contrast_image_tensor = _FT.adjust_contrast +def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = get_num_channels_image_tensor(image) + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + grayscale_image = _rgb_to_gray(image) if c == 3 else image + mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) + return _blend(image, mean, contrast_factor) + + adjust_contrast_image_pil = _FP.adjust_contrast @@ -74,7 +119,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) else: needs_unsquash = False - output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) + output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) if needs_unsquash: output = output.reshape(shape) @@ -183,13 +228,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: - # input img shape should be [N, H, W] - shape = img.shape +def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor: + # input image shape should be [N, H, W] + shape = image.shape # Compute image histogram: - flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] - hist = flat_img.new_zeros(shape[0], 256) - hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) + flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W] + hist = flat_image.new_zeros(shape[0], 256) + hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image)) # Compute image cdf chist = hist.cumsum_(dim=1) @@ -213,7 +258,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) lut = torch.cat([zeros, lut[:, :-1]], dim=1) - return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img)) + return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image)) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2903d73ce95..61a54f01cc9 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -184,7 +184,11 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.repeat(repeats) -_rgb_to_gray = _FT.rgb_to_grayscale +def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: + r, g, b = image.unbind(dim=-3) + l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.to(image.dtype).unsqueeze(dim=-3) + return l_img def convert_color_space_image_tensor( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4944c75fab8..ca641faf161 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -816,12 +816,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: kernel /= kernel.sum() kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( - img, - [ - kernel.dtype, - ], - ) + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)