-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add scale option to ToDtype. Remove ConvertDtype. #7759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
45e0e2e
9ac672b
c3f71aa
00aef9b
4afaee4
64bfb7d
c4224b3
134deb5
760d014
ea7c1e2
fb288d2
3d0d5d9
92f2588
dd903a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import contextlib | ||
| import decimal | ||
| import inspect | ||
| import math | ||
| import re | ||
|
|
@@ -27,6 +28,7 @@ | |
| set_rng_seed, | ||
| ) | ||
| from torch.testing import assert_close | ||
| from torch.utils._pytree import tree_map | ||
| from torchvision import datapoints | ||
|
|
||
| from torchvision.transforms._functional_tensor import _max_value as get_max_value | ||
|
|
@@ -68,7 +70,8 @@ def _script(fn): | |
| try: | ||
| return torch.jit.script(fn) | ||
| except Exception as error: | ||
| raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error | ||
| name = getattr(fn, "__name__", fn.__class__.__name__) | ||
| raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error | ||
|
|
||
|
|
||
| def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): | ||
|
|
@@ -125,6 +128,7 @@ def check_kernel( | |
| check_cuda_vs_cpu=True, | ||
| check_scripted_vs_eager=True, | ||
| check_batched_vs_unbatched=True, | ||
| expect_same_dtype=True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not against it, but are we expecting more kernels to set this to False?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so (more precisely: I don't know). We still want the rest of the checks to be done for ToDtype though. Is there a better way than to add a parameter?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. Since this will be the only the kernel that ever needs this, we could implement a custom Kinda torn on this. Up to you. I'm ok with both.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a slight preference for adding the parameter, because otherwise we'd have to change both implementation of |
||
| **kwargs, | ||
| ): | ||
| initial_input_version = input._version | ||
|
|
@@ -137,7 +141,8 @@ def check_kernel( | |
| # check that no inplace operation happened | ||
| assert input._version == initial_input_version | ||
|
|
||
| assert output.dtype == input.dtype | ||
| if expect_same_dtype: | ||
| assert output.dtype == input.dtype | ||
| assert output.device == input.device | ||
|
|
||
| if check_cuda_vs_cpu: | ||
|
|
@@ -274,7 +279,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type): | |
| def _check_transform_v1_compatibility(transform, input): | ||
| """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static | ||
| ``get_params`` method, is scriptable, and the scripted version can be called without error.""" | ||
| if not hasattr(transform, "_v1_transform_cls"): | ||
| if transform._v1_transform_cls is None: | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return | ||
|
|
||
| if type(input) is not torch.Tensor: | ||
|
|
@@ -1634,3 +1639,111 @@ def test_transform_negative_degrees_error(self): | |
| def test_transform_unknown_fill_error(self): | ||
| with pytest.raises(TypeError, match="Got inappropriate fill arg"): | ||
| transforms.RandomAffine(degrees=0, fill="fill") | ||
|
|
||
|
|
||
| class TestToDtype: | ||
| @pytest.mark.parametrize( | ||
| ("kernel", "make_input"), | ||
| [ | ||
| (F.to_dtype_image_tensor, make_image_tensor), | ||
| (F.to_dtype_image_tensor, make_image), | ||
| (F.to_dtype_video, make_video), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("device", cpu_and_cuda()) | ||
| @pytest.mark.parametrize("scale", (True, False)) | ||
| def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale): | ||
| check_kernel( | ||
| kernel, | ||
| make_input(dtype=input_dtype, device=device), | ||
| expect_same_dtype=False, | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype=output_dtype, | ||
| scale=scale, | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ("kernel", "make_input"), | ||
| [ | ||
| (F.to_dtype_image_tensor, make_image_tensor), | ||
| (F.to_dtype_image_tensor, make_image), | ||
| (F.to_dtype_video, make_video), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("device", cpu_and_cuda()) | ||
| @pytest.mark.parametrize("scale", (True, False)) | ||
| def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale): | ||
| check_dispatcher( | ||
| F.to_dtype, | ||
| kernel, | ||
| make_input(dtype=input_dtype, device=device), | ||
| check_dispatch=False, # TODO: the check would pass if we were to use the non-datapoint dependent logic of _check_dispatcher_dispatch ¯\_(ツ)_/¯ | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype=output_dtype, | ||
| scale=scale, | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "make_input", | ||
| [make_image_tensor, make_image, make_bounding_box, make_segmentation_mask, make_video], | ||
| ) | ||
| @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("device", cpu_and_cuda()) | ||
| @pytest.mark.parametrize("scale", (True, False)) | ||
| def test_transform(self, make_input, input_dtype, output_dtype, device, scale): | ||
| input = make_input(dtype=input_dtype, device=device) | ||
| output_dtype = {"others": torch.float32} | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale) | ||
|
|
||
| def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): | ||
| input_dtype = image.dtype | ||
| output_dtype = dtype | ||
|
|
||
| if not scale: | ||
| return image.to(dtype) | ||
|
|
||
| if output_dtype == input_dtype: | ||
| return image | ||
|
|
||
| def fn(value): | ||
| if input_dtype.is_floating_point: | ||
| if output_dtype.is_floating_point: | ||
| return value | ||
| else: | ||
| return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) | ||
| else: | ||
| input_max_value = torch.iinfo(input_dtype).max | ||
|
|
||
| if output_dtype.is_floating_point: | ||
| return float(decimal.Decimal(value) / input_max_value) | ||
| else: | ||
| output_max_value = torch.iinfo(output_dtype).max | ||
|
|
||
| if input_max_value > output_max_value: | ||
| factor = (input_max_value + 1) // (output_max_value + 1) | ||
| return value / factor | ||
| else: | ||
| factor = (output_max_value + 1) // (input_max_value + 1) | ||
| return value * factor | ||
|
|
||
| return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) | ||
|
|
||
| @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
| @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize("device", cpu_and_cuda()) | ||
| @pytest.mark.parametrize("scale", (True, False)) | ||
| def test_against_ref(self, input_dtype, output_dtype, device, scale): | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| input = make_image(dtype=input_dtype, device=device) | ||
|
|
||
| out = F.to_dtype(input, dtype=output_dtype, scale=scale) | ||
| expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) | ||
|
|
||
| if output_dtype is torch.uint8 and scale: | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # TODO: is this actually normal? Why wasn't this a problem before? | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| torch.testing.assert_close(out, expected, atol=1, rtol=0) | ||
| else: | ||
| torch.testing.assert_close(out, expected) | ||
Uh oh!
There was an error while loading. Please reload this page.