Skip to content

Introduce @skip_rewrite_func, @skip_rewrite_type to extend should_rewrite_call #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 13, 2025

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented Jun 6, 2025

Performance comparison

I'm using Muscle.binary_einsum to compare compile time. It is semantically equal to stablehlo.dot_general. What I found is that the compilation time disappears on the second time a function is traced. Notably I get a x5-30 comptime speedup on second tracing. Most probably this recompilation that disappears might be because Muscle.binary_einsum is a type-unstable function and AFAIK that is a dynamic call that Reactant always rewrites. In that case, this PR successfully avoids the rewrites on those cases.

Unfortunately, first trace comp time doesn't get lower and the bottleneck is probably in other part.

julia> a = Tensor(ConcreteRArray(rand(3, 4)), [Index(:i), Index(:j)])
3×4 Tensor{Float64, 2, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}:
 0.00641709  0.247586  0.376454  0.75816
 0.896955    0.577077  0.947384  0.97931
 0.60413     0.567416  0.804769  0.489075

julia> b = Tensor(ConcreteRArray(rand(3, 4)), [Index(:i), Index(:j)])
3×4 Tensor{Float64, 2, ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}:
 0.196665   0.198938  0.421027  0.0856663
 0.0510834  0.430934  0.264092  0.625037
 0.488447   0.624371  0.28303   0.493439

julia> g(x,y) = binary_einsum(x,y)
g (generic function with 1 method)

Before

julia> @time @code_hlo optimize=false g(a, b);
 14.256694 seconds (52.07 M allocations: 2.618 GiB, 5.30% gc time, 99.76% compilation time: <1% of which was recompilation)

julia> @time @code_hlo optimize=false g(a, b);
  0.093615 seconds (5.86 k allocations: 462.086 KiB, 95.72% compilation time)

After

julia> @time @code_hlo optimize=false g(a, b);
 13.495342 seconds (52.50 M allocations: 2.642 GiB, 6.19% gc time, 99.80% compilation time: <1% of which was recompilation)

julia> @time @code_hlo optimize=false g(a, b);
  0.003218 seconds (3.22 k allocations: 359.414 KiB)

Open questions

  • If the top-function in the traced function (e.g. f in @compile f(x)) is marked, it won't be detected and it will rewrite it. Before starting the rewrite, we should check if the top-function must be rewritten. Here the proof that it's not working:
julia> @time @code_hlo optimize=false binary_einsum(a, b);
  1.452055 seconds (2.54 M allocations: 133.979 MiB, 1.04% gc time, 99.55% compilation time)

julia> @time @code_hlo optimize=false binary_einsum(a, b);
  0.020816 seconds (3.66 k allocations: 377.211 KiB, 84.59% compilation time)
  • Is there any way we can mark only certain methods to skip the rewrite (unlike the whole function)? i.e. extend these changes to should_rewrite_invoke.
    • The only idea to implement this that comes to my mind is using another Core.MethodTable that it's only used to check methods that skip rewritting.

@mofeing mofeing requested a review from wsmoses June 6, 2025 17:41
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to confirm?

mofeing

This comment was marked as outdated.

@mofeing mofeing marked this pull request as draft June 7, 2025 22:39
@mofeing mofeing changed the title Introduce @skip_rewrite to extend should_rewrite_call Introduce @skip_rewrite_func, @skip_rewrite_type to extend should_rewrite_call Jun 8, 2025
@mofeing
Copy link
Collaborator Author

mofeing commented Jun 8, 2025

can you add a test to confirm?

done, i'm not sure how to test @skip_rewrite_type but @skip_rewrite_func is tested now

@mofeing mofeing requested a review from wsmoses June 8, 2025 18:20
@mofeing mofeing marked this pull request as ready for review June 9, 2025 08:39
@mofeing
Copy link
Collaborator Author

mofeing commented Jun 13, 2025

Failed tests seem unrelated to the PR.

@mofeing mofeing merged commit a65da93 into main Jun 13, 2025
53 of 56 checks passed
@mofeing mofeing deleted the ss/mark-skip-rewrite-call branch June 13, 2025 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants