diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 8fe5333aa51..442dd526ed3 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -3,7 +3,7 @@ import pytest import torchvision.prototype.transforms.functional as F from prototype_common_utils import InfoBase, TestMark -from prototype_transforms_kernel_infos import KERNEL_INFOS +from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition from torchvision.prototype import datapoints __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] @@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ) -def xfail_jit_tuple_instead_of_list(name, *, reason=None): - return xfail_jit( - reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting", - condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple), - ) - - -def is_list_of_ints(args_kwargs): - fill = args_kwargs.kwargs.get("fill") - return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill) - - -def xfail_jit_list_of_ints(name, *, reason=None): - return xfail_jit( - reason or f"Passing a list of integers for `{name}` is not supported when scripting", - condition=is_list_of_ints, - ) - - skip_dispatch_datapoint = TestMark( ("TestDispatchers", "test_dispatch_datapoint"), pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."), @@ -130,6 +111,13 @@ def xfail_jit_list_of_ints(name, *, reason=None): multi_crop_skips.append(skip_dispatch_datapoint) +def xfails_pil(reason, *, condition=None): + return [ + TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition) + for test_name in ["test_dispatch_pil", "test_pil_output_type"] + ] + + def fill_sequence_needs_broadcast(args_kwargs): (image_loader, *_), kwargs = args_kwargs try: @@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs): return image_loader.num_channels > 1 -xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark( - ("TestDispatchers", "test_dispatch_pil"), - pytest.mark.xfail( - reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger." - ), +xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil( + "PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.", condition=fill_sequence_needs_broadcast, ) @@ -186,11 +171,9 @@ def fill_sequence_needs_broadcast(args_kwargs): }, pil_kernel_info=PILKernelInfo(F.affine_image_pil), test_marks=[ - xfail_dispatch_pil_if_fill_sequence_needs_broadcast, + *xfails_pil_if_fill_sequence_needs_broadcast, xfail_jit_python_scalar_arg("shear"), - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit_python_scalar_arg("fill"), ], ), DispatcherInfo( @@ -213,9 +196,8 @@ def fill_sequence_needs_broadcast(args_kwargs): }, pil_kernel_info=PILKernelInfo(F.rotate_image_pil), test_marks=[ - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit_python_scalar_arg("fill"), + *xfails_pil_if_fill_sequence_needs_broadcast, ], ), DispatcherInfo( @@ -248,21 +230,16 @@ def fill_sequence_needs_broadcast(args_kwargs): }, pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), test_marks=[ - TestMark( - ("TestDispatchers", "test_dispatch_pil"), - pytest.mark.xfail( - reason=( - "PIL kernel doesn't support sequences of length 1 for argument `fill` and " - "`padding_mode='constant'`, if the number of color channels is larger." - ) + *xfails_pil( + reason=( + "PIL kernel doesn't support sequences of length 1 for argument `fill` and " + "`padding_mode='constant'`, if the number of color channels is larger." ), condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", ), - xfail_jit_tuple_instead_of_list("padding"), - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition), + xfail_jit_python_scalar_arg("padding"), ], ), DispatcherInfo( @@ -275,7 +252,8 @@ def fill_sequence_needs_broadcast(args_kwargs): }, pil_kernel_info=PILKernelInfo(F.perspective_image_pil), test_marks=[ - xfail_dispatch_pil_if_fill_sequence_needs_broadcast, + *xfails_pil_if_fill_sequence_needs_broadcast, + xfail_jit_python_scalar_arg("fill"), ], ), DispatcherInfo( @@ -287,6 +265,7 @@ def fill_sequence_needs_broadcast(args_kwargs): datapoints.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), + test_marks=[xfail_jit_python_scalar_arg("fill")], ), DispatcherInfo( F.center_crop, diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index dc9c3e57d7a..ce05c980a87 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ) -def xfail_jit_tuple_instead_of_list(name, *, reason=None): - reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting" - return xfail_jit( - reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting", - condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple), - ) - - -def is_list_of_ints(args_kwargs): - fill = args_kwargs.kwargs.get("fill") - return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill) - - -def xfail_jit_list_of_ints(name, *, reason=None): - return xfail_jit( - reason or f"Passing a list of integers for `{name}` is not supported when scripting", - condition=is_list_of_ints, - ) - - KERNEL_INFOS = [] @@ -450,21 +430,21 @@ def _full_affine_params(**partial_params): ] -def get_fills(*, num_channels, dtype, vector=True): +def get_fills(*, num_channels, dtype): yield None - max_value = get_max_value(dtype) - # This intentionally gives us a float and an int scalar fill value - yield max_value / 2 - yield max_value + int_value = get_max_value(dtype) + float_value = int_value / 2 + yield int_value + yield float_value - if not vector: - return + for vector_type in [list, tuple]: + yield vector_type([int_value]) + yield vector_type([float_value]) - if dtype.is_floating_point: - yield [0.1 + c / 10 for c in range(num_channels)] - else: - yield [12.0 + c for c in range(num_channels)] + if num_channels > 1: + yield vector_type(float_value * c / 10 for c in range(num_channels)) + yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels)) def float32_vs_uint8_fill_adapter(other_args, kwargs): @@ -644,9 +624,7 @@ def sample_inputs_affine_video(): closeness_kwargs=pil_reference_pixel_difference(10, mae=True), test_marks=[ xfail_jit_python_scalar_arg("shear"), - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit_python_scalar_arg("fill"), ], ), KernelInfo( @@ -873,9 +851,7 @@ def sample_inputs_rotate_video(): float32_vs_uint8=True, closeness_kwargs=pil_reference_pixel_difference(1, mae=True), test_marks=[ - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit_python_scalar_arg("fill"), ], ), KernelInfo( @@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor(): for image_loader, params in itertools.product( make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS ): - # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? for fill in get_fills( num_channels=image_loader.num_channels, dtype=image_loader.dtype, - vector=params["padding_mode"] == "constant", ): + # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? + if isinstance(fill, (list, tuple)): + continue + yield ArgsKwargs(image_loader, fill=fill, **params) @@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box(): ) +def pad_xfail_jit_fill_condition(args_kwargs): + fill = args_kwargs.kwargs.get("fill") + if not isinstance(fill, (list, tuple)): + return False + elif isinstance(fill, tuple): + return True + else: # isinstance(fill, list): + return all(isinstance(f, int) for f in fill) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -1205,10 +1193,10 @@ def reference_inputs_pad_bounding_box(): float32_vs_uint8=float32_vs_uint8_fill_adapter, closeness_kwargs=float32_vs_uint8_pixel_difference(), test_marks=[ - xfail_jit_tuple_instead_of_list("padding"), - xfail_jit_tuple_instead_of_list("fill"), - # TODO: check if this is a regression since it seems that should be supported if `int` is ok - xfail_jit_list_of_ints("fill"), + xfail_jit_python_scalar_arg("padding"), + xfail_jit( + "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition + ), ], ), KernelInfo( @@ -1217,7 +1205,7 @@ def reference_inputs_pad_bounding_box(): reference_fn=reference_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_box, test_marks=[ - xfail_jit_tuple_instead_of_list("padding"), + xfail_jit_python_scalar_arg("padding"), ], ), KernelInfo( @@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor(): F.InterpolationMode.BILINEAR, ], ): - # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): + # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? + if isinstance(fill, (list, tuple)): + continue + yield ArgsKwargs( image_loader, startpoints=None, @@ -1327,6 +1318,7 @@ def sample_inputs_perspective_video(): **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), }, + test_marks=[xfail_jit_python_scalar_arg("fill")], ), KernelInfo( F.perspective_bounding_box, @@ -1418,6 +1410,7 @@ def sample_inputs_elastic_video(): **float32_vs_uint8_pixel_difference(6, mae=True), **cuda_vs_cpu_pixel_difference(), }, + test_marks=[xfail_jit_python_scalar_arg("fill")], ), KernelInfo( F.elastic_bounding_box, diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 7c44a9e4b26..e04a965d9fc 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -118,7 +118,7 @@ def resized_crop( def pad( self, padding: Union[int, Sequence[int]], - fill: FillTypeJIT = None, + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> BoundingBox: output, spatial_size = self._F.pad_bounding_box( diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index d75a2211071..3738d2a8124 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -12,7 +12,7 @@ D = TypeVar("D", bound="Datapoint") FillType = Union[int, float, Sequence[int], Sequence[float], None] -FillTypeJIT = Union[int, float, List[float], None] +FillTypeJIT = Optional[List[float]] class Datapoint(torch.Tensor): @@ -169,8 +169,8 @@ def resized_crop( def pad( self, - padding: Union[int, List[int]], - fill: FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> Datapoint: return self diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index bbd06de707a..8f3092fa1e7 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -103,8 +103,8 @@ def resized_crop( def pad( self, - padding: Union[int, List[int]], - fill: FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index dec26f80af1..a1870fa4b20 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -83,8 +83,8 @@ def resized_crop( def pad( self, - padding: Union[int, List[int]], - fill: FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> Mask: output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 2f628f2efc4..0e5ff7a17b8 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -102,8 +102,8 @@ def resized_crop( def pad( self, - padding: Union[int, List[int]], - fill: FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> Video: output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b8c8d10ae1d..c4708cc57bd 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -270,7 +270,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index b5ec05669e9..f2d818b1326 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: if fill is None: return fill - # This cast does Sequence -> List[float] to please mypy and torch.jit.script if not isinstance(fill, (int, float)): fill = [float(v) for v in list(fill)] - return fill + return fill # type: ignore[return-value] def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5fa6cbb4873..7fa0736ccb6 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -432,7 +432,7 @@ def _apply_grid_transform( if fill is not None: float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) mask = mask.expand_as(float_img) - fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] + fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) if mode == "nearest": bool_mask = mask < 0.5 @@ -968,8 +968,8 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: def pad_image_tensor( image: torch.Tensor, - padding: Union[int, List[int]], - fill: datapoints.FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> torch.Tensor: # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses @@ -1069,14 +1069,14 @@ def _pad_with_vector_fill( def pad_mask( mask: torch.Tensor, - padding: Union[int, List[int]], + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", - fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: if fill is None: fill = 0 - if isinstance(fill, list): + if isinstance(fill, (tuple, list)): raise ValueError("Non-scalar fill value is not supported") if mask.ndim < 3: @@ -1097,7 +1097,7 @@ def pad_bounding_box( bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], - padding: Union[int, List[int]], + padding: List[int], padding_mode: str = "constant", ) -> Tuple[torch.Tensor, Tuple[int, int]]: if padding_mode not in ["constant"]: @@ -1122,8 +1122,8 @@ def pad_bounding_box( def pad_video( video: torch.Tensor, - padding: Union[int, List[int]], - fill: datapoints.FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> torch.Tensor: return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) @@ -1131,8 +1131,8 @@ def pad_video( def pad( inpt: datapoints.InputTypeJIT, - padding: Union[int, List[int]], - fill: datapoints.FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting():