diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index cd3ae5a0a82..95f7383a4f7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -331,6 +331,23 @@ def test_resize(self): pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation)) + def test_resized_crop(self): + # test values of F.resized_crop in several cases: + # 1) resize to the same size, crop to the same size => should be identity + tensor, _ = self._create_data(26, 36) + for i in [0, 2, 3]: + out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i) + self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) + + # 2) resize by half and crop a TL corner + tensor, _ = self._create_data(26, 36) + out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0) + expected_out_tensor = tensor[:, :20:2, :30:2] + self.assertTrue( + expected_out_tensor.equal(out_tensor), + msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 9d70744dfc1..fbd3331a490 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -245,6 +245,25 @@ def test_resize(self): s_resized_tensor = script_transform(tensor) self.assertTrue(s_resized_tensor.equal(resized_tensor)) + def test_resized_crop(self): + tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8) + + scale = (0.7, 1.2) + ratio = (0.75, 1.333) + + for size in [(32, ), [32, ], [32, 32], (32, 32)]: + for interpolation in [NEAREST, BILINEAR, BICUBIC]: + transform = T.RandomResizedCrop( + size=size, scale=scale, ratio=ratio, interpolation=interpolation + ) + s_transform = torch.jit.script(transform) + + torch.manual_seed(12) + out1 = transform(tensor) + torch.manual_seed(12) + out2 = s_transform(tensor) + self.assertTrue(out1.equal(out2)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 72ca54d7260..4b38c7bb92e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -439,24 +439,26 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: return crop(img, crop_top, crop_left, crop_height, crop_width) -def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): - """Crop the given PIL Image and resize it to desired size. +def resized_crop( + img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR +) -> Tensor: + """Crop the given image and resize it to desired size. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. Args: - img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. size (sequence or int): Desired output size. Same semantics as ``resize``. - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR``. + interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. Returns: - PIL Image: Cropped image. + PIL Image or Tensor: Cropped image. """ - assert F_pil._is_pil_image(img), 'img should be PIL Image' img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index be0b7b3a622..59cf6bc2764 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -532,7 +532,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: elif len(size) < 2: size_w, size_h = size[0], size[0] else: - size_w, size_h = size[0], size[1] + size_w, size_h = size[1], size[0] # Convention (h, w) if isinstance(size, int) or len(size) < 2: if w < h: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9f4ad8175c6..2df2befcb33 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -687,8 +687,10 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomResizedCrop(object): - """Crop the given PIL Image to random size and aspect ratio. +class RandomResizedCrop(torch.nn.Module): + """Crop the given image to random size and aspect ratio. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop @@ -696,31 +698,45 @@ class RandomResizedCrop(object): This is popularly used to train the Inception networks. Args: - size: expected output size of each edge - scale: range of size of the origin size cropped - ratio: range of aspect ratio of the origin aspect ratio cropped - interpolation: Default: PIL.Image.BILINEAR + size (int or sequence): expected output size of each edge. If size is an + int instead of sequence like (h, w), a square output size ``(size, size)`` is + made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). + scale (tuple of float): range of size of the origin size cropped + ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped. + interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR`` """ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): - if isinstance(size, (tuple, list)): - self.size = size + super().__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) else: - self.size = (size, size) + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + self.size = size + + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - warnings.warn("range should be of kind (min, max)") + warnings.warn("Scale and ratio should be of kind (min, max)") self.interpolation = interpolation self.scale = scale self.ratio = ratio @staticmethod - def get_params(img, scale, ratio): + def get_params( + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float] + ) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: - img (PIL Image): Image to be cropped. - scale (tuple): range of size of the origin size cropped + img (PIL Image or Tensor): Input image. + scale (tuple): range of scale of the origin size cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped Returns: @@ -731,24 +747,26 @@ def get_params(img, scale, ratio): area = height * width for _ in range(10): - target_area = random.uniform(*scale) * area - log_ratio = (math.log(ratio[0]), math.log(ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) + target_area = area * torch.empty(1).uniform_(*scale).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = random.randint(0, height - h) - j = random.randint(0, width - w) + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) - if (in_ratio < min(ratio)): + if in_ratio < min(ratio): w = width h = int(round(w / min(ratio))) - elif (in_ratio > max(ratio)): + elif in_ratio > max(ratio): h = height w = int(round(h * max(ratio))) else: # whole image @@ -758,13 +776,13 @@ def get_params(img, scale, ratio): j = (width - w) // 2 return i, j, h, w - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be cropped and resized. + img (PIL Image or Tensor): Image to be cropped and resized. Returns: - PIL Image: Randomly cropped and resized image. + PIL Image or Tensor: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)