diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 68b52fff637..7bf412aaf99 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -211,7 +211,34 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp return solarize_image_pil(inpt, threshold=threshold) -autocontrast_image_tensor = _FT.autocontrast +def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: + + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if image.numel() == 0: + # exit earlier on empty images + return image + + bound = 1.0 if image.is_floating_point() else 255.0 + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + + minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = image.amax(dim=(-2, -1), keepdim=True).to(dtype) + + scale = bound / (maximum - minimum) + eq_idxs = maximum == minimum + minimum[eq_idxs] = 0.0 + scale[eq_idxs] = 1.0 + + return (image - minimum).mul_(scale).clamp_(0, bound).to(image.dtype) + + autocontrast_image_pil = _FP.autocontrast