Skip to content

Commit 2a5fbcd

Browse files
vfdev-5datumbox
andauthored
[proto] Fixed F.perspective signature (#6617)
Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent cffb7f7 commit 2a5fbcd

File tree

5 files changed

+15
-37
lines changed

5 files changed

+15
-37
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,14 @@ def sample_inputs(self, *types):
113113
features.Mask: F.pad_mask,
114114
},
115115
),
116-
# FIXME:
117-
# RuntimeError: perspective() is missing value for argument 'startpoints'.
118-
# Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints,
119-
# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
120-
# Union(float[], float, int, NoneType) fill=None) -> Tensor
121-
#
122-
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
123-
# DispatcherInfo(
124-
# F.perspective,
125-
# kernels={
126-
# features.Image: F.perspective_image_tensor,
127-
# features.BoundingBox: F.perspective_bounding_box,
128-
# features.Mask: F.perspective_mask,
129-
# },
130-
# ),
116+
DispatcherInfo(
117+
F.perspective,
118+
kernels={
119+
features.Image: F.perspective_image_tensor,
120+
features.BoundingBox: F.perspective_bounding_box,
121+
features.Mask: F.perspective_mask,
122+
},
123+
),
131124
DispatcherInfo(
132125
F.center_crop,
133126
kernels={

test/test_prototype_transforms.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -894,21 +894,8 @@ def test__get_params(self, mocker):
894894
params = transform._get_params(image)
895895

896896
h, w = image.image_size
897-
assert len(params["startpoints"]) == 4
898-
for x, y in params["startpoints"]:
899-
assert x in (0, w - 1)
900-
assert y in (0, h - 1)
901-
902-
assert len(params["endpoints"]) == 4
903-
for (x, y), name in zip(params["endpoints"], ["tl", "tr", "br", "bl"]):
904-
if "t" in name:
905-
assert 0 <= y <= int(dscale * h // 2), (x, y, name)
906-
if "b" in name:
907-
assert h - int(dscale * h // 2) - 1 <= y <= h, (x, y, name)
908-
if "l" in name:
909-
assert 0 <= x <= int(dscale * w // 2), (x, y, name)
910-
if "r" in name:
911-
assert w - int(dscale * w // 2) - 1 <= x <= w, (x, y, name)
897+
assert "perspective_coeffs" in params
898+
assert len(params["perspective_coeffs"]) == 8
912899

913900
@pytest.mark.parametrize("distortion_scale", [0.1, 0.7])
914901
def test__transform(self, distortion_scale, mocker):

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
12321232
np.max(transformed_points[:, 1]),
12331233
]
12341234
out_bbox = features.BoundingBox(
1235-
out_bbox,
1235+
np.array(out_bbox),
12361236
format=features.BoundingBoxFormat.XYXY,
12371237
image_size=bbox.image_size,
12381238
dtype=bbox.dtype,

torchvision/prototype/transforms/_geometry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchvision.ops.boxes import box_iou
1010
from torchvision.prototype import features
1111
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
12+
from torchvision.transforms.functional import _get_perspective_coeffs
1213

1314
from typing_extensions import Literal
1415

@@ -556,7 +557,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
556557
]
557558
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
558559
endpoints = [topleft, topright, botright, botleft]
559-
return dict(startpoints=startpoints, endpoints=endpoints)
560+
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
561+
return dict(perspective_coeffs=perspective_coeffs)
560562

561563
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
562564
fill = self.fill[type(inpt)]

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torchvision.transforms.functional import (
1010
_compute_resized_output_size,
1111
_get_inverse_affine_matrix,
12-
_get_perspective_coeffs,
1312
InterpolationMode,
1413
pil_modes_mapping,
1514
pil_to_tensor,
@@ -876,13 +875,10 @@ def perspective_mask(
876875

877876
def perspective(
878877
inpt: features.DType,
879-
startpoints: List[List[int]],
880-
endpoints: List[List[int]],
878+
perspective_coeffs: List[float],
881879
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
882880
fill: Optional[Union[int, float, List[float]]] = None,
883881
) -> features.DType:
884-
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
885-
886882
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
887883
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
888884
elif isinstance(inpt, features._Feature):

0 commit comments

Comments
 (0)