diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 742b344cf71..cc7f9654b8c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -295,7 +295,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) -posterize_image_tensor = _FT.posterize +def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: + if image.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") + + # JIT-friendly for: ~(2 ** (8 - bits) - 1) + mask = -int(2 ** (8 - bits)) + return image & mask + + posterize_image_pil = _FP.posterize