From ac3ab8108d2982acb4ad85f0613c4c7c661e7515 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Nov 2022 22:23:25 +0100 Subject: [PATCH 1/2] replace tensor division with scalar division and tensor multiplication --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- torchvision/prototype/transforms/functional/_meta.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index c70d746d8b9..5518322080f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -180,7 +180,7 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r) h = hr.add_(hg).add_(hb) - h = h.div_(6.0).add_(1.0).fmod_(1.0) + h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0) return torch.stack((h, s, maxc), dim=-3) @@ -287,7 +287,7 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: if image.is_floating_point(): levels = 1 << bits - return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) + return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels) else: num_value_bits = _num_value_bits(image.dtype) if bits >= num_value_bits: diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 81ccd08de5d..8bcd8176733 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -367,7 +367,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f else: # int to float if float_output: - return image.to(dtype).div_(_FT._max_value(image.dtype)) + return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype)) # int to int num_value_bits_input = _num_value_bits(image.dtype) From d02d3982df047875eecc6bb9dcc9218aff96adde Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Nov 2022 00:15:53 +0100 Subject: [PATCH 2/2] fix consistency test tolerances --- test/test_prototype_transforms_consistency.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index a23783b0037..4cba4265ac3 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -163,6 +163,8 @@ def __init__( ArgsKwargs(torch.uint8), ], supports_pil=False, + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.ToPILImage,