Skip to content

Commit 4b06339

Browse files
Merge pull request #1167 from AstitvaAggarwal/develop
fix MooncakeOriginator
2 parents 134b63a + 131f51f commit 4b06339

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle)
1313

1414
function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
1515
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
16-
u0, p, args...;
16+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
1717
kwargs...)
1818
DiffEqBase._solve_forward(
1919
prob, sensealg, u0, p,
20-
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
20+
originator, args...;
2121
kwargs...)
2222
end
2323

2424
function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem,
2525
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
26-
u0, p, args...;
26+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
2727
kwargs...)
2828
DiffEqBase._solve_adjoint(
2929
prob, sensealg, u0, p,
30-
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
30+
originator, args...;
3131
kwargs...)
3232
end
3333

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ module DiffEqBaseMooncakeExt
22

33
using DiffEqBase, Mooncake
44
using DiffEqBase: SciMLBase
5-
using SciMLBase: ADOriginator, MooncakeOriginator
6-
Mooncake.@from_rrule(Mooncake.MinimalCtx,
5+
using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
6+
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
7+
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
8+
NoPullback
9+
10+
@from_rrule(MinimalCtx,
711
Tuple{
812
typeof(DiffEqBase.solve_up),
913
DiffEqBase.AbstractDEProblem,
@@ -15,7 +19,7 @@ Mooncake.@from_rrule(Mooncake.MinimalCtx,
1519
true,)
1620

1721
# Dispatch for auto-alg
18-
Mooncake.@from_rrule(Mooncake.MinimalCtx,
22+
@from_rrule(MinimalCtx,
1923
Tuple{
2024
typeof(DiffEqBase.solve_up),
2125
DiffEqBase.AbstractDEProblem,
@@ -25,7 +29,16 @@ Mooncake.@from_rrule(Mooncake.MinimalCtx,
2529
},
2630
true,)
2731

28-
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
29-
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator
32+
@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
33+
@is_primitive MinimalCtx Tuple{
34+
typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator
35+
}
36+
37+
function rrule!!(
38+
f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)},
39+
X::CoDual{SciMLBase.ChainRulesOriginator}
40+
)
41+
return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X)
42+
end
3043

3144
end

src/solve.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,14 +1174,29 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
11741174
p = p !== nothing ? p : prob.p
11751175

11761176
if wrap isa Val{true}
1177-
wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...))
1177+
wrap_sol(solve_up(prob,
1178+
sensealg,
1179+
u0,
1180+
p,
1181+
args...;
1182+
alias_u0 = alias_u0,
1183+
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1184+
kwargs...))
11781185
else
1179-
solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...)
1186+
solve_up(prob,
1187+
sensealg,
1188+
u0,
1189+
p,
1190+
args...;
1191+
alias_u0 = alias_u0,
1192+
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1193+
kwargs...)
11801194
end
11811195
end
11821196

11831197
function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
1184-
args...; kwargs...)
1198+
args...; originator = SciMLBase.ChainRulesOriginator(),
1199+
kwargs...)
11851200
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
11861201
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
11871202
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,

0 commit comments

Comments
 (0)