diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 0d5bdc31d3c..17878b0c698 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -312,7 +312,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: return posterize_image_pil(inpt, bits=bits) -solarize_image_tensor = _FT.solarize +def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: + if threshold > _FT._max_value(image.dtype): + raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") + + return torch.where(image >= threshold, invert_image_tensor(image), image) + + solarize_image_pil = _FP.solarize @@ -456,7 +462,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return equalize_image_pil(inpt) -invert_image_tensor = _FT.invert +def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.dtype == torch.uint8: + return image.bitwise_not() + else: + return _FT._max_value(image.dtype) - image # type: ignore[no-any-return] + + invert_image_pil = _FP.invert