-
Notifications
You must be signed in to change notification settings - Fork 98
[ort_fusuion] Support fp16 in rms_norm fusion #2491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ort_fusuion] Support fp16 in rms_norm fusion #2491
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2491 +/- ##
=======================================
Coverage 70.08% 70.09%
=======================================
Files 212 212
Lines 25646 25647 +1
Branches 2573 2573
=======================================
+ Hits 17973 17976 +3
+ Misses 6783 6781 -2
Partials 890 890 ☔ View full report in Codecov by Sentry. |
| normalized = op.Mul(x, reciprocal_rms) | ||
| normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) | ||
| # To support float16, we need to ensure the scale is casted or not. | ||
| scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be to=compute_dtype ... it may be better to make it a different variable, say scale_cast_type. And check for correctness in the check condition below. It should basically be the target-dtype, since it is the type of the final output ... the compute-dtype could be something different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we need to check the spec of SimplifiedLayerNorm in ORT to ensure that we are providing a scale value that has a consistent type.
But I am confused about the issue you ran into ... I think the fusion should have happened anyway ... are you trying to eliminate any redundant type cast or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically: what are the different types in the failing case? What is the input type, computation type, and output type? Is the scale type float32? Is it being cast to fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Link to ORT's op definition for reference: https://github.com/microsoft/onnxruntime/blob/cb0c5e9001cd3510ceb25173453373e4f1c7ab09/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L3079 ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While reviewing RMSNormalization pattern match, I recall I forgot to answer this. I ran into an issue that the model specifically turn the dtype to float32 for calculations and then cast it back to whatever its original inputs was.
In RMSNorm, there are compute_type and target_type, which we run the computation on compute_type and then convert it back to target_type after RMSNorm.
Typical example can be found in RMSNorm class in LLMs, like in GPT-OSS: https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L54
Therefore, we need to take op.Cast into pattern consideration.