Skip to content

Commit 5f33f4a

Browse files
committed
fix resize
1 parent d0394b7 commit 5f33f4a

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ def resize_image_tensor(
143143
)
144144

145145
if need_cast:
146-
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
147-
image = image.clamp_(min=0, max=255)
146+
# bicubic interpolation can overshoot
147+
if interpolation == InterpolationMode.BICUBIC:
148+
image = image.clamp_(min=0, max=_FT._max_value(dtype))
148149
image = image.round_().to(dtype=dtype)
149150

150151
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))

torchvision/transforms/functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ def resize(
458458

459459
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
460460

461-
if interpolation == "bicubic" and out_dtype == torch.uint8:
462-
img = img.clamp(min=0, max=255)
461+
if interpolation == "bicubic" and out_dtype not in (torch.float32, torch.float64):
462+
img = img.clamp(min=0, max=_max_value(out_dtype))
463463

464464
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
465465

0 commit comments

Comments
 (0)