Skip to content

Add scale option to ToDtype. Remove ConvertDtype.#7759

Merged
NicolasHug merged 14 commits intopytorch:mainfrom
NicolasHug:dtypeeeeeeeeeeeee
Jul 27, 2023
Merged

Add scale option to ToDtype. Remove ConvertDtype.#7759
NicolasHug merged 14 commits intopytorch:mainfrom
NicolasHug:dtypeeeeeeeeeeeee

Conversation

@NicolasHug
Copy link
Copy Markdown
Member

@NicolasHug NicolasHug commented Jul 25, 2023

Closes #7756

this PR:

  • adds a scale parameter to ToDtype, which only affects images or videos
  • removes ConvetDtype to keep ConvertImageDtype. ConvertImageDtype now only supports images, not videos.
  • removes all dispatchers / kernels associated with convert_.* and replace with to_dtype* dispatchers / kernels.
  • When passing ToDtype(torch.float32) i.e. the dtype parameter is just a dtype, not a dict, we only convert images and videos - this is for BC with ConvertImageDtype and reduces adoption friction

(I may have removed too much stuff on the existing tests, we'll see with the CI)

cc @vfdev-5

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jul 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7759

Note: Links to docs will display an error until the docs builds have been completed.

❌ 23 New Failures

As of commit dd903a8:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Copy Markdown
Contributor

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a bunch of comments, but overall looks solid.


def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not scale:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this come after the dtype check? Functionally, this is a no-op here in case image.dtype == dtype, but it is still an unnecessary call.

if image.dtype == dtype:
    return image
elif not scale:
    return image.to(dtype)

makes the behavior a little more clear and should be minimally more performant.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it but TBH I don't really follow the reasoning behind the change

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The elif part actually adds confusion to me because it suggests these 2 blocks are related when in reality they're really not

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reading the original implementation: I was wrong. Since we already return'ed in both branches, this should have no effect on the performance. Thus, feel free to revert to what you had.

Still, I would put the dtype check at the top, because that is the "more important" one. Again, just style / personal preference. Ignore if you feel different.

The elif part actually adds confusion to me because it suggests these 2 blocks are related when in reality they're really not

Fair enough. Happy with a second if as well.

check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not against it, but are we expecting more kernels to set this to False?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so (more precisely: I don't know). We still want the rest of the checks to be done for ToDtype though. Is there a better way than to add a parameter?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Since this will be the only the kernel that ever needs this, we could implement a custom check_kernel inside the test class. Basically copy-pasting the current function, but add the new parameter as well as stripping everything that is more generic, but not needed in this specific case. This would keep the check_kernel signature clean since it already has quite a few parameters.

Kinda torn on this. Up to you. I'm ok with both.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a slight preference for adding the parameter, because otherwise we'd have to change both implementation of check_kernel if we ever needed to update it.
(looks like I'm advocating for this single entry point after all :p )

def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]], scale: bool = False) -> None:
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this behaviour we would like to go away from defaultdict(lambda: dtype) ?

Copy link
Copy Markdown
Member Author

@NicolasHug NicolasHug Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is one of the goals #7756 (comment). It basically just gets replaced by {"others": dtype}:

  • no need to know what a defaultdict is
  • no need to import defaultdict
  • no need to know what a lambda is
  • no need to understand how defaultdict and lambda interact

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But should not this be done in a separate PR and everywhere it is used ?

Copy link
Copy Markdown
Member Author

@NicolasHug NicolasHug Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this gets accepted here I was going to submit a follow-up PR to do the same changes for the fill parameters (which is probably much more work). Are there other places where this pattern is used?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good to me.

PermuteDimensions, TransposeDimensions and _setup_fill_arg

@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
@pytest.mark.parametrize("as_dict", (True, False))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because when the dtype parameters isn't a dict and when the input is a single bbox or a single mask (as is the case here), the input just gets passed-through. I'm converting the dtype into a dict so that we also test the rest of the code-paths for those single bboxes and masks. Not incredibly critical, just for coverage.

Copy link
Copy Markdown
Contributor

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, but nothing blocking. LGTM if CI is green. Thanks Nicolas!

check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Since this will be the only the kernel that ever needs this, we could implement a custom check_kernel inside the test class. Basically copy-pasting the current function, but add the new parameter as well as stripping everything that is more generic, but not needed in this specific case. This would keep the check_kernel signature clean since it already has quite a few parameters.

Kinda torn on this. Up to you. I'm ok with both.

return inpt.max() <= 1

H, W = 10, 10
sample = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also throw a video in there?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do (from the parametrization), although I kept img as a name which I agree is confusing. I'll use inpt instead

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I missed that.

make_video,
),
)
def test_behaviour(self, make_input):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is huge. Should we maybe split it into multiple ones?

NicolasHug and others added 2 commits July 27, 2023 10:14
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
@NicolasHug
Copy link
Copy Markdown
Member Author

Tests are green in 92f2588 except for the linter which I just fixed. Merging, thanks for the reviews!

@NicolasHug NicolasHug merged commit 1402eb8 into pytorch:main Jul 27, 2023
@github-actions
Copy link
Copy Markdown

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2023
Reviewed By: matteobettini

Differential Revision: D48642282

fbshipit-source-id: 95a2eea16407f17e1ebeb386cd5e2618a105450f

Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dtype and scale conversion in V2

4 participants