Skip to content

Commit 0de3e5b

Browse files
authored
[proto] Aligned fill, padding typehints between Features and F (#6616)
1 parent 841b9a1 commit 0de3e5b

File tree

5 files changed

+45
-52
lines changed

5 files changed

+45
-52
lines changed

torchvision/prototype/features/_bounding_box.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def resized_crop(
115115
def pad(
116116
self,
117117
padding: Union[int, Sequence[int]],
118-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
118+
fill: Optional[Union[int, float, List[float]]] = None,
119119
padding_mode: str = "constant",
120120
) -> BoundingBox:
121121
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
@@ -137,7 +137,7 @@ def rotate(
137137
angle: float,
138138
interpolation: InterpolationMode = InterpolationMode.NEAREST,
139139
expand: bool = False,
140-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
140+
fill: Optional[Union[int, float, List[float]]] = None,
141141
center: Optional[List[float]] = None,
142142
) -> BoundingBox:
143143
output = self._F.rotate_bounding_box(
@@ -165,7 +165,7 @@ def affine(
165165
scale: float,
166166
shear: List[float],
167167
interpolation: InterpolationMode = InterpolationMode.NEAREST,
168-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
168+
fill: Optional[Union[int, float, List[float]]] = None,
169169
center: Optional[List[float]] = None,
170170
) -> BoundingBox:
171171
output = self._F.affine_bounding_box(
@@ -184,7 +184,7 @@ def perspective(
184184
self,
185185
perspective_coeffs: List[float],
186186
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
187-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
187+
fill: Optional[Union[int, float, List[float]]] = None,
188188
) -> BoundingBox:
189189
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
190190
return BoundingBox.new_like(self, output, dtype=output.dtype)
@@ -193,7 +193,7 @@ def elastic(
193193
self,
194194
displacement: torch.Tensor,
195195
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
196-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
196+
fill: Optional[Union[int, float, List[float]]] = None,
197197
) -> BoundingBox:
198198
output = self._F.elastic_bounding_box(self, self.format, displacement)
199199
return BoundingBox.new_like(self, output, dtype=output.dtype)

torchvision/prototype/features/_feature.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def resized_crop(
153153

154154
def pad(
155155
self,
156-
padding: Union[int, Sequence[int]],
157-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
156+
padding: Union[int, List[int]],
157+
fill: Optional[Union[int, float, List[float]]] = None,
158158
padding_mode: str = "constant",
159159
) -> _Feature:
160160
return self
@@ -164,7 +164,7 @@ def rotate(
164164
angle: float,
165165
interpolation: InterpolationMode = InterpolationMode.NEAREST,
166166
expand: bool = False,
167-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
167+
fill: Optional[Union[int, float, List[float]]] = None,
168168
center: Optional[List[float]] = None,
169169
) -> _Feature:
170170
return self
@@ -176,7 +176,7 @@ def affine(
176176
scale: float,
177177
shear: List[float],
178178
interpolation: InterpolationMode = InterpolationMode.NEAREST,
179-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
179+
fill: Optional[Union[int, float, List[float]]] = None,
180180
center: Optional[List[float]] = None,
181181
) -> _Feature:
182182
return self
@@ -185,15 +185,15 @@ def perspective(
185185
self,
186186
perspective_coeffs: List[float],
187187
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
188-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
188+
fill: Optional[Union[int, float, List[float]]] = None,
189189
) -> _Feature:
190190
return self
191191

192192
def elastic(
193193
self,
194194
displacement: torch.Tensor,
195195
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
196-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
196+
fill: Optional[Union[int, float, List[float]]] = None,
197197
) -> _Feature:
198198
return self
199199

torchvision/prototype/features/_image.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Any, cast, List, Optional, Sequence, Tuple, Union
4+
from typing import Any, cast, List, Optional, Tuple, Union
55

66
import torch
77
from torchvision._utils import StrEnum
@@ -180,16 +180,10 @@ def resized_crop(
180180

181181
def pad(
182182
self,
183-
padding: Union[int, Sequence[int]],
184-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
183+
padding: Union[int, List[int]],
184+
fill: Optional[Union[int, float, List[float]]] = None,
185185
padding_mode: str = "constant",
186186
) -> Image:
187-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
188-
if not isinstance(padding, int):
189-
padding = list(padding)
190-
191-
fill = self._F._geometry._convert_fill_arg(fill)
192-
193187
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
194188
return Image.new_like(self, output)
195189

@@ -198,11 +192,9 @@ def rotate(
198192
angle: float,
199193
interpolation: InterpolationMode = InterpolationMode.NEAREST,
200194
expand: bool = False,
201-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
195+
fill: Optional[Union[int, float, List[float]]] = None,
202196
center: Optional[List[float]] = None,
203197
) -> Image:
204-
fill = self._F._geometry._convert_fill_arg(fill)
205-
206198
output = self._F._geometry.rotate_image_tensor(
207199
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
208200
)
@@ -215,11 +207,9 @@ def affine(
215207
scale: float,
216208
shear: List[float],
217209
interpolation: InterpolationMode = InterpolationMode.NEAREST,
218-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
210+
fill: Optional[Union[int, float, List[float]]] = None,
219211
center: Optional[List[float]] = None,
220212
) -> Image:
221-
fill = self._F._geometry._convert_fill_arg(fill)
222-
223213
output = self._F._geometry.affine_image_tensor(
224214
self,
225215
angle,
@@ -236,10 +226,8 @@ def perspective(
236226
self,
237227
perspective_coeffs: List[float],
238228
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
239-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
229+
fill: Optional[Union[int, float, List[float]]] = None,
240230
) -> Image:
241-
fill = self._F._geometry._convert_fill_arg(fill)
242-
243231
output = self._F._geometry.perspective_image_tensor(
244232
self, perspective_coeffs, interpolation=interpolation, fill=fill
245233
)
@@ -249,10 +237,8 @@ def elastic(
249237
self,
250238
displacement: torch.Tensor,
251239
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
252-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
240+
fill: Optional[Union[int, float, List[float]]] = None,
253241
) -> Image:
254-
fill = self._F._geometry._convert_fill_arg(fill)
255-
256242
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
257243
return Image.new_like(self, output)
258244

torchvision/prototype/features/_mask.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Sequence, Union
3+
from typing import List, Optional, Union
44

55
import torch
66
from torchvision.transforms import InterpolationMode
@@ -50,16 +50,10 @@ def resized_crop(
5050

5151
def pad(
5252
self,
53-
padding: Union[int, Sequence[int]],
54-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
53+
padding: Union[int, List[int]],
54+
fill: Optional[Union[int, float, List[float]]] = None,
5555
padding_mode: str = "constant",
5656
) -> Mask:
57-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
58-
if not isinstance(padding, int):
59-
padding = list(padding)
60-
61-
fill = self._F._geometry._convert_fill_arg(fill)
62-
6357
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
6458
return Mask.new_like(self, output)
6559

@@ -68,10 +62,10 @@ def rotate(
6862
angle: float,
6963
interpolation: InterpolationMode = InterpolationMode.NEAREST,
7064
expand: bool = False,
71-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
65+
fill: Optional[Union[int, float, List[float]]] = None,
7266
center: Optional[List[float]] = None,
7367
) -> Mask:
74-
output = self._F.rotate_mask(self, angle, expand=expand, center=center)
68+
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
7569
return Mask.new_like(self, output)
7670

7771
def affine(
@@ -81,7 +75,7 @@ def affine(
8175
scale: float,
8276
shear: List[float],
8377
interpolation: InterpolationMode = InterpolationMode.NEAREST,
84-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
78+
fill: Optional[Union[int, float, List[float]]] = None,
8579
center: Optional[List[float]] = None,
8680
) -> Mask:
8781
output = self._F.affine_mask(
@@ -90,6 +84,7 @@ def affine(
9084
translate=translate,
9185
scale=scale,
9286
shear=shear,
87+
fill=fill,
9388
center=center,
9489
)
9590
return Mask.new_like(self, output)
@@ -98,16 +93,16 @@ def perspective(
9893
self,
9994
perspective_coeffs: List[float],
10095
interpolation: InterpolationMode = InterpolationMode.NEAREST,
101-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
96+
fill: Optional[Union[int, float, List[float]]] = None,
10297
) -> Mask:
103-
output = self._F.perspective_mask(self, perspective_coeffs)
98+
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
10499
return Mask.new_like(self, output)
105100

106101
def elastic(
107102
self,
108103
displacement: torch.Tensor,
109104
interpolation: InterpolationMode = InterpolationMode.NEAREST,
110-
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
105+
fill: Optional[Union[int, float, List[float]]] = None,
111106
) -> Mask:
112-
output = self._F.elastic_mask(self, displacement)
107+
output = self._F.elastic_mask(self, displacement, fill=fill)
113108
return Mask.new_like(self, output, dtype=output.dtype)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def affine_mask(
379379
translate: List[float],
380380
scale: float,
381381
shear: List[float],
382+
fill: Optional[Union[int, float, List[float]]] = None,
382383
center: Optional[List[float]] = None,
383384
) -> torch.Tensor:
384385
if mask.ndim < 3:
@@ -394,6 +395,7 @@ def affine_mask(
394395
scale=scale,
395396
shear=shear,
396397
interpolation=InterpolationMode.NEAREST,
398+
fill=fill,
397399
center=center,
398400
)
399401

@@ -541,6 +543,7 @@ def rotate_mask(
541543
mask: torch.Tensor,
542544
angle: float,
543545
expand: bool = False,
546+
fill: Optional[Union[int, float, List[float]]] = None,
544547
center: Optional[List[float]] = None,
545548
) -> torch.Tensor:
546549
if mask.ndim < 3:
@@ -554,6 +557,7 @@ def rotate_mask(
554557
angle=angle,
555558
expand=expand,
556559
interpolation=InterpolationMode.NEAREST,
560+
fill=fill,
557561
center=center,
558562
)
559563

@@ -849,15 +853,19 @@ def perspective_bounding_box(
849853
).view(original_shape)
850854

851855

852-
def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
856+
def perspective_mask(
857+
mask: torch.Tensor,
858+
perspective_coeffs: List[float],
859+
fill: Optional[Union[int, float, List[float]]] = None,
860+
) -> torch.Tensor:
853861
if mask.ndim < 3:
854862
mask = mask.unsqueeze(0)
855863
needs_squeeze = True
856864
else:
857865
needs_squeeze = False
858866

859867
output = perspective_image_tensor(
860-
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
868+
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST, fill=fill
861869
)
862870

863871
if needs_squeeze:
@@ -944,14 +952,18 @@ def elastic_bounding_box(
944952
).view(original_shape)
945953

946954

947-
def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
955+
def elastic_mask(
956+
mask: torch.Tensor,
957+
displacement: torch.Tensor,
958+
fill: Optional[Union[int, float, List[float]]] = None,
959+
) -> torch.Tensor:
948960
if mask.ndim < 3:
949961
mask = mask.unsqueeze(0)
950962
needs_squeeze = True
951963
else:
952964
needs_squeeze = False
953965

954-
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
966+
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
955967

956968
if needs_squeeze:
957969
output = output.squeeze(0)

0 commit comments

Comments
 (0)