diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 8852f9864c8..dc9c3e57d7a 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False): } -def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6): +def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6): return { (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False}, } @@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size [-1, 0, spatial_size[1]], [0, 1, 0], ], - dtype="float32", + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + ) return expected_bboxes @@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor(): def sample_inputs_resize_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): for size in _get_resize_sizes(bounding_box_loader.spatial_size): - yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) + yield ArgsKwargs(bounding_box_loader, spatial_size=bounding_box_loader.spatial_size, size=size) def sample_inputs_resize_mask(): @@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size= [new_width / old_width, 0, 0], [0, new_height / old_height, 0], ], - dtype="float32", + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", ) expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix + bounding_box, + format=bounding_box.format, + spatial_size=(new_height, new_width), + affine_matrix=affine_matrix, ) return expected_bboxes, (new_height, new_width) def reference_inputs_resize_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders( - formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,)) - ): + for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))): for size in _get_resize_sizes(bounding_box_loader.spatial_size): yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) @@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center): return true_matrix -def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix): - def transform(bbox, affine_matrix_, format_): +def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): + def transform(bbox, affine_matrix_, format_, spatial_size_): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype if not torch.is_floating_point(bbox): bbox = bbox.float() bbox_xyxy = F.convert_format_bounding_box( - bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + bbox.as_subclass(torch.Tensor), + old_format=format_, + new_format=datapoints.BoundingBoxFormat.XYXY, + inplace=True, ) points = np.array( [ @@ -573,12 +579,15 @@ def transform(bbox, affine_matrix_, format_): out_bbox = F.convert_format_bounding_box( out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True ) - return out_bbox.to(dtype=in_dtype) + # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 + out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_) + out_bbox = out_bbox.to(dtype=in_dtype) + return out_bbox if bounding_box.ndim < 2: bounding_box = [bounding_box] - expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box] + expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box] if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) else: @@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) affine_matrix = affine_matrix[:2, :] - expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + ) return expected_bboxes @@ -643,9 +654,6 @@ def sample_inputs_affine_video(): sample_inputs_fn=sample_inputs_affine_bounding_box, reference_fn=reference_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box, - closeness_kwargs={ - (("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0), - }, test_marks=[ xfail_jit_python_scalar_arg("shear"), ], @@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): [1, 0, 0], [0, -1, spatial_size[0]], ], - dtype="float32", + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + ) return expected_bboxes @@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box(): ) +def reference_inputs_rotate_bounding_box(): + for bounding_box_loader, angle in itertools.product( + make_bounding_box_loaders(extra_dims=((), (4,))), _ROTATE_ANGLES + ): + yield ArgsKwargs( + bounding_box_loader, + format=bounding_box_loader.format, + spatial_size=bounding_box_loader.spatial_size, + angle=angle, + ) + + # TODO: add samples with expand=True and center + + +def reference_rotate_bounding_box(bounding_box, *, format, spatial_size, angle, expand=False, center=None): + + if center is None: + center = [spatial_size[1] * 0.5, spatial_size[0] * 0.5] + + a = np.cos(angle * np.pi / 180.0) + b = np.sin(angle * np.pi / 180.0) + cx = center[0] + cy = center[1] + affine_matrix = np.array( + [ + [a, b, cx - cx * a - b * cy], + [-b, a, cy + cx * b - a * cy], + ], + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + ) + + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + ) + return expected_bboxes, spatial_size + + def sample_inputs_rotate_mask(): for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): yield ArgsKwargs(mask_loader, angle=15.0) @@ -834,9 +881,11 @@ def sample_inputs_rotate_video(): KernelInfo( F.rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box, + reference_fn=reference_rotate_bounding_box, + reference_inputs_fn=reference_inputs_rotate_bounding_box, closeness_kwargs={ - **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), - **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), + **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), }, ), KernelInfo( @@ -897,17 +946,19 @@ def sample_inputs_crop_video(): def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width): - affine_matrix = np.array( [ [1, 0, -left], [0, 1, -top], ], - dtype="float32", + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) - return expected_bboxes, (height, width) + spatial_size = (height, width) + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + ) + return expected_bboxes, spatial_size def reference_inputs_crop_bounding_box(): @@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p [1, 0, left], [0, 1, top], ], - dtype="float32", + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", ) height = spatial_size[0] + top + bottom width = spatial_size[1] + left + right - expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix + ) return expected_bboxes, (height, width) @@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, + spatial_size=bounding_box_loader.spatial_size, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0], ) format = datapoints.BoundingBoxFormat.XYXY + loader = make_bounding_box_loader(format=format) yield ArgsKwargs( - make_bounding_box_loader(format=format), format=format, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS + loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS ) @@ -1269,13 +1324,17 @@ def sample_inputs_perspective_video(): **pil_reference_pixel_difference(2, mae=True), **cuda_vs_cpu_pixel_difference(), **float32_vs_uint8_pixel_difference(), - **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), - **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), }, ), KernelInfo( F.perspective_bounding_box, sample_inputs_fn=sample_inputs_perspective_bounding_box, + closeness_kwargs={ + **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), + **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6), + }, ), KernelInfo( F.perspective_mask, @@ -1292,8 +1351,8 @@ def sample_inputs_perspective_video(): sample_inputs_fn=sample_inputs_perspective_video, closeness_kwargs={ **cuda_vs_cpu_pixel_difference(), - **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), - **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), }, ), ] @@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, + spatial_size=bounding_box_loader.spatial_size, displacement=displacement, ) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 7030d2d1b2e..167b839eef9 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -146,7 +146,7 @@ class TestSmoke: (transforms.RandomZoomOut(p=1.0), None), (transforms.Resize([16, 16], antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None), - (transforms.ClampBoundingBoxes(), None), + (transforms.ClampBoundingBox(), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertDtype(), None), (transforms.GaussianBlur(kernel_size=3), None), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 1650d03de73..bb4b6ef1158 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -25,7 +25,7 @@ from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding -from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box +from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box from torchvision.transforms.functional import _get_perspective_coeffs @@ -257,16 +257,17 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device): @reference_inputs def test_against_reference(self, test_id, info, args_kwargs): (input, *other_args), kwargs = args_kwargs.load("cpu") - input = input.as_subclass(torch.Tensor) - actual = info.kernel(input, *other_args, **kwargs) + actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) + # We intnetionally don't unwrap the input of the reference function in order for it to have access to all + # metadata regardless of whether the kernel takes it explicitly or not expected = info.reference_fn(input, *other_args, **kwargs) assert_close( actual, expected, **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), - msg=parametrized_error_message(*other_args, **kwargs), + msg=parametrized_error_message(input, *other_args, **kwargs), ) @make_info_args_kwargs_parametrization( @@ -682,6 +683,10 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), ] + expected_bboxes = clamp_bounding_box( + datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) + ).tolist() + output_boxes = F.affine_bounding_box( in_boxes, format=format, @@ -762,7 +767,8 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width) + out_bbox = clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format)) + return out_bbox, (height, width) spatial_size = (32, 38) @@ -839,6 +845,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): [69.27564928, 12.39339828, 74.93250353, 18.05025253], [18.36396103, 1.07968978, 46.64823228, 29.36396103], ] + expected_bboxes = clamp_bounding_box( + datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) + ).tolist() output_boxes, _ = F.rotate_bounding_box( in_boxes, @@ -905,6 +914,10 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, if format != datapoints.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) + expected_bboxes = clamp_bounding_box( + datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) + ).tolist() + output_boxes, output_spatial_size = F.crop_bounding_box( in_boxes, format, @@ -1121,7 +1134,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box(out_bbox, new_format=bbox.format) + return clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format)) spatial_size = (32, 38) @@ -1134,6 +1147,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): output_bboxes = F.perspective_bounding_box( bboxes.as_subclass(torch.Tensor), format=bboxes.format, + spatial_size=bboxes.spatial_size, startpoints=None, endpoints=None, coefficients=pcoeffs, @@ -1178,6 +1192,7 @@ def _compute_expected_bbox(bbox, output_size_): ] out_bbox = torch.tensor(out_bbox) out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) + out_bbox = clamp_bounding_box(out_bbox, format=format_, spatial_size=output_size) return out_bbox.to(dtype=dtype, device=bbox.device) for bboxes in make_bounding_boxes(extra_dims=((4,),)): @@ -1201,7 +1216,8 @@ def _compute_expected_bbox(bbox, output_size_): expected_bboxes = torch.stack(expected_bboxes) else: expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_boxes, expected_bboxes) + + torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_spatial_size, output_size) diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 718c3c2ade8..7c44a9e4b26 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -81,7 +81,10 @@ def resize( # type: ignore[override] antialias: Optional[Union[str, bool]] = "warn", ) -> BoundingBox: output, spatial_size = self._F.resize_bounding_box( - self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size + self.as_subclass(torch.Tensor), + spatial_size=self.spatial_size, + size=size, + max_size=max_size, ) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) @@ -178,6 +181,7 @@ def perspective( output = self._F.perspective_bounding_box( self.as_subclass(torch.Tensor), format=self.format, + spatial_size=self.spatial_size, startpoints=startpoints, endpoints=endpoints, coefficients=coefficients, @@ -190,5 +194,7 @@ def elastic( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, ) -> BoundingBox: - output = self._F.elastic_bounding_box(self.as_subclass(torch.Tensor), self.format, displacement) + output = self._F.elastic_bounding_box( + self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement + ) return BoundingBox.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 132edb1b6fc..a640d726cef 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -41,7 +41,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype +from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype from ._misc import ( GaussianBlur, Identity, diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 75085fff6d5..79bd5549b2e 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -42,7 +42,7 @@ def _transform( ConvertImageDtype = ConvertDtype -class ClampBoundingBoxes(Transform): +class ClampBoundingBox(Transform): _transformed_types = (datapoints.BoundingBox,) def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6fcd87ac91c..5fa6cbb4873 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -22,7 +22,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import convert_format_bounding_box, get_spatial_size_image_pil +from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil from ._utils import is_simple_tensor @@ -580,8 +580,9 @@ def affine_image_pil( return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) -def _affine_bounding_box_xyxy( +def _affine_bounding_box_with_expand( bounding_box: torch.Tensor, + format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], angle: Union[int, float], translate: List[float], @@ -593,6 +594,17 @@ def _affine_bounding_box_xyxy( if bounding_box.numel() == 0: return bounding_box, spatial_size + original_shape = bounding_box.shape + original_dtype = bounding_box.dtype + bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() + dtype = bounding_box.dtype + device = bounding_box.device + bounding_box = ( + convert_format_bounding_box( + bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + ) + ).reshape(-1, 4) + angle, translate, shear, center = _affine_parse_args( angle, translate, scale, shear, InterpolationMode.NEAREST, center ) @@ -601,9 +613,6 @@ def _affine_bounding_box_xyxy( height, width = spatial_size center = [width * 0.5, height * 0.5] - dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 - device = bounding_box.device - affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) transposed_affine_matrix = ( torch.tensor( @@ -651,7 +660,13 @@ def _affine_bounding_box_xyxy( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) spatial_size = (new_height, new_width) - return out_bboxes.to(bounding_box.dtype), spatial_size + out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size) + out_bboxes = convert_format_bounding_box( + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + out_bboxes = out_bboxes.to(original_dtype) + return out_bboxes, spatial_size def affine_bounding_box( @@ -664,19 +679,18 @@ def affine_bounding_box( shear: List[float], center: Optional[List[float]] = None, ) -> torch.Tensor: - original_shape = bounding_box.shape - - bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) - ).reshape(-1, 4) - - out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) - - # out_bboxes should be of shape [N boxes, 4] - - return convert_format_bounding_box( - out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True - ).reshape(original_shape) + out_box, _ = _affine_bounding_box_with_expand( + bounding_box, + format=format, + spatial_size=spatial_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + expand=False, + ) + return out_box def affine_mask( @@ -852,14 +866,10 @@ def rotate_bounding_box( warnings.warn("The provided center argument has no effect on the result if expand is True") center = None - original_shape = bounding_box.shape - bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) - ).reshape(-1, 4) - - out_bboxes, spatial_size = _affine_bounding_box_xyxy( + return _affine_bounding_box_with_expand( bounding_box, - spatial_size, + format=format, + spatial_size=spatial_size, angle=-angle, translate=[0.0, 0.0], scale=1.0, @@ -868,13 +878,6 @@ def rotate_bounding_box( expand=expand, ) - return ( - convert_format_bounding_box( - out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True - ).reshape(original_shape), - spatial_size, - ) - def rotate_mask( mask: torch.Tensor, @@ -1112,8 +1115,9 @@ def pad_bounding_box( height, width = spatial_size height += top + bottom width += left + right + spatial_size = (height, width) - return bounding_box, (height, width) + return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size def pad_video( @@ -1185,8 +1189,9 @@ def crop_bounding_box( sub = [left, top, 0, 0] bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device) + spatial_size = (height, width) - return bounding_box, (height, width) + return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: @@ -1332,6 +1337,7 @@ def perspective_image_pil( def perspective_bounding_box( bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], coefficients: Optional[List[float]] = None, @@ -1342,6 +1348,7 @@ def perspective_bounding_box( perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) original_shape = bounding_box.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_box bounding_box = ( convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) @@ -1408,7 +1415,11 @@ def perspective_bounding_box( transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) + out_bboxes = clamp_bounding_box( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) # out_bboxes should be of shape [N boxes, 4] @@ -1549,6 +1560,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to def elastic_bounding_box( bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, + spatial_size: Tuple[int, int], displacement: torch.Tensor, ) -> torch.Tensor: if bounding_box.numel() == 0: @@ -1562,14 +1574,11 @@ def elastic_bounding_box( displacement = displacement.to(dtype=dtype, device=device) original_shape = bounding_box.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_box bounding_box = ( convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) - # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it - # Or add spatial_size arg and check displacement shape - spatial_size = displacement.shape[-3], displacement.shape[-2] - id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid @@ -1588,7 +1597,11 @@ def elastic_bounding_box( transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) + out_bboxes = clamp_bounding_box( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) return convert_format_bounding_box( out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True @@ -1796,7 +1809,7 @@ def resized_crop_bounding_box( size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width) - return resize_bounding_box(bounding_box, (height, width), size) + return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size) def resized_crop_mask( diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 47860451774..5e32516fb8a 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -245,12 +245,17 @@ def _clamp_bounding_box( ) -> torch.Tensor: # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth + in_dtype = bounding_box.dtype + bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() xyxy_boxes = convert_format_bounding_box( - bounding_box.clone(), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) - return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) + out_boxes = convert_format_bounding_box( + xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True + ) + return out_boxes.to(in_dtype) def clamp_bounding_box(