-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fix some annotations in transforms v2 for JIT v1 compatibility #7252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9dc3df3
d3de938
8690d9e
a28313e
558ad3d
c5bcda4
073d972
6c0e3e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import pytest | ||
import torchvision.prototype.transforms.functional as F | ||
from prototype_common_utils import InfoBase, TestMark | ||
from prototype_transforms_kernel_infos import KERNEL_INFOS | ||
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition | ||
from torchvision.prototype import datapoints | ||
|
||
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] | ||
|
@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): | |
) | ||
|
||
|
||
def xfail_jit_tuple_instead_of_list(name, *, reason=None): | ||
return xfail_jit( | ||
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting", | ||
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple), | ||
) | ||
|
||
|
||
def is_list_of_ints(args_kwargs): | ||
fill = args_kwargs.kwargs.get("fill") | ||
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill) | ||
|
||
|
||
def xfail_jit_list_of_ints(name, *, reason=None): | ||
return xfail_jit( | ||
reason or f"Passing a list of integers for `{name}` is not supported when scripting", | ||
condition=is_list_of_ints, | ||
) | ||
|
||
|
||
skip_dispatch_datapoint = TestMark( | ||
("TestDispatchers", "test_dispatch_datapoint"), | ||
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."), | ||
|
@@ -130,6 +111,13 @@ def xfail_jit_list_of_ints(name, *, reason=None): | |
multi_crop_skips.append(skip_dispatch_datapoint) | ||
|
||
|
||
def xfails_pil(reason, *, condition=None): | ||
return [ | ||
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition) | ||
for test_name in ["test_dispatch_pil", "test_pil_output_type"] | ||
] | ||
|
||
|
||
def fill_sequence_needs_broadcast(args_kwargs): | ||
(image_loader, *_), kwargs = args_kwargs | ||
try: | ||
|
@@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
return image_loader.num_channels > 1 | ||
|
||
|
||
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark( | ||
("TestDispatchers", "test_dispatch_pil"), | ||
pytest.mark.xfail( | ||
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger." | ||
), | ||
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil( | ||
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.", | ||
condition=fill_sequence_needs_broadcast, | ||
) | ||
|
||
|
@@ -186,11 +171,9 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
}, | ||
pil_kernel_info=PILKernelInfo(F.affine_image_pil), | ||
test_marks=[ | ||
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, | ||
*xfails_pil_if_fill_sequence_needs_broadcast, | ||
xfail_jit_python_scalar_arg("shear"), | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were on the right track 🤦 |
||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit_python_scalar_arg("fill"), | ||
], | ||
), | ||
DispatcherInfo( | ||
|
@@ -213,9 +196,8 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
}, | ||
pil_kernel_info=PILKernelInfo(F.rotate_image_pil), | ||
test_marks=[ | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit_python_scalar_arg("fill"), | ||
*xfails_pil_if_fill_sequence_needs_broadcast, | ||
], | ||
), | ||
DispatcherInfo( | ||
|
@@ -248,21 +230,16 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
}, | ||
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), | ||
test_marks=[ | ||
TestMark( | ||
("TestDispatchers", "test_dispatch_pil"), | ||
pytest.mark.xfail( | ||
reason=( | ||
"PIL kernel doesn't support sequences of length 1 for argument `fill` and " | ||
"`padding_mode='constant'`, if the number of color channels is larger." | ||
) | ||
*xfails_pil( | ||
reason=( | ||
"PIL kernel doesn't support sequences of length 1 for argument `fill` and " | ||
"`padding_mode='constant'`, if the number of color channels is larger." | ||
), | ||
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) | ||
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", | ||
), | ||
xfail_jit_tuple_instead_of_list("padding"), | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition), | ||
xfail_jit_python_scalar_arg("padding"), | ||
], | ||
), | ||
DispatcherInfo( | ||
|
@@ -275,7 +252,8 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
}, | ||
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), | ||
test_marks=[ | ||
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, | ||
*xfails_pil_if_fill_sequence_needs_broadcast, | ||
xfail_jit_python_scalar_arg("fill"), | ||
], | ||
), | ||
DispatcherInfo( | ||
|
@@ -287,6 +265,7 @@ def fill_sequence_needs_broadcast(args_kwargs): | |
datapoints.Mask: F.elastic_mask, | ||
}, | ||
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), | ||
test_marks=[xfail_jit_python_scalar_arg("fill")], | ||
), | ||
DispatcherInfo( | ||
F.center_crop, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): | |
) | ||
|
||
|
||
def xfail_jit_tuple_instead_of_list(name, *, reason=None): | ||
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting" | ||
return xfail_jit( | ||
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting", | ||
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple), | ||
) | ||
|
||
|
||
def is_list_of_ints(args_kwargs): | ||
fill = args_kwargs.kwargs.get("fill") | ||
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill) | ||
|
||
|
||
def xfail_jit_list_of_ints(name, *, reason=None): | ||
return xfail_jit( | ||
reason or f"Passing a list of integers for `{name}` is not supported when scripting", | ||
condition=is_list_of_ints, | ||
) | ||
|
||
|
||
KERNEL_INFOS = [] | ||
|
||
|
||
|
@@ -450,21 +430,21 @@ def _full_affine_params(**partial_params): | |
] | ||
|
||
|
||
def get_fills(*, num_channels, dtype, vector=True): | ||
def get_fills(*, num_channels, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now make sure that we get all possible fill types. |
||
yield None | ||
|
||
max_value = get_max_value(dtype) | ||
# This intentionally gives us a float and an int scalar fill value | ||
yield max_value / 2 | ||
yield max_value | ||
int_value = get_max_value(dtype) | ||
float_value = int_value / 2 | ||
yield int_value | ||
yield float_value | ||
|
||
if not vector: | ||
return | ||
for vector_type in [list, tuple]: | ||
yield vector_type([int_value]) | ||
yield vector_type([float_value]) | ||
|
||
if dtype.is_floating_point: | ||
yield [0.1 + c / 10 for c in range(num_channels)] | ||
else: | ||
yield [12.0 + c for c in range(num_channels)] | ||
if num_channels > 1: | ||
yield vector_type(float_value * c / 10 for c in range(num_channels)) | ||
yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels)) | ||
|
||
|
||
def float32_vs_uint8_fill_adapter(other_args, kwargs): | ||
|
@@ -644,9 +624,7 @@ def sample_inputs_affine_video(): | |
closeness_kwargs=pil_reference_pixel_difference(10, mae=True), | ||
test_marks=[ | ||
xfail_jit_python_scalar_arg("shear"), | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit_python_scalar_arg("fill"), | ||
], | ||
), | ||
KernelInfo( | ||
|
@@ -873,9 +851,7 @@ def sample_inputs_rotate_video(): | |
float32_vs_uint8=True, | ||
closeness_kwargs=pil_reference_pixel_difference(1, mae=True), | ||
test_marks=[ | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit_python_scalar_arg("fill"), | ||
], | ||
), | ||
KernelInfo( | ||
|
@@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor(): | |
for image_loader, params in itertools.product( | ||
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS | ||
): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
for fill in get_fills( | ||
num_channels=image_loader.num_channels, | ||
dtype=image_loader.dtype, | ||
vector=params["padding_mode"] == "constant", | ||
): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
if isinstance(fill, (list, tuple)): | ||
continue | ||
|
||
yield ArgsKwargs(image_loader, fill=fill, **params) | ||
|
||
|
||
|
@@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box(): | |
) | ||
|
||
|
||
def pad_xfail_jit_fill_condition(args_kwargs): | ||
fill = args_kwargs.kwargs.get("fill") | ||
if not isinstance(fill, (list, tuple)): | ||
return False | ||
elif isinstance(fill, tuple): | ||
return True | ||
else: # isinstance(fill, list): | ||
return all(isinstance(f, int) for f in fill) | ||
|
||
|
||
KERNEL_INFOS.extend( | ||
[ | ||
KernelInfo( | ||
|
@@ -1205,10 +1193,10 @@ def reference_inputs_pad_bounding_box(): | |
float32_vs_uint8=float32_vs_uint8_fill_adapter, | ||
closeness_kwargs=float32_vs_uint8_pixel_difference(), | ||
test_marks=[ | ||
xfail_jit_tuple_instead_of_list("padding"), | ||
xfail_jit_tuple_instead_of_list("fill"), | ||
# TODO: check if this is a regression since it seems that should be supported if `int` is ok | ||
xfail_jit_list_of_ints("fill"), | ||
xfail_jit_python_scalar_arg("padding"), | ||
xfail_jit( | ||
"F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition | ||
), | ||
], | ||
), | ||
KernelInfo( | ||
|
@@ -1217,7 +1205,7 @@ def reference_inputs_pad_bounding_box(): | |
reference_fn=reference_pad_bounding_box, | ||
reference_inputs_fn=reference_inputs_pad_bounding_box, | ||
test_marks=[ | ||
xfail_jit_tuple_instead_of_list("padding"), | ||
xfail_jit_python_scalar_arg("padding"), | ||
], | ||
), | ||
KernelInfo( | ||
|
@@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor(): | |
F.InterpolationMode.BILINEAR, | ||
], | ||
): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): | ||
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? | ||
if isinstance(fill, (list, tuple)): | ||
continue | ||
|
||
yield ArgsKwargs( | ||
image_loader, | ||
startpoints=None, | ||
|
@@ -1327,6 +1318,7 @@ def sample_inputs_perspective_video(): | |
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5), | ||
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), | ||
}, | ||
test_marks=[xfail_jit_python_scalar_arg("fill")], | ||
), | ||
KernelInfo( | ||
F.perspective_bounding_box, | ||
|
@@ -1418,6 +1410,7 @@ def sample_inputs_elastic_video(): | |
**float32_vs_uint8_pixel_difference(6, mae=True), | ||
**cuda_vs_cpu_pixel_difference(), | ||
}, | ||
test_marks=[xfail_jit_python_scalar_arg("fill")], | ||
), | ||
KernelInfo( | ||
F.elastic_bounding_box, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,7 +118,7 @@ def resized_crop( | |
def pad( | ||
self, | ||
padding: Union[int, Sequence[int]], | ||
fill: FillTypeJIT = None, | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... but keep it for |
||
padding_mode: str = "constant", | ||
) -> BoundingBox: | ||
output, spatial_size = self._F.pad_bounding_box( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
|
||
D = TypeVar("D", bound="Datapoint") | ||
FillType = Union[int, float, Sequence[int], Sequence[float], None] | ||
FillTypeJIT = Union[int, float, List[float], None] | ||
FillTypeJIT = Optional[List[float]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert this to what we had in v1 ... |
||
|
||
|
||
class Datapoint(torch.Tensor): | ||
|
@@ -169,8 +169,8 @@ def resized_crop( | |
|
||
def pad( | ||
self, | ||
padding: Union[int, List[int]], | ||
fill: FillTypeJIT = None, | ||
padding: List[int], | ||
fill: Optional[Union[int, float, List[float]]] = None, | ||
padding_mode: str = "constant", | ||
) -> Datapoint: | ||
return self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually observed that our functionals didn't work for tuples, but missed to check if v1 enforces this. Since we have aligned the behavior now, we can also remove this helper as it is no longer in use.