Skip to content

Fix some annotations in transforms v2 for JIT v1 compatibility #7252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 23 additions & 44 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
Expand Down Expand Up @@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)


def xfail_jit_tuple_instead_of_list(name, *, reason=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually observed that our functionals didn't work for tuples, but missed to check if v1 enforces this. Since we have aligned the behavior now, we can also remove this helper as it is no longer in use.

return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)


def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)


def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)


skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
Expand All @@ -130,6 +111,13 @@ def xfail_jit_list_of_ints(name, *, reason=None):
multi_crop_skips.append(skip_dispatch_datapoint)


def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
Expand All @@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
return image_loader.num_channels > 1


xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)

Expand Down Expand Up @@ -186,11 +171,9 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were on the right track 🤦

xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
DispatcherInfo(
Expand All @@ -213,9 +196,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
*xfails_pil_if_fill_sequence_needs_broadcast,
],
),
DispatcherInfo(
Expand Down Expand Up @@ -248,21 +230,16 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"),
],
),
DispatcherInfo(
Expand All @@ -275,7 +252,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
],
),
DispatcherInfo(
Expand All @@ -287,6 +265,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
datapoints.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
DispatcherInfo(
F.center_crop,
Expand Down
83 changes: 38 additions & 45 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)


def xfail_jit_tuple_instead_of_list(name, *, reason=None):
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)


def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)


def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)


KERNEL_INFOS = []


Expand Down Expand Up @@ -450,21 +430,21 @@ def _full_affine_params(**partial_params):
]


def get_fills(*, num_channels, dtype, vector=True):
def get_fills(*, num_channels, dtype):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now make sure that we get all possible fill types.

yield None

max_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value
yield max_value / 2
yield max_value
int_value = get_max_value(dtype)
float_value = int_value / 2
yield int_value
yield float_value

if not vector:
return
for vector_type in [list, tuple]:
yield vector_type([int_value])
yield vector_type([float_value])

if dtype.is_floating_point:
yield [0.1 + c / 10 for c in range(num_channels)]
else:
yield [12.0 + c for c in range(num_channels)]
if num_channels > 1:
yield vector_type(float_value * c / 10 for c in range(num_channels))
yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))


def float32_vs_uint8_fill_adapter(other_args, kwargs):
Expand Down Expand Up @@ -644,9 +624,7 @@ def sample_inputs_affine_video():
closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
KernelInfo(
Expand Down Expand Up @@ -873,9 +851,7 @@ def sample_inputs_rotate_video():
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
KernelInfo(
Expand Down Expand Up @@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills(
num_channels=image_loader.num_channels,
dtype=image_loader.dtype,
vector=params["padding_mode"] == "constant",
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue

yield ArgsKwargs(image_loader, fill=fill, **params)


Expand Down Expand Up @@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box():
)


def pad_xfail_jit_fill_condition(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
if not isinstance(fill, (list, tuple)):
return False
elif isinstance(fill, tuple):
return True
else: # isinstance(fill, list):
return all(isinstance(f, int) for f in fill)


KERNEL_INFOS.extend(
[
KernelInfo(
Expand All @@ -1205,10 +1193,10 @@ def reference_inputs_pad_bounding_box():
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("padding"),
xfail_jit(
"F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
),
],
),
KernelInfo(
Expand All @@ -1217,7 +1205,7 @@ def reference_inputs_pad_bounding_box():
reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_python_scalar_arg("padding"),
],
),
KernelInfo(
Expand Down Expand Up @@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor():
F.InterpolationMode.BILINEAR,
],
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue

yield ArgsKwargs(
image_loader,
startpoints=None,
Expand Down Expand Up @@ -1327,6 +1318,7 @@ def sample_inputs_perspective_video():
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.perspective_bounding_box,
Expand Down Expand Up @@ -1418,6 +1410,7 @@ def sample_inputs_elastic_video():
**float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.elastic_bounding_box,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, Sequence[int]],
fill: FillTypeJIT = None,
fill: Optional[Union[int, float, List[float]]] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but keep it for F.pad

padding_mode: str = "constant",
) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box(
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
FillTypeJIT = Optional[List[float]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this to what we had in v1 ...



class Datapoint(torch.Tensor):
Expand Down Expand Up @@ -169,8 +169,8 @@ def resized_crop(

def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Datapoint:
return self
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def resized_crop(

def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def resized_crop(

def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def resized_crop(

def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]


class RandomZoomOut(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
if fill is None:
return fill

# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
return fill # type: ignore[return-value]


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
Expand Down
Loading