diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 6c33bd9a548..70695a3e7c5 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -138,6 +138,36 @@ def test_ten_crop(self): (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) + def test_resize(self): + height = random.randint(24, 32) * 2 + width = random.randint(24, 32) * 2 + img = torch.ones(3, height, width) + img_clone = img.clone() + modes = ["bilinear", "nearest", "bicubic"] + + for mode in modes: + # (Int) for resizing + output_size = random.randint(5, 12) * 2 + result = F_t.resize(img, output_size, interpolation=mode) + if height < width: + self.assertEqual(output_size, result.shape[1]) + else: + self.assertEqual(output_size, result.shape[2]) + + # (Int, Int) for resizing + output_size = (random.randint(5, 12) * 2, random.randint(5, 12) * 2) + result = F_t.resize(img, output_size, interpolation=mode) + self.assertEqual((output_size[0], output_size[1]), (result.shape[1], result.shape[2])) + + # checking input tensor is not mutated + self.assertTrue(torch.equal(img, img_clone)) + + # checking overshooting for bicubic + output_size = (random.randint(5, 12) * 2, random.randint(5, 12) * 2) + result = F_t.resize(img, output_size, interpolation="bicubic") + clamped_tensor = result.clamp(min=0, max=255) + self.assertTrue(torch.equal(result, clamped_tensor)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index bd56ae3a131..f63e5bf7cf7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,6 @@ import torch import torchvision.transforms.functional as F +import torch.nn.functional as Fn def vflip(img_tensor): @@ -219,3 +220,40 @@ def ten_crop(img, size, vertical_flip=False): def _blend(img1, img2, ratio): bound = 1 if img1.dtype.is_floating_point else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def resize(img, size, interpolation="bilinear"): + r"""Resize the input Tensor Image to the given size. + + Args: + img (Tensor): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (string, optional): Desired interpolation ["bilinear", "nearest", "bicubic"]. Default is + ``bilinear`` + + Returns: + Tensor: Resized image Tensor. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + if isinstance(size, int): + w, h = img.shape[2], img.shape[1] + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation) + else: + oh = size + ow = int(size * w / h) + out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation) + else: + out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation) + + return(out_img.clamp(min=0, max=255).squeeze(0))