diff --git a/test/test_transforms.py b/test/test_transforms.py index 945bf5fde4b..8423bf99ee3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,6 +23,22 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') +def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + + +def int_dtypes(): + yield from iter( + (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) + ) + + +def float_dtypes(): + yield from iter((torch.float32, torch.float, torch.float64, torch.double)) + + class Tester(unittest.TestCase): def test_crop(self): @@ -510,6 +526,100 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_convert_image_dtype_float_to_float(self): + for input_dtype, output_dtypes in cycle_over(float_dtypes()): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in output_dtypes: + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + + def test_convert_image_dtype_float_to_int(self): + for input_dtype in float_dtypes(): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in int_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + + if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( + input_dtype == torch.float64 and output_dtype == torch.int64 + ): + with self.assertRaises(RuntimeError): + transform(input_image) + else: + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, torch.iinfo(output_dtype).max + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_float(self): + for input_dtype in int_dtypes(): + input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) + for output_dtype in float_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertGreaterEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + self.assertLessEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_int(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, output_max + + # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details + if input_max >= output_max: + error_term = 0 + else: + error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max + error_term) + + def test_convert_image_dtype_int_to_int_consistency(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + if output_max <= input_max: + continue + + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + inverse_transfrom = transforms.ConvertImageDtype(input_dtype) + output_image = inverse_transfrom(transform(input_image)) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, input_max + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d19b26e36b2..e49ff063dc8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -113,6 +113,65 @@ def pil_to_tensor(pic): return img +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if image.dtype == dtype: + return image + + if image.dtype.is_floating_point: + # float to float + if dtype.is_floating_point: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + eps = 1e-3 + return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + else: + # int to float + if dtype.is_floating_point: + max = torch.iinfo(image.dtype).max + image = image.to(dtype) + return image / max + + # int to int + input_max = torch.iinfo(image.dtype).max + output_max = torch.iinfo(dtype).max + + if input_max > output_max: + factor = (input_max + 1) // (output_max + 1) + image = image // factor + return image.to(dtype) + else: + factor = (output_max + 1) // (input_max + 1) + image = image.to(dtype) + return image * factor + + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 812d82e3825..d54aa5099f2 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,10 +15,10 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", - "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", - "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", - "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", +__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing"] _pil_interpolation_to_str = { @@ -115,6 +115,31 @@ def __repr__(self): return self.__class__.__name__ + '()' +class ConvertImageDtype(object): + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + def __init__(self, dtype: torch.dtype) -> None: + self.dtype = dtype + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + return F.convert_image_dtype(image, self.dtype) + + class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image.