Skip to content

Add topk min function for trace and onnx #5310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 23, 2022

Conversation

shubhambhokare1
Copy link
Contributor

@shubhambhokare1 shubhambhokare1 commented Jan 28, 2022

ONNX spec requires the k-value to be less than or equal to the number of inputs along provided dim. Certain models use the number of elements along a particular axis instead of K if K exceeds the number of elements along that axis. Previously, python's min() function was used to determine whether to use the provided k-value or the specified dim axis value.

However in cases where the model is being exported in tracing mode, python min() is static causing the model to be traced incorrectly and eventually fail at the topk node. In order to avoid this situation, in tracing mode, torch.min() is used instead.

@facebook-github-bot
Copy link

facebook-github-bot commented Jan 28, 2022

💊 CI failures summary and remediations

As of commit 859aec2 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@datumbox
Copy link
Contributor

@shubhambhokare1 Thanks for the contribution. TorchVision's detection models do not properly support tracing. Instead we recommend users to JIT-script them.

@fmassa @prabhat00155 Could you weight in on this one as you are more knowledgeable on ONNX? I also find it strange that despite not properly supporting tracing, our RPN code has some kind of mitigation about it. Any thoughts on why this is the case?

@prabhat00155
Copy link
Contributor

Thanks @shubhambhokare1 for the PR, I have a couple of comments.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is toasted, possibly due to a bad synch. Worth force pushing to an earlier version or closing and opening an new one.

@datumbox datumbox dismissed their stale review February 16, 2022 19:39

The user force pushed the branch and resolved the main issue.

@shubhambhokare1
Copy link
Contributor Author

The PR is toasted, possibly due to a bad synch. Worth force pushing to an earlier version or closing and opening an new one.

Think it should be fixed now

Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prabhat00155 prabhat00155 merged commit 86a14cb into pytorch:main Feb 23, 2022
@github-actions
Copy link

Hey @prabhat00155!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

"""
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return min_kval # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prabhat00155 @shubhambhokare1 There might be performance implications here that we need to benchmark thoroughly. If we have such benchmarks we should post them in this PR to make them easily accessible.

If we don't I propose either to revert the PR until we measure it. Another alternative is to do the onnx friendly approach only when we are tracing and do the previous simpler estimation in all other cases.

facebook-github-bot pushed a commit that referenced this pull request Feb 25, 2022
Summary:
* Add topk minimizer function to _utils

* Apply ufmt formatting

* Apply min function for tracing and scripting

* Add type ignore to avoid cast

* fix flake

* Fix python_type_check

Reviewed By: jdsgomes

Differential Revision: D34475307

fbshipit-source-id: 0dfa9d59a0fc85d43247ab0f4024deda9640609c

Co-authored-by: Prabhat Roy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants