Skip to content

Mooncake overlay for erroring out before Mooncake.DerivedRule construction #1169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025

Conversation

AstitvaAggarwal
Copy link
Contributor

@AstitvaAggarwal AstitvaAggarwal commented Jul 9, 2025

Closes SciML/SciMLSensitivity.jl#1230 and chalk-lab/Mooncake.jl#587.

This PR handles cases where:
DiffEqBase._concrete_solve_adjoint must error out when using ReverseDiffAdjoint/TrackerAdjoint while differentiating via Mooncake. This error was already handled by SciMLSensitivity but was not getting hit.
Therefore calling Mooncake.@mooncake_overlay for DiffEqBase.set_mooncakeoriginator_if_mooncake is required otherwise the Mooncake.DerivedRule which contains primal typechecks fails: as Tracker for example adds tags such as Tracker.TrackerReal{Float64} around Float64's to the forward pass primals.

The previous PR handled:
any other case (eg: not using ReverseDiffAdjoint/TrackerAdjoint) when it is required to use a Mooncake.rrule!! for DiffEqBase.set_mooncakeoriginator_if_mooncake in a Mooncake.DerivedRule.

julia> using OrdinaryDiffEq, SciMLSensitivity, Mooncake, Test
       mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]

       odef(du, u, p, t) = du .= u .* p
       const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

       struct senseloss{T}
           sense::T
       end
       function (f::senseloss)(u0p)
           sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12,
               reltol = 1e-12, saveat = 0.1, sensealg = f.sense))
       end
       function loss(u0p)
           sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, reltol = 1e-12,
               saveat = 0.1))
       end
       u0p = [2.0, 3.0]
WARNING: redefinition of constant Main.prob. This may fail, cause incorrect answers, or produce other errors.

julia> @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
Test Passed
      Thrown: SciMLSensitivity.MooncakeTrackedRealError

@ChrisRackauckas ChrisRackauckas merged commit 367b691 into SciML:master Jul 9, 2025
38 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mooncake gives the wrong aggregator and thus does not give contextualized error messages
2 participants