diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl index 2ea7a4cf2..bfcfe8313 100644 --- a/ext/DiffEqBaseChainRulesCoreExt.jl +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -13,21 +13,21 @@ ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle) function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), kwargs...) DiffEqBase._solve_forward( prob, sensealg, u0, p, - set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + originator, args...; kwargs...) end function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), kwargs...) DiffEqBase._solve_adjoint( prob, sensealg, u0, p, - set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...; + originator, args...; kwargs...) end diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index ad000e62e..29aff5271 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -2,8 +2,12 @@ module DiffEqBaseMooncakeExt using DiffEqBase, Mooncake using DiffEqBase: SciMLBase -using SciMLBase: ADOriginator, MooncakeOriginator -Mooncake.@from_rrule(Mooncake.MinimalCtx, +using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator +import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, + @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, + NoPullback + +@from_rrule(MinimalCtx, Tuple{ typeof(DiffEqBase.solve_up), DiffEqBase.AbstractDEProblem, @@ -15,7 +19,7 @@ Mooncake.@from_rrule(Mooncake.MinimalCtx, true,) # Dispatch for auto-alg -Mooncake.@from_rrule(Mooncake.MinimalCtx, +@from_rrule(MinimalCtx, Tuple{ typeof(DiffEqBase.solve_up), DiffEqBase.AbstractDEProblem, @@ -25,7 +29,16 @@ Mooncake.@from_rrule(Mooncake.MinimalCtx, }, true,) -Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} -Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator +@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} +@is_primitive MinimalCtx Tuple{ + typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator +} + +function rrule!!( + f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)}, + X::CoDual{SciMLBase.ChainRulesOriginator} +) + return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X) +end end diff --git a/src/solve.jl b/src/solve.jl index de102cc3b..633e0246d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1174,14 +1174,29 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, p = p !== nothing ? p : prob.p if wrap isa Val{true} - wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...)) + wrap_sol(solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), + kwargs...)) else - solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...) + solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), + kwargs...) end end function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, - args...; kwargs...) + args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,