Preserve Datapoint subclasses instead of returning tensors#7807
Preserve Datapoint subclasses instead of returning tensors#7807NicolasHug wants to merge 12 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7807
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 1 Unrelated FailureAs of commit ec17580: BROKEN TRUNK - The following job failed but were present on the merge base bf6a8dc:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| with pytest.raises(TypeError, match="unsupported operand type"): | ||
| img + mask |
There was a problem hiding this comment.
Users want to do that? Perfect, they'll need explicitly say what type they want as output by converting one of those operands to a tensor. We don't have to assume anything on their behalf and (surprisingly) return a pure tensor.
EDIT: as @pmeier pointed out offline, this is in fact the same behaviour as on main - nothing new
|
|
||
| output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) | ||
| output_boxes, output_canvas_size = F.resized_crop_bounding_boxes( | ||
| in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size |
There was a problem hiding this comment.
This (and similar changes below) was needed because in_boxes is now still a BBox instance, and resized_crop_bounding_boxes expects a tensor (there is an error saying something like "if you pass a bbox, don't pass the format").
| # Copy-paste masks: | ||
| masks = masks * inverse_paste_alpha_mask | ||
| non_all_zero_masks = masks.sum((-1, -2)) > 0 | ||
| non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor) |
There was a problem hiding this comment.
There was 2 other similar failures (below). The reason for the error is that (masks.sum((-1, -2)) > 0) is still a Mask object, and we can't use Masks as indices (line below).
This is the only kind of instance that I identified as potentially weird / confusing. But the error message is good enough to figure out the fix.
(In contrast, unwrapping all the time is likely to cause a lot more surprises and forces users to re-wrap all the time).
| assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint | ||
|
|
||
|
|
||
| def test_operations(): |
There was a problem hiding this comment.
This test is mostly for illustrating the new behaviour. If we're OK with it, I'll refactor this test into something a little more polished
|
Got superseded by #7825 |
This PR addresses the "subclass unwrapping" issue from #7319.
We now always preserve the Datapoint type when doing native operations like
img + 3orimg + some_tensor. This largely simplifies theDatapointclass implementation and avoid the potentially surprising "unwrapping" behaviour.BoundingBoxes is the only class that needs a special treatment as it requires metadata, so it's the only class for which we override
__torch_function__. Overall, the Datapoint logic is greatly simplified as it largely relies on the default ones fromtorch.Tensor.Take a look at the newly-added
test_operations()for an illustration of what is now possible.Note: following #7807 (comment), the unwrapping / rewrapping mechanism in our functionals is preserved for perf reasons only.