Skip to content

Conversation

@titaiwangms
Copy link
Contributor

In RMSNorm, there are compute_type and target_type, which we run the computation on compute_type and then convert it back to target_type after RMSNorm.

Typical example can be found in RMSNorm class in LLMs, like in GPT-OSS: https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L54

Therefore, we need to take op.Cast into pattern consideration.

@codecov
Copy link

codecov bot commented Aug 14, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.09%. Comparing base (7407431) to head (2180b0f).
⚠️ Report is 5 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2491   +/-   ##
=======================================
  Coverage   70.08%   70.09%           
=======================================
  Files         212      212           
  Lines       25646    25647    +1     
  Branches     2573     2573           
=======================================
+ Hits        17973    17976    +3     
+ Misses       6783     6781    -2     
  Partials      890      890           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@titaiwangms titaiwangms enabled auto-merge (squash) August 14, 2025 23:09
@titaiwangms titaiwangms merged commit 700bb1a into microsoft:main Aug 15, 2025
25 of 32 checks passed
normalized = op.Mul(x, reciprocal_rms)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
# To support float16, we need to ensure the scale is casted or not.
scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be to=compute_dtype ... it may be better to make it a different variable, say scale_cast_type. And check for correctness in the check condition below. It should basically be the target-dtype, since it is the type of the final output ... the compute-dtype could be something different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we need to check the spec of SimplifiedLayerNorm in ORT to ensure that we are providing a scale value that has a consistent type.

But I am confused about the issue you ran into ... I think the fusion should have happened anyway ... are you trying to eliminate any redundant type cast or something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically: what are the different types in the failing case? What is the input type, computation type, and output type? Is the scale type float32? Is it being cast to fp16?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While reviewing RMSNormalization pattern match, I recall I forgot to answer this. I ran into an issue that the model specifically turn the dtype to float32 for calculations and then cast it back to whatever its original inputs was.

https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L54

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants