diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 09c8bdbcfeb..c79aef6cb04 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Sequence import numpy as np import torch @@ -390,7 +390,7 @@ def _compute_resized_output_size( def resize( img: Tensor, - size: List[int], + size: Union[int,Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", @@ -492,7 +492,7 @@ def resize( return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) -def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: +def pad(img: Tensor, padding: Union[int, Sequence[int]], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, @@ -566,7 +566,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: return F_t.crop(img, top, left, height, width) -def center_crop(img: Tensor, output_size: List[int]) -> Tensor: +def center_crop(img: Tensor, output_size: Union[int, Sequence[int]]) -> Tensor: """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -613,7 +613,7 @@ def resized_crop( left: int, height: int, width: int, - size: List[int], + size: Union[int, Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> Tensor: @@ -782,7 +782,7 @@ def vflip(img: Tensor) -> Tensor: return F_t.vflip(img) -def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: +def five_crop(img: Tensor, size: Union[int, Sequence[int]]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -828,7 +828,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten def ten_crop( - img: Tensor, size: List[int], vertical_flip: bool = False + img: Tensor, size: Union[int, Sequence[int]], vertical_flip: bool = False ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Generate ten cropped images from the given image. Crop the given image into four corners and the central crop plus the