Skip to content

convert_image_dtype overflows with low precision floating point dtypes #6799

Open
@pmeier

Description

@pmeier

While working on improving performance of convert_image_dtype in #6795, I found several cases where convert_image_dtype is silently failing for low precision floating point dtypes torch.float16 and torch.bfloat16:

import torch
from torchvision.transforms import functional as F

# torch.{float16, bfloat16} to any integer dtype
image = torch.tensor(1.0, dtype=torch.float16)
print(image, F.convert_image_dtype(image, torch.uint8), F.convert_image_dtype(image, torch.int8))

# torch.{int32, int64} to torch.float16
image = torch.tensor(2**31 - 1, dtype=torch.int32)
print(image, F.convert_image_dtype(image, torch.float16))
tensor(1., dtype=torch.float16) tensor(0, dtype=torch.uint8) tensor(-128, dtype=torch.int8)
tensor(2147483647, dtype=torch.int32) tensor(nan, dtype=torch.float16)
  1. Converting an valid (b)float16 image in the value range [0.0, 1.0] to any integer dtype overflows the computation. This stems from the fact that eps is fixed:

    eps = 1e-3
    max_val = float(_max_value(dtype))
    result = image.mul(max_val + 1.0 - eps)
    return result.to(dtype)

    This value is simply to large for (b)float16:

    >>> image = torch.tensor(1.0, dtype=torch.float16)
    >>> image.mul(255 + 1.0 - 1e-3)  # float16 -> uint8
    tensor(256., dtype=torch.float16)
    >>> image.to(torch.float32).mul(255 + 1.0 - 1e-3)  # float32 -> uint8
    tensor(255.9990)
    >>> image.mul(255 + 1.0 - 7e-2)  # float16 -> uint8 with adjusted eps
    tensor(255.8750, dtype=torch.float16)

    The whole point of eps is to be as small as possible to have an even value distribution. See Add convert_image_dtype to functionals #2078 (comment) for details.

    We could simply make eps dependent on the input dtype in a function similar to

    def _max_value(dtype: torch.dtype) -> int:

  2. Converting a int{32, 64} image to float16 should not be possible since it can't hold the maximum values:

    >>> torch.finfo(torch.float16).max
    65504.0
    >>> torch.iinfo(torch.int16).max  # ok
    32767
    >>> torch.iinfo(torch.int32).max  # not ok
    2147483647
    >>> torch.iinfo(torch.int64).max  # not ok
    9223372036854775807
    >>> torch.finfo(torch.bfloat16).max  # bfloat does not have this issue
    3.3895313892515355e+38

    We are already raising an error for unsafe float to int conversions

    # float to int
    if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
    image.dtype == torch.float64 and dtype == torch.int64
    ):
    msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
    raise RuntimeError(msg)

    so we could simply do the same here.

cc @vfdev-5 @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions