-
Notifications
You must be signed in to change notification settings - Fork 7.1k
handle inplace operations in _Feature.__torch_function__ #6671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Conflicts: torchvision/prototype/features/_feature.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
torch.testing.assert_close(inpt - 0.5, output) | ||
|
||
inpt = make_image(color_space=features.ColorSpace.RGB) | ||
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) | ||
assert isinstance(output, torch.Tensor) | ||
assert type(output) is torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was wrong and this is why it didn't fail before fixing normalize
: since features.Image
is a subclass of torch.Tensor
the check passes. We need to check for type identity here.
@@ -907,15 +907,15 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s | |||
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") | |||
|
|||
|
|||
def test_midlevel_normalize_output_type(): | |||
def test_normalize_output_type(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed since
- we agreed to use the kernel / dispatcher scheme rather than low / mid level kernels.
- using
normalize
in the name is sufficient, since the kernel is callednormalize_image_tensor
tensor.sub_(mean).div_(std) | ||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Behavior is equivalent for plain tensors, but gives us the unwrapping behavior we want from this change. With this in place, explicit unwrapping in normalize
is no longer needed. Let me know if I should keep it there. Runtime is the same, because here we would implicitly unwrap inside __torch_function__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially I thought this was a styling change. Then I realize you do this to avoid returning tensor
(which might be an image) and instead return the unwrapped torch.Tensor
. Might worth adding a short comment in a follow up PR.
# Image instance after normalization is not Image anymore due to unknown data range | ||
# Thus we return Tensor for input Image | ||
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) | ||
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives is the correct error behavior in eager mode. @vfdev-5 noted that before, we could pass bounding boxes or the like to normalize
and it would only fail in the computation rather than being caught before.
Hey @pmeier! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
) Summary: * prevent feature wrapping for inplace ops * cleanup * mypy * refactor __torch_function__ to be more concise * avoid double lookup * fix normalize * refactor normalize * mypy Reviewed By: datumbox Differential Revision: D40138745 fbshipit-source-id: 4c28b1a4a8ebbeef6b47de66eb7c41dbcf1e5908
Fixes #6669. At least as far as it is possible. We can only unwrap the result of an inplace operation, i.e. what is returned by it. We cannot change the type inplace and thus we can't prevent the following:
As you can see
output
is properly unwrapped, but we can't do that forimage
although we have no idea if the operation "invalidated" the type.