diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index f8b237f2e96..133508f5f94 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1,3 +1,4 @@ +import decimal import functools import itertools import math @@ -21,6 +22,7 @@ mark_framework_limitation, TestMark, ) +from torch.utils._pytree import tree_map from torchvision.prototype import features from torchvision.transforms.functional_tensor import _max_value as get_max_value @@ -1947,3 +1949,119 @@ def sample_inputs_normalize_video(): ), ] ) + + +def sample_inputs_convert_image_dtype(): + for input_dtype, output_dtype in itertools.product( + [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 + ): + if input_dtype.is_floating_point and output_dtype == torch.int64: + # conversion cannot be performed safely + continue + + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] + ): + yield ArgsKwargs(image_loader, dtype=output_dtype) + + yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8) + + +def reference_convert_image_dtype(image, dtype=torch.float): + input_dtype = image.dtype + output_dtype = dtype + + if output_dtype == input_dtype: + return image + + def fn(value): + if input_dtype.is_floating_point: + if output_dtype.is_floating_point: + return value + else: + return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) + else: + input_max_value = torch.iinfo(input_dtype).max + + if output_dtype.is_floating_point: + return float(decimal.Decimal(value) / input_max_value) + else: + output_max_value = torch.iinfo(output_dtype).max + + if input_max_value > output_max_value: + factor = (input_max_value + 1) // (output_max_value + 1) + return value // factor + else: + factor = (output_max_value + 1) // (input_max_value + 1) + return value * factor + + return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) + + +def reference_inputs_convert_image_dtype(): + for input_dtype, output_dtype in itertools.product( + [ + torch.uint8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + repeat=2, + ): + if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or ( + input_dtype == torch.float64 and output_dtype == torch.int64 + ): + continue + + if input_dtype.is_floating_point: + data = [0.0, 0.5, 1.0] + else: + max_value = torch.iinfo(input_dtype).max + data = [0, max_value // 2, max_value] + image = torch.tensor(data, dtype=input_dtype) + + yield ArgsKwargs(image, dtype=output_dtype) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.convert_image_dtype, + sample_inputs_fn=sample_inputs_convert_image_dtype, + reference_fn=reference_convert_image_dtype, + reference_inputs_fn=reference_inputs_convert_image_dtype, + test_marks=[ + TestMark( + ("TestKernels", "test_scripted_vs_eager"), + pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"), + ), + TestMark( + ("TestKernels", "test_dtype_and_device_consistency"), + pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), + condition=lambda args_kwargs: args_kwargs.args[0].dtype + != args_kwargs.kwargs.get("dtype", torch.float32), + ), + TestMark( + ("TestKernels", "test_against_reference"), + pytest.mark.xfail(reason="Conversion overflows"), + condition=lambda args_kwargs: ( + args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} + and not args_kwargs.kwargs["dtype"].is_floating_point + ) + or ( + args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} + and args_kwargs.kwargs["dtype"] == torch.int64 + ) + or ( + args_kwargs.args[0].dtype in {torch.int32, torch.int64} + and args_kwargs.kwargs["dtype"] == torch.float16 + ), + ), + ], + ), + ] +) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index bafe1f13459..3423006e2eb 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -26,6 +26,20 @@ def script(fn): raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error +def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None): + args_kwargs = list(args_kwargs_fn(info)) + idx_field_len = len(str(len(args_kwargs))) + return [ + pytest.param( + info, + args_kwargs_, + marks=info.get_marks(test_id, args_kwargs_) if test_id else [], + id=f"{info.id}-{idx:0{idx_field_len}}", + ) + for idx, args_kwargs_ in enumerate(args_kwargs) + ] + + def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None): if condition is None: @@ -49,18 +63,7 @@ def decorator(test_fn): if not condition(info): continue - args_kwargs = list(args_kwargs_fn(info)) - idx_field_len = len(str(len(args_kwargs))) - - for idx, args_kwargs_ in enumerate(args_kwargs): - argvalues.append( - pytest.param( - info, - args_kwargs_, - marks=info.get_marks(test_id, args_kwargs_), - id=f"{info.id}-{idx:0{idx_field_len}}", - ) - ) + argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id)) return pytest.mark.parametrize(argnames, argvalues)(test_fn) @@ -232,7 +235,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): [ F.clamp_bounding_box, F.convert_color_space, - F.convert_image_dtype, F.get_dimensions, F.get_image_num_channels, F.get_image_size, @@ -312,6 +314,24 @@ def test_alias(alias, target): assert alias is target +@pytest.mark.parametrize( + ("info", "args_kwargs"), + make_info_args_kwargs_params( + next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), + args_kwargs_fn=lambda info: info.sample_inputs_fn(), + ), +) +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device): + (input, *other_args), kwargs = args_kwargs.load(device) + dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) + + output = info.kernel(input, dtype) + + assert output.dtype == dtype + assert output.device == input.device + + # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # `prototype_transforms_kernel_infos.py` diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 5fe990eb727..a57fbc65536 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -7,7 +7,7 @@ from torchvision.io.video import read_video from torchvision.prototype import features from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer -from torchvision.transforms import functional as _F +from torchvision.transforms import functional as _F, functional_tensor as _FT @torch.jit.unused @@ -42,4 +42,77 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> # prevalent and well understood. Thus, we just alias it without deprecating the old name. to_pil_image = to_image_pil -convert_image_dtype = _F.convert_image_dtype + +def _num_value_bits(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 8 + elif dtype == torch.int8: + return 7 + elif dtype == torch.int16: + return 15 + elif dtype == torch.int32: + return 31 + elif dtype == torch.int64: + return 63 + else: + raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") + + +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + if not isinstance(image, torch.Tensor): + raise TypeError("Input img should be Tensor Image") + + if image.dtype == dtype: + return image + + float_input = image.is_floating_point() + if torch.jit.is_scripting(): + # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT + float_output = torch.tensor(0, dtype=dtype).is_floating_point() + else: + float_output = dtype.is_floating_point + + if float_input: + # float to float + if float_output: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") + + # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting + # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only + # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # for a detailed analysis. + # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. + # Instead, we can also multiply by the maximum value plus something close to `1`. See + # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. + eps = 1e-3 + max_value = float(_FT._max_value(dtype)) + # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the + # discrete set `{0, 1}`. + return image.mul(max_value + 1.0 - eps).to(dtype) + else: + # int to float + if float_output: + return image.to(dtype).div_(_FT._max_value(image.dtype)) + + # int to int + num_value_bits_input = _num_value_bits(image.dtype) + num_value_bits_output = _num_value_bits(dtype) + + if num_value_bits_input > num_value_bits_output: + return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) + else: + # The bitshift kernel is not vectorized + # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 + # This results in the multiplication actually being faster. + # TODO: If the bitshift kernel is optimized in core, replace the computation below with + # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` + max_value_input = float(_FT._max_value(dtype)) + max_value_output = float(_FT._max_value(image.dtype)) + factor = int((max_value_input + 1) // (max_value_output + 1)) + return image.to(dtype).mul_(factor)