Skip to content

Commit 9b0da0c

Browse files
authored
replace tensor division with scalar division and tensor multiplication (#6903)
* replace tensor division with scalar division and tensor multiplication * fix consistency test tolerances
1 parent 4508c84 commit 9b0da0c

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def __init__(
163163
ArgsKwargs(torch.uint8),
164164
],
165165
supports_pil=False,
166+
# Use default tolerances of `torch.testing.assert_close`
167+
closeness_kwargs=dict(rtol=None, atol=None),
166168
),
167169
ConsistencyConfig(
168170
prototype_transforms.ToPILImage,

torchvision/prototype/transforms/functional/_color.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
180180
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
181181

182182
h = hr.add_(hg).add_(hb)
183-
h = h.div_(6.0).add_(1.0).fmod_(1.0)
183+
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
184184
return torch.stack((h, s, maxc), dim=-3)
185185

186186

@@ -287,7 +287,7 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
287287
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
288288
if image.is_floating_point():
289289
levels = 1 << bits
290-
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
290+
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)
291291
else:
292292
num_value_bits = _num_value_bits(image.dtype)
293293
if bits >= num_value_bits:

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
367367
else:
368368
# int to float
369369
if float_output:
370-
return image.to(dtype).div_(_FT._max_value(image.dtype))
370+
return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype))
371371

372372
# int to int
373373
num_value_bits_input = _num_value_bits(image.dtype)

0 commit comments

Comments
 (0)