-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add topk min function for trace and onnx #5310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add topk min function for trace and onnx #5310
Conversation
💊 CI failures summary and remediationsAs of commit 859aec2 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
@shubhambhokare1 Thanks for the contribution. TorchVision's detection models do not properly support tracing. Instead we recommend users to JIT-script them. @fmassa @prabhat00155 Could you weight in on this one as you are more knowledgeable on ONNX? I also find it strange that despite not properly supporting tracing, our RPN code has some kind of mitigation about it. Any thoughts on why this is the case? |
Thanks @shubhambhokare1 for the PR, I have a couple of comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is toasted, possibly due to a bad synch. Worth force pushing to an earlier version or closing and opening an new one.
45756ed
to
d64af49
Compare
The user force pushed the branch and resolved the main issue.
Think it should be fixed now |
d64af49
to
c96a819
Compare
3b11dfa
to
b02c5e6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @shubhambhokare1!
Hey @prabhat00155! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
""" | ||
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) | ||
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) | ||
return min_kval # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@prabhat00155 @shubhambhokare1 There might be performance implications here that we need to benchmark thoroughly. If we have such benchmarks we should post them in this PR to make them easily accessible.
If we don't I propose either to revert the PR until we measure it. Another alternative is to do the onnx friendly approach only when we are tracing and do the previous simpler estimation in all other cases.
Summary: * Add topk minimizer function to _utils * Apply ufmt formatting * Apply min function for tracing and scripting * Add type ignore to avoid cast * fix flake * Fix python_type_check Reviewed By: jdsgomes Differential Revision: D34475307 fbshipit-source-id: 0dfa9d59a0fc85d43247ab0f4024deda9640609c Co-authored-by: Prabhat Roy <[email protected]>
ONNX spec requires the k-value to be less than or equal to the number of inputs along provided dim. Certain models use the number of elements along a particular axis instead of K if K exceeds the number of elements along that axis. Previously, python's min() function was used to determine whether to use the provided k-value or the specified dim axis value.
However in cases where the model is being exported in tracing mode, python min() is static causing the model to be traced incorrectly and eventually fail at the topk node. In order to avoid this situation, in tracing mode, torch.min() is used instead.