Skip to content

Commit 6155808

Browse files
authored
[proto] Improvements for functional API and tests (#6187)
* Added base tests for rotate_image_tensor * Updated resize_image_tensor API and tests and fixed a bug with max_size * Refactored and modified private api for resize functional op * Fixed failures * More updates * Updated proto functional op: resize_image_* * Added max_size arg to resize_bounding_box and updated basic tests * Update functional.py * Reverted fill/center order for rotate Other nits
1 parent aeafa91 commit 6155808

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,32 +201,58 @@ def horizontal_flip_bounding_box():
201201

202202
@register_kernel_info_from_sample_inputs_fn
203203
def resize_image_tensor():
204-
for image, interpolation in itertools.product(
204+
for image, interpolation, max_size, antialias in itertools.product(
205205
make_images(),
206-
[
207-
F.InterpolationMode.BILINEAR,
208-
F.InterpolationMode.NEAREST,
209-
],
206+
[F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation
207+
[None, 34], # max_size
208+
[False, True], # antialias
210209
):
210+
211+
if antialias and interpolation == F.InterpolationMode.NEAREST:
212+
continue
213+
211214
height, width = image.shape[-2:]
212215
for size in [
213216
(height, width),
214217
(int(height * 0.75), int(width * 1.25)),
215218
]:
216-
yield SampleInput(image, size=size, interpolation=interpolation)
219+
if max_size is not None:
220+
size = [size[0]]
221+
yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
217222

218223

219224
@register_kernel_info_from_sample_inputs_fn
220225
def resize_bounding_box():
221-
for bounding_box in make_bounding_boxes():
226+
for bounding_box, max_size in itertools.product(
227+
make_bounding_boxes(),
228+
[None, 34], # max_size
229+
):
222230
height, width = bounding_box.image_size
223231
for size in [
224232
(height, width),
225233
(int(height * 0.75), int(width * 1.25)),
226234
]:
235+
if max_size is not None:
236+
size = [size[0]]
227237
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
228238

229239

240+
@register_kernel_info_from_sample_inputs_fn
241+
def resize_segmentation_mask():
242+
for mask, max_size in itertools.product(
243+
make_segmentation_masks(),
244+
[None, 34], # max_size
245+
):
246+
height, width = mask.shape[-2:]
247+
for size in [
248+
(height, width),
249+
(int(height * 0.75), int(width * 1.25)),
250+
]:
251+
if max_size is not None:
252+
size = [size[0]]
253+
yield SampleInput(mask, size=size, max_size=max_size)
254+
255+
230256
@register_kernel_info_from_sample_inputs_fn
231257
def affine_image_tensor():
232258
for image, angle, translate, scale, shear in itertools.product(
@@ -284,6 +310,22 @@ def affine_segmentation_mask():
284310
)
285311

286312

313+
@register_kernel_info_from_sample_inputs_fn
314+
def rotate_image_tensor():
315+
for image, angle, expand, center, fill in itertools.product(
316+
make_images(extra_dims=((), (4,))),
317+
[-87, 15, 90], # angle
318+
[True, False], # expand
319+
[None, [12, 23]], # center
320+
[None, [128]], # fill
321+
):
322+
if center is not None and expand:
323+
# Skip warning: The provided center argument is ignored if expand is True
324+
continue
325+
326+
yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill)
327+
328+
287329
@register_kernel_info_from_sample_inputs_fn
288330
def rotate_bounding_box():
289331
for bounding_box, angle, expand, center in itertools.product(

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import torch
77
from torchvision.prototype import features
88
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
9-
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode
9+
from torchvision.transforms.functional import (
10+
pil_modes_mapping,
11+
_get_inverse_affine_matrix,
12+
InterpolationMode,
13+
_compute_output_size,
14+
)
1015

1116
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
1217

@@ -42,14 +47,12 @@ def resize_image_tensor(
4247
max_size: Optional[int] = None,
4348
antialias: Optional[bool] = None,
4449
) -> torch.Tensor:
45-
# TODO: use _compute_output_size to enable max_size option
46-
max_size # ununsed right now
47-
new_height, new_width = size
4850
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
51+
new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size)
4952
batch_shape = image.shape[:-3]
5053
return _FT.resize(
5154
image.reshape((-1, num_channels, old_height, old_width)),
52-
size=size,
55+
size=[new_height, new_width],
5356
interpolation=interpolation.value,
5457
antialias=antialias,
5558
).reshape(batch_shape + (num_channels, new_height, new_width))
@@ -61,8 +64,11 @@ def resize_image_pil(
6164
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
6265
max_size: Optional[int] = None,
6366
) -> PIL.Image.Image:
64-
# TODO: use _compute_output_size to enable max_size option
65-
max_size # ununsed right now
67+
if isinstance(size, int):
68+
size = [size, size]
69+
# Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]"
70+
size: List[int] = list(size)
71+
size = _compute_output_size(img.size[::-1], size=size, max_size=max_size)
6672
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])
6773

6874

@@ -72,10 +78,11 @@ def resize_segmentation_mask(
7278
return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
7379

7480

75-
# TODO: handle max_size
76-
def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
81+
def resize_bounding_box(
82+
bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None
83+
) -> torch.Tensor:
7784
old_height, old_width = image_size
78-
new_height, new_width = size
85+
new_height, new_width = _compute_output_size(image_size, size=size, max_size=max_size)
7986
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
8087
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
8188

0 commit comments

Comments
 (0)