From ce3cfe09eaf27b6bfda88fe6a0032e57e453d382 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 13:50:45 +0200 Subject: [PATCH 01/20] add convert_image_dtype to functionals --- torchvision/transforms/functional.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d19b26e36b2..27accecf1aa 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -113,6 +113,38 @@ def pil_to_tensor(pic): return img +def convert_image_dtype( + image: torch.Tensor, dtype: torch.dtype = torch.float +) -> torch.Tensor: + def scale_factor(dtype: torch.dtype) -> float: + if dtype in ( + torch.float32, + torch.float, + torch.float64, + torch.double, + torch.float16, + torch.half, + ): + return 1.0 + + num_value_bits = { + torch.uint8: 8, + torch.int8: 7, + torch.int16: 15, + torch.short: 15, + torch.int32: 31, + torch.int: 31, + torch.int64: 63, + torch.long: 63, + torch.bool: 1, + } + return float(2 ** num_value_bits[dtype] - 1) + + return ( + image.double().div(scale_factor(image.dtype)).mul(scale_factor(dtype)).to(dtype) + ) + + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. From 1c54ebafeaadeff1bc83acf940fc265b21b5360a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 14:00:17 +0200 Subject: [PATCH 02/20] add ConvertImageDtype transform --- torchvision/transforms/transforms.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 812d82e3825..6dbb504fc00 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,10 +15,10 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", - "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", - "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", - "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", +__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing"] _pil_interpolation_to_str = { @@ -115,6 +115,20 @@ def __repr__(self): return self.__class__.__name__ + '()' +class ConvertImageDtype(object): + """Convert a tensor to the given ``dtype`` and scale the values accordingly + + Args: + dtype (torch.dtype): Desired data type of the output + + """ + def __init__(self, dtype: torch.dtype) -> None: + self.dtype = dtype + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + return F.convert_image_dtype(image, self.dtype) + + class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. From f15dbdb60208f3d860c997eee8dbe665791c5e02 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 14:17:59 +0200 Subject: [PATCH 03/20] add test --- test/test_transforms.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 945bf5fde4b..5f5a5bffed7 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -510,6 +510,40 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_convert_image_dtype(self): + dtype_max_value = { + torch.float32: 1.0, + torch.float: 1.0, + torch.float64: 1.0, + torch.double: 1.0, + torch.float16: 1.0, + torch.half: 1.0, + torch.uint8: 255, + torch.int8: 127, + torch.int16: 32_767, + torch.short: 32_767, + torch.int32: 2_147_483_647, + torch.int: 2_147_483_647, + torch.int64: 9_223_372_036_854_775_807, + torch.long: 9_223_372_036_854_775_807, + torch.bool: 1, + } + + def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + + for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): + input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] + + for output_dtype in output_dtypes: + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + self.assertEqual(output_image.dtype, output_dtype) + self.assertEqual(torch.max(output_image), dtype_max_value[output_dtype]) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() From dbe79c190cb1ab786f49a8e454ddb6c37a6fc285 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 14:46:49 +0200 Subject: [PATCH 04/20] remove underscores from numbers since they are not compatible with python<3.6 --- test/test_transforms.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5f5a5bffed7..58c1990e2b6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -518,14 +518,14 @@ def test_convert_image_dtype(self): torch.double: 1.0, torch.float16: 1.0, torch.half: 1.0, - torch.uint8: 255, - torch.int8: 127, - torch.int16: 32_767, - torch.short: 32_767, - torch.int32: 2_147_483_647, - torch.int: 2_147_483_647, - torch.int64: 9_223_372_036_854_775_807, - torch.long: 9_223_372_036_854_775_807, + torch.uint8: 2 ** 8 - 1, + torch.int8: 2 ** 7 - 1, + torch.int16: 2 ** 15 - 1, + torch.short: 2 ** 15 - 1, + torch.int32: 2 ** 31 - 1, + torch.int: 2 ** 31 - 1, + torch.int64: 2 ** 63 - 1, + torch.long: 2 ** 63 - 1, torch.bool: 1, } @@ -542,7 +542,9 @@ def cycle_over(objs): output_image = transform(input_image) self.assertEqual(output_image.dtype, output_dtype) - self.assertEqual(torch.max(output_image), dtype_max_value[output_dtype]) + self.assertEqual( + torch.max(output_image).item(), dtype_max_value[output_dtype] + ) @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): From dbde8cb3b2ac5a56a7d1a8c57b73bbdff38c7679 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 15:29:43 +0200 Subject: [PATCH 05/20] address review comments 1/3 --- torchvision/transforms/functional.py | 32 +++++++--------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 27accecf1aa..03ba477c2ba 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -117,32 +117,16 @@ def convert_image_dtype( image: torch.Tensor, dtype: torch.dtype = torch.float ) -> torch.Tensor: def scale_factor(dtype: torch.dtype) -> float: - if dtype in ( - torch.float32, - torch.float, - torch.float64, - torch.double, - torch.float16, - torch.half, - ): + if dtype.is_floating_point: return 1.0 + else: + return float(torch.iinfo(dtype)) - num_value_bits = { - torch.uint8: 8, - torch.int8: 7, - torch.int16: 15, - torch.short: 15, - torch.int32: 31, - torch.int: 31, - torch.int64: 63, - torch.long: 63, - torch.bool: 1, - } - return float(2 ** num_value_bits[dtype] - 1) - - return ( - image.double().div(scale_factor(image.dtype)).mul(scale_factor(dtype)).to(dtype) - ) + old_dtype = image.dtype + image = image.double() + image = image / scale_factor(old_dtype) + image = image * scale_factor(dtype) + return image.to(dtype) def to_pil_image(pic, mode=None): From b4d74f90a4e39be66f918f8732766b0b5ea67cc7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 15:32:26 +0200 Subject: [PATCH 06/20] fix torch.bool --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 03ba477c2ba..84f8cd6c150 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -117,7 +117,7 @@ def convert_image_dtype( image: torch.Tensor, dtype: torch.dtype = torch.float ) -> torch.Tensor: def scale_factor(dtype: torch.dtype) -> float: - if dtype.is_floating_point: + if dtype.is_floating_point or dtype == torch.bool: return 1.0 else: return float(torch.iinfo(dtype)) From 0225cf996e897aca8058c3606e9973e67027d62f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 15:45:59 +0200 Subject: [PATCH 07/20] use torch.iinfo in test --- test/test_transforms.py | 46 +++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 58c1990e2b6..2b9190162b2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -512,27 +512,37 @@ def test_to_tensor(self): def test_convert_image_dtype(self): dtype_max_value = { - torch.float32: 1.0, - torch.float: 1.0, - torch.float64: 1.0, - torch.double: 1.0, - torch.float16: 1.0, - torch.half: 1.0, - torch.uint8: 2 ** 8 - 1, - torch.int8: 2 ** 7 - 1, - torch.int16: 2 ** 15 - 1, - torch.short: 2 ** 15 - 1, - torch.int32: 2 ** 31 - 1, - torch.int: 2 ** 31 - 1, - torch.int64: 2 ** 63 - 1, - torch.long: 2 ** 63 - 1, - torch.bool: 1, - } + dtype: 1.0 + for dtype in ( + torch.float32, + torch.float, + torch.float64, + torch.double, + torch.float16, + torch.half, + torch.bool, + ) + } + dtype_max_value.update( + { + dtype: torch.iinfo(dtype).max + for dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.short, + torch.int32, + torch.int, + torch.int64, + torch.long, + ) + } + ) def cycle_over(objs): objs = list(objs) for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] + yield obj, objs[:idx] + objs[idx + 1 :] for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] @@ -542,7 +552,7 @@ def cycle_over(objs): output_image = transform(input_image) self.assertEqual(output_image.dtype, output_dtype) - self.assertEqual( + self.assertAlmostEqual( torch.max(output_image).item(), dtype_max_value[output_dtype] ) From cc1223c052d04b833d60712d7d4df8982c22cbd9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 15:52:02 +0200 Subject: [PATCH 08/20] fix flake8 --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2b9190162b2..c52ac388d60 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -542,7 +542,7 @@ def test_convert_image_dtype(self): def cycle_over(objs): objs = list(objs) for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1 :] + yield obj, objs[:idx] + objs[idx + 1:] for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] From 377861962551ace0d7bcc3459d41ee3507a0d708 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 16:05:01 +0200 Subject: [PATCH 09/20] remove double conversion --- torchvision/transforms/functional.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 84f8cd6c150..5fe2cae7be6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -122,9 +122,7 @@ def scale_factor(dtype: torch.dtype) -> float: else: return float(torch.iinfo(dtype)) - old_dtype = image.dtype - image = image.double() - image = image / scale_factor(old_dtype) + image = image / scale_factor(image.dtype) image = image * scale_factor(dtype) return image.to(dtype) From 36a908f6cfb124c1a0d857eea70980f62c379569 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 16:06:46 +0200 Subject: [PATCH 10/20] fix flake9 --- test/test_transforms.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c52ac388d60..889078d68c4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -512,17 +512,17 @@ def test_to_tensor(self): def test_convert_image_dtype(self): dtype_max_value = { - dtype: 1.0 - for dtype in ( - torch.float32, - torch.float, - torch.float64, - torch.double, - torch.float16, - torch.half, - torch.bool, - ) - } + dtype: 1.0 + for dtype in ( + torch.float32, + torch.float, + torch.float64, + torch.double, + torch.float16, + torch.half, + torch.bool, + ) + } dtype_max_value.update( { dtype: torch.iinfo(dtype).max From 6a7a95fb40d7310db82d0e9a1913917e6816c93e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Apr 2020 16:16:55 +0200 Subject: [PATCH 11/20] bug fix --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5fe2cae7be6..93cfa629e5f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -120,7 +120,7 @@ def scale_factor(dtype: torch.dtype) -> float: if dtype.is_floating_point or dtype == torch.bool: return 1.0 else: - return float(torch.iinfo(dtype)) + return float(torch.iinfo(dtype).max) image = image / scale_factor(image.dtype) image = image * scale_factor(dtype) From 570893988daf32d47ecc0332e76d7a2589310def Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Apr 2020 09:47:39 +0200 Subject: [PATCH 12/20] add error messages to test --- test/test_transforms.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 889078d68c4..bb70a4ae964 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -550,11 +550,21 @@ def cycle_over(objs): for output_dtype in output_dtypes: transform = transforms.ConvertImageDtype(output_dtype) output_image = transform(input_image) + msg_prefix = "Conversion from {input_dtype} to {output_dtype} resulted in ".format( + input_dtype=input_dtype, output_dtype=output_dtype + ) + + actual = output_image.dtype + desired = output_dtype + msg = msg_prefix + "{actual_dtype}.".format(actual_dtype=actual) + self.assertEqual(actual, desired, msg=msg) - self.assertEqual(output_image.dtype, output_dtype) - self.assertAlmostEqual( - torch.max(output_image).item(), dtype_max_value[output_dtype] + actual = torch.max(output_image).item() + desired = dtype_max_value[output_dtype] + msg = msg_prefix + "{actual_max_value} instead of {desired_max_value}.".format( + actual_max_value=actual, desired_max_value=desired ) + self.assertAlmostEqual(actual, desired, msg=msg) @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): From f2f26f158622b567293e5ad489b4134b912d1af9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Apr 2020 09:51:01 +0200 Subject: [PATCH 13/20] disable torch.float16 and torch.half for now --- test/test_transforms.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index bb70a4ae964..3c9391b57a4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -513,15 +513,10 @@ def test_to_tensor(self): def test_convert_image_dtype(self): dtype_max_value = { dtype: 1.0 - for dtype in ( - torch.float32, - torch.float, - torch.float64, - torch.double, - torch.float16, - torch.half, - torch.bool, - ) + for dtype in (torch.float32, torch.float, torch.float64, torch.double, torch.bool,) + # torch.float16 and torch.half are disabled for now since they do not support torch.max + # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051 + # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, ) } dtype_max_value.update( { From d3e403161b1bb65c426a49bc530154e1c0eaa3fc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Apr 2020 09:52:49 +0200 Subject: [PATCH 14/20] add docstring --- torchvision/transforms/functional.py | 9 +++++++++ torchvision/transforms/transforms.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 93cfa629e5f..d64861eaf96 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -116,6 +116,15 @@ def pil_to_tensor(pic): def convert_image_dtype( image: torch.Tensor, dtype: torch.dtype = torch.float ) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + """ def scale_factor(dtype: torch.dtype) -> float: if dtype.is_floating_point or dtype == torch.bool: return 1.0 diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6dbb504fc00..011ce226f5a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -116,7 +116,7 @@ def __repr__(self): class ConvertImageDtype(object): - """Convert a tensor to the given ``dtype`` and scale the values accordingly + """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: dtype (torch.dtype): Desired data type of the output From eb3aab17ab2753982c2a7f7fc102c816cd9887f4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Apr 2020 10:18:42 +0200 Subject: [PATCH 15/20] add test for consistency --- test/test_transforms.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3c9391b57a4..844d3510c4c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -561,6 +561,53 @@ def cycle_over(objs): ) self.assertAlmostEqual(actual, desired, msg=msg) + def test_convert_image_dtype_consistency(self): + def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + + def create_input_image(dtype, seed=0, size=(3, 4, 4)): + torch.manual_seed(seed) + if dtype.is_floating_point: + return torch.rand(size, dtype=dtype) + elif dtype == torch.bool: + return torch.rand(size) > 0.5 + else: + return torch.randint(torch.iinfo(dtype).max, size, dtype=dtype) + + dtypes = ( + torch.float32, + torch.float, + torch.float64, + torch.double, + # torch.float16, + # torch.half, + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.short, + torch.int32, + torch.int, + torch.int64, + torch.long, + ) + + for dtype, intermediate_dtypes in cycle_over(dtypes): + input_image = create_input_image(dtype) + inverse_transform = transforms.ConvertImageDtype(dtype) + + for intermediate_dtype in intermediate_dtypes: + transform = transforms.ConvertImageDtype(intermediate_dtype) + output_image = inverse_transform(transform(input_image)) + + msg = "Cycle conversion from {dtype} via {intermediate_dtype} resulted in inconsistent results.".format( + dtype=dtype, intermediate_dtype=intermediate_dtype + ) + # FIXME: This should compare floating point tensor for almost equality + # self.assertEqual(output_image, input_image, msg=msg) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() From 49e6a9c43921d840b8f9720c66bf9184331fba29 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Apr 2020 10:19:21 +0200 Subject: [PATCH 16/20] move nested function to top --- test/test_transforms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 844d3510c4c..4470a1028f0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -511,6 +511,11 @@ def test_to_tensor(self): self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) def test_convert_image_dtype(self): + def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + dtype_max_value = { dtype: 1.0 for dtype in (torch.float32, torch.float, torch.float64, torch.double, torch.bool,) @@ -534,11 +539,6 @@ def test_convert_image_dtype(self): } ) - def cycle_over(objs): - objs = list(objs) - for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] - for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] From b6becf17d5ddce829a8a53653fdf0912e4afcf05 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 10 Apr 2020 19:54:51 +0200 Subject: [PATCH 17/20] test in CI --- test/test_transforms.py | 93 +++++++--------------------- torchvision/transforms/functional.py | 45 +++++++++++--- 2 files changed, 60 insertions(+), 78 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4470a1028f0..33d976b37bd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -516,13 +516,14 @@ def cycle_over(objs): for idx, obj in enumerate(objs): yield obj, objs[:idx] + objs[idx + 1:] - dtype_max_value = { - dtype: 1.0 - for dtype in (torch.float32, torch.float, torch.float64, torch.double, torch.bool,) - # torch.float16 and torch.half are disabled for now since they do not support torch.max - # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051 - # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, ) - } + # dtype_max_value = { + # dtype: 1.0 + # for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,) + # # torch.float16 and torch.half are disabled for now since they do not support torch.max + # # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051 + # # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, ) + # } + dtype_max_value = {} dtype_max_value.update( { dtype: torch.iinfo(dtype).max @@ -543,70 +544,20 @@ def cycle_over(objs): input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] for output_dtype in output_dtypes: - transform = transforms.ConvertImageDtype(output_dtype) - output_image = transform(input_image) - msg_prefix = "Conversion from {input_dtype} to {output_dtype} resulted in ".format( - input_dtype=input_dtype, output_dtype=output_dtype - ) - - actual = output_image.dtype - desired = output_dtype - msg = msg_prefix + "{actual_dtype}.".format(actual_dtype=actual) - self.assertEqual(actual, desired, msg=msg) - - actual = torch.max(output_image).item() - desired = dtype_max_value[output_dtype] - msg = msg_prefix + "{actual_max_value} instead of {desired_max_value}.".format( - actual_max_value=actual, desired_max_value=desired - ) - self.assertAlmostEqual(actual, desired, msg=msg) - - def test_convert_image_dtype_consistency(self): - def cycle_over(objs): - objs = list(objs) - for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] - - def create_input_image(dtype, seed=0, size=(3, 4, 4)): - torch.manual_seed(seed) - if dtype.is_floating_point: - return torch.rand(size, dtype=dtype) - elif dtype == torch.bool: - return torch.rand(size) > 0.5 - else: - return torch.randint(torch.iinfo(dtype).max, size, dtype=dtype) - - dtypes = ( - torch.float32, - torch.float, - torch.float64, - torch.double, - # torch.float16, - # torch.half, - torch.bool, - torch.uint8, - torch.int8, - torch.int16, - torch.short, - torch.int32, - torch.int, - torch.int64, - torch.long, - ) - - for dtype, intermediate_dtypes in cycle_over(dtypes): - input_image = create_input_image(dtype) - inverse_transform = transforms.ConvertImageDtype(dtype) - - for intermediate_dtype in intermediate_dtypes: - transform = transforms.ConvertImageDtype(intermediate_dtype) - output_image = inverse_transform(transform(input_image)) - - msg = "Cycle conversion from {dtype} via {intermediate_dtype} resulted in inconsistent results.".format( - dtype=dtype, intermediate_dtype=intermediate_dtype - ) - # FIXME: This should compare floating point tensor for almost equality - # self.assertEqual(output_image, input_image, msg=msg) + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual = output_image.dtype + desired = output_dtype + self.assertEqual(actual, desired) + + actual = torch.max(output_image).item() + desired = dtype_max_value[output_dtype] + if output_dtype.is_floating_point: + self.assertAlmostEqual(actual, desired) + else: + self.assertEqual(actual, desired) @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d64861eaf96..d821514834f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -125,15 +125,46 @@ def convert_image_dtype( Returns: (torch.Tensor): Converted image """ - def scale_factor(dtype: torch.dtype) -> float: - if dtype.is_floating_point or dtype == torch.bool: - return 1.0 + def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return image.to(dtype) + + def float_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + max = float(torch.iinfo(dtype).max) + image = image * (max + 1.0) + image = torch.clamp(image, max) + return image.to(dtype) + + def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + max = torch.iinfo(image.dtype).max + image = image.to(dtype) + return image / max + + def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + input_max = torch.iinfo(image.dtype).max + output_max = torch.iinfo(dtype).max + + if input_max > output_max: + factor = (input_max + 1) // (output_max + 1) + image = image // factor + return image.to(dtype) else: - return float(torch.iinfo(dtype).max) + factor = (output_max + 1) // (input_max + 1) + image = image.to(dtype) + return (image + 1) * factor - 1 - image = image / scale_factor(image.dtype) - image = image * scale_factor(dtype) - return image.to(dtype) + if image.dtype == dtype: + return image + + if image.dtype.is_floating_point: + if dtype.is_floating_point: + return float_to_float(image, dtype) + else: + return float_to_int(image, dtype) + else: + if dtype.is_floating_point: + return int_to_float(image, dtype) + else: + return int_to_int(image, dtype) def to_pil_image(pic, mode=None): From adfb0964e3cee46ba1ea81e5cba8dc15d8c37918 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 13:26:24 +0200 Subject: [PATCH 18/20] dirty progress --- torchvision/transforms/functional.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d821514834f..63316569848 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -124,15 +124,22 @@ def convert_image_dtype( Returns: (torch.Tensor): Converted image + + Raises: + TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` + or :class:`torch.int64` as well as for trying to cast + :class:`torch.float64` to :class:`torch.int64`. These conversions are + unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors """ def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return image.to(dtype) - def float_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - max = float(torch.iinfo(dtype).max) - image = image * (max + 1.0) - image = torch.clamp(image, max) - return image.to(dtype) + def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor: + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64): + msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, " + f"since {image.dtype} cannot ") + raise TypeError(msg) + return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: max = torch.iinfo(image.dtype).max From 28e2fbf4697d928cbce1d4673aac75fd85a59c46 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 14:56:51 +0200 Subject: [PATCH 19/20] add int to int and cleanup --- test/test_transforms.py | 137 +++++++++++++++++++-------- torchvision/transforms/functional.py | 70 +++++++------- torchvision/transforms/transforms.py | 11 +++ 3 files changed, 143 insertions(+), 75 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 33d976b37bd..11330635882 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,6 +23,20 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') +def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + +def int_dtypes(): + yield from iter( + (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) + ) + +def float_dtypes(): + yield from iter((torch.float32, torch.float, torch.float64, torch.double)) + + class Tester(unittest.TestCase): def test_crop(self): @@ -510,54 +524,99 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) - def test_convert_image_dtype(self): - def cycle_over(objs): - objs = list(objs) - for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] - - # dtype_max_value = { - # dtype: 1.0 - # for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,) - # # torch.float16 and torch.half are disabled for now since they do not support torch.max - # # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051 - # # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, ) - # } - dtype_max_value = {} - dtype_max_value.update( - { - dtype: torch.iinfo(dtype).max - for dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.short, - torch.int32, - torch.int, - torch.int64, - torch.long, - ) - } - ) + def test_convert_image_dtype_float_to_float(self): + for input_dtype, output_dtypes in cycle_over(float_dtypes()): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in output_dtypes: + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + + def test_convert_image_dtype_float_to_int(self): + for input_dtype in float_dtypes(): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in int_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + + if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( + input_dtype == torch.float64 and output_dtype == torch.int64 + ): + with self.assertRaises(RuntimeError): + transform(input_image) + else: + output_image = transform(input_image) - for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): - input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, torch.iinfo(output_dtype).max + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_float(self): + for input_dtype in int_dtypes(): + input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) + for output_dtype in float_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertGreaterEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + self.assertLessEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_int(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) output_image = transform(input_image) - actual = output_image.dtype - desired = output_dtype - self.assertEqual(actual, desired) + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, output_max - actual = torch.max(output_image).item() - desired = dtype_max_value[output_dtype] - if output_dtype.is_floating_point: - self.assertAlmostEqual(actual, desired) + # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details + if input_max >= output_max: + error_term = 0 else: - self.assertEqual(actual, desired) + error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max + error_term) + + def test_convert_image_dtype_int_to_int_consistency(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + if output_max <= input_max: + continue + + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + inverse_transfrom = transforms.ConvertImageDtype(input_dtype) + output_image = inverse_transfrom(transform(input_image)) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, input_max + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 63316569848..e49ff063dc8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -113,9 +113,7 @@ def pil_to_tensor(pic): return img -def convert_image_dtype( - image: torch.Tensor, dtype: torch.dtype = torch.float -) -> torch.Tensor: +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: @@ -125,28 +123,42 @@ def convert_image_dtype( Returns: (torch.Tensor): Converted image + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + Raises: - TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` - or :class:`torch.int64` as well as for trying to cast - :class:`torch.float64` to :class:`torch.int64`. These conversions are - unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. """ - def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - return image.to(dtype) - - def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor: - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64): - msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, " - f"since {image.dtype} cannot ") - raise TypeError(msg) - return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + if image.dtype == dtype: + return image - def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - max = torch.iinfo(image.dtype).max - image = image.to(dtype) - return image / max + if image.dtype.is_floating_point: + # float to float + if dtype.is_floating_point: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) - def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + eps = 1e-3 + return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + else: + # int to float + if dtype.is_floating_point: + max = torch.iinfo(image.dtype).max + image = image.to(dtype) + return image / max + + # int to int input_max = torch.iinfo(image.dtype).max output_max = torch.iinfo(dtype).max @@ -157,21 +169,7 @@ def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: else: factor = (output_max + 1) // (input_max + 1) image = image.to(dtype) - return (image + 1) * factor - 1 - - if image.dtype == dtype: - return image - - if image.dtype.is_floating_point: - if dtype.is_floating_point: - return float_to_float(image, dtype) - else: - return float_to_int(image, dtype) - else: - if dtype.is_floating_point: - return int_to_float(image, dtype) - else: - return int_to_int(image, dtype) + return image * factor def to_pil_image(pic, mode=None): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 011ce226f5a..d54aa5099f2 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -121,7 +121,18 @@ class ConvertImageDtype(object): Args: dtype (torch.dtype): Desired data type of the output + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. """ + def __init__(self, dtype: torch.dtype) -> None: self.dtype = dtype From 56ba421363ee39d6708e7f2fc89d1fa6a7fdd231 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Jun 2020 15:17:53 +0200 Subject: [PATCH 20/20] lint --- test/test_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 11330635882..8423bf99ee3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -28,11 +28,13 @@ def cycle_over(objs): for idx, obj in enumerate(objs): yield obj, objs[:idx] + objs[idx + 1:] + def int_dtypes(): yield from iter( (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) ) + def float_dtypes(): yield from iter((torch.float32, torch.float, torch.float64, torch.double))