Skip to content

Commit 7feed37

Browse files
fix: compile symbolic affects after mtkcompile in complete
1 parent 4ff2c09 commit 7feed37

File tree

4 files changed

+170
-119
lines changed

4 files changed

+170
-119
lines changed

src/systems/abstractsystem.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ function complete(
646646
if add_initial_parameters
647647
sys = add_initialization_parameters(sys; split)
648648
end
649+
if has_continuous_events(sys) && is_time_dependent(sys)
650+
@set! sys.continuous_events = complete.(get_continuous_events(sys); iv = get_iv(sys), alg_eqs = [alg_equations(sys); observed(sys)])
651+
end
652+
if has_discrete_events(sys) && is_time_dependent(sys)
653+
@set! sys.discrete_events = complete.(get_discrete_events(sys); iv = get_iv(sys), alg_eqs = [alg_equations(sys); observed(sys)])
654+
end
649655
end
650656
if split && has_index_cache(sys)
651657
@set! sys.index_cache = IndexCache(sys)

src/systems/callbacks.jl

Lines changed: 138 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@ function has_functional_affect(cb)
44
affects(cb) isa ImperativeAffect
55
end
66

7+
struct SymbolicAffect{A, K}
8+
affect::A
9+
kwargs::K
10+
end
11+
12+
SymbolicAffect(affect::Vector{Equation}; kwargs...) = SymbolicAffect(affect, kwargs)
13+
SymbolicAffect(affect::SymbolicAffect; kwargs...) = SymbolicAffect(affect.affect; affect.kwargs..., kwargs...)
14+
SymbolicAffect(affect; kwargs...) = affect
15+
716
struct AffectSystem
817
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
918
system::AbstractSystem
@@ -15,6 +24,71 @@ struct AffectSystem
1524
discretes::Vector
1625
end
1726

27+
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
28+
AffectSystem(spec.affect; iv = something(iv, get(spec.kwargs, :iv, nothing), Some(nothing)), alg_eqs = vcat(get(spec.kwargs, :alg_eqs, Equation[]), alg_eqs), kwargs...)
29+
end
30+
31+
function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
32+
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
33+
isempty(affect) && return nothing
34+
if isnothing(iv)
35+
iv = t_nounits
36+
@warn "No independent variable specified. Defaulting to t_nounits."
37+
end
38+
39+
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
40+
discrete_parameters = unwrap.(discrete_parameters)
41+
42+
for p in discrete_parameters
43+
occursin(unwrap(iv), unwrap(p)) ||
44+
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
45+
end
46+
47+
dvs = OrderedSet()
48+
params = OrderedSet()
49+
_varsbuf = Set()
50+
for eq in affect
51+
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
52+
symbolic_type(eq.lhs) === NotSymbolic())
53+
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
54+
end
55+
collect_vars!(dvs, params, eq, iv; op = Pre)
56+
empty!(_varsbuf)
57+
vars!(_varsbuf, eq; op = Pre)
58+
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
59+
union!(params, _varsbuf)
60+
diffvs = collect_applied_operators(eq, Differential)
61+
union!(dvs, diffvs)
62+
end
63+
for eq in alg_eqs
64+
collect_vars!(dvs, params, eq, iv)
65+
end
66+
pre_params = filter(haspre value, params)
67+
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
68+
discretes = map(tovar, discrete_parameters)
69+
dvs = collect(dvs)
70+
_dvs = map(default_toterm, dvs)
71+
72+
rev_map = Dict(zip(discrete_parameters, discretes))
73+
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
74+
affect = Symbolics.fast_substitute(affect, subs)
75+
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
76+
77+
@named affectsys = System(
78+
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
79+
collect(union(pre_params, sys_params)); is_discrete = true)
80+
affectsys = mtkcompile(affectsys; fully_determined = nothing)
81+
# get accessed parameters p from Pre(p) in the callback parameters
82+
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
83+
union!(accessed_params, sys_params)
84+
85+
# add scalarized unknowns to the map.
86+
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
87+
88+
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
89+
collect(discrete_parameters))
90+
end
91+
1892
system(a::AffectSystem) = a.system
1993
discretes(a::AffectSystem) = a.discretes
2094
unknowns(a::AffectSystem) = a.unknowns
@@ -159,40 +233,40 @@ will run as soon as the solver starts, while finalization affects will be execut
159233
"""
160234
struct SymbolicContinuousCallback <: AbstractCallback
161235
conditions::Vector{Equation}
162-
affect::Union{Affect, Nothing}
163-
affect_neg::Union{Affect, Nothing}
164-
initialize::Union{Affect, Nothing}
165-
finalize::Union{Affect, Nothing}
236+
affect::Union{Affect, SymbolicAffect, Nothing}
237+
affect_neg::Union{Affect, SymbolicAffect, Nothing}
238+
initialize::Union{Affect, SymbolicAffect, Nothing}
239+
finalize::Union{Affect, SymbolicAffect, Nothing}
166240
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
167241
reinitializealg::SciMLBase.DAEInitializationAlgorithm
242+
end
168243

169-
function SymbolicContinuousCallback(
170-
conditions::Union{Equation, Vector{Equation}},
171-
affect = nothing;
172-
affect_neg = affect,
173-
initialize = nothing,
174-
finalize = nothing,
175-
rootfind = SciMLBase.LeftRootFind,
176-
reinitializealg = nothing,
177-
kwargs...)
178-
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
179-
180-
if isnothing(reinitializealg)
181-
if any(a -> a isa ImperativeAffect,
182-
[affect, affect_neg, initialize, finalize])
183-
reinitializealg = SciMLBase.CheckInit()
184-
else
185-
reinitializealg = SciMLBase.NoInit()
186-
end
244+
function SymbolicContinuousCallback(
245+
conditions::Union{Equation, Vector{Equation}},
246+
affect = nothing;
247+
affect_neg = affect,
248+
initialize = nothing,
249+
finalize = nothing,
250+
rootfind = SciMLBase.LeftRootFind,
251+
reinitializealg = nothing,
252+
kwargs...)
253+
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
254+
255+
if isnothing(reinitializealg)
256+
if any(a -> a isa ImperativeAffect,
257+
[affect, affect_neg, initialize, finalize])
258+
reinitializealg = SciMLBase.CheckInit()
259+
else
260+
reinitializealg = SciMLBase.NoInit()
187261
end
262+
end
188263

189-
new(conditions, make_affect(affect; kwargs...),
190-
make_affect(affect_neg; kwargs...),
191-
make_affect(initialize; kwargs...), make_affect(
192-
finalize; kwargs...),
193-
rootfind, reinitializealg)
194-
end # Default affect to nothing
195-
end
264+
SymbolicContinuousCallback(conditions, SymbolicAffect(affect; kwargs...),
265+
SymbolicAffect(affect_neg; kwargs...),
266+
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
267+
finalize; kwargs...),
268+
rootfind, reinitializealg)
269+
end # Default affect to nothing
196270

197271
function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
198272
SymbolicContinuousCallback(p[1], p[2], args...; kwargs...)
@@ -207,72 +281,18 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
207281
end
208282
end
209283

284+
function complete(cb::SymbolicContinuousCallback; kwargs...)
285+
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
286+
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
287+
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
288+
end
289+
290+
make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
210291
make_affect(affect::Nothing; kwargs...) = nothing
211292
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
212293
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
213294
make_affect(affect::Affect; kwargs...) = affect
214295

215-
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
216-
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
217-
isempty(affect) && return nothing
218-
if isnothing(iv)
219-
iv = t_nounits
220-
@warn "No independent variable specified. Defaulting to t_nounits."
221-
end
222-
223-
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
224-
discrete_parameters = unwrap.(discrete_parameters)
225-
226-
for p in discrete_parameters
227-
occursin(unwrap(iv), unwrap(p)) ||
228-
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
229-
end
230-
231-
dvs = OrderedSet()
232-
params = OrderedSet()
233-
_varsbuf = Set()
234-
for eq in affect
235-
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
236-
symbolic_type(eq.lhs) === NotSymbolic())
237-
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
238-
end
239-
collect_vars!(dvs, params, eq, iv; op = Pre)
240-
empty!(_varsbuf)
241-
vars!(_varsbuf, eq; op = Pre)
242-
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
243-
union!(params, _varsbuf)
244-
diffvs = collect_applied_operators(eq, Differential)
245-
union!(dvs, diffvs)
246-
end
247-
for eq in alg_eqs
248-
collect_vars!(dvs, params, eq, iv)
249-
end
250-
pre_params = filter(haspre value, params)
251-
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
252-
discretes = map(tovar, discrete_parameters)
253-
dvs = collect(dvs)
254-
_dvs = map(default_toterm, dvs)
255-
256-
rev_map = Dict(zip(discrete_parameters, discretes))
257-
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
258-
affect = Symbolics.fast_substitute(affect, subs)
259-
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
260-
261-
@named affectsys = System(
262-
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
263-
collect(union(pre_params, sys_params)); is_discrete = true)
264-
affectsys = mtkcompile(affectsys; fully_determined = nothing)
265-
# get accessed parameters p from Pre(p) in the callback parameters
266-
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
267-
union!(accessed_params, sys_params)
268-
269-
# add scalarized unknowns to the map.
270-
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
271-
272-
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
273-
collect(discrete_parameters))
274-
end
275-
276296
function make_affect(affect; kwargs...)
277297
error("Malformed affect $(affect). This should be a vector of equations or a tuple specifying a functional affect.")
278298
end
@@ -374,30 +394,30 @@ Arguments:
374394
"""
375395
struct SymbolicDiscreteCallback <: AbstractCallback
376396
conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}}
377-
affect::Union{Affect, Nothing}
378-
initialize::Union{Affect, Nothing}
379-
finalize::Union{Affect, Nothing}
397+
affect::Union{Affect, SymbolicAffect, Nothing}
398+
initialize::Union{Affect, SymbolicAffect, Nothing}
399+
finalize::Union{Affect, SymbolicAffect, Nothing}
380400
reinitializealg::SciMLBase.DAEInitializationAlgorithm
401+
end
381402

382-
function SymbolicDiscreteCallback(
383-
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
384-
initialize = nothing, finalize = nothing,
385-
reinitializealg = nothing, kwargs...)
386-
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
387-
388-
if isnothing(reinitializealg)
389-
if any(a -> a isa ImperativeAffect,
390-
[affect, initialize, finalize])
391-
reinitializealg = SciMLBase.CheckInit()
392-
else
393-
reinitializealg = SciMLBase.NoInit()
394-
end
403+
function SymbolicDiscreteCallback(
404+
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
405+
initialize = nothing, finalize = nothing,
406+
reinitializealg = nothing, kwargs...)
407+
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
408+
409+
if isnothing(reinitializealg)
410+
if any(a -> a isa ImperativeAffect,
411+
[affect, initialize, finalize])
412+
reinitializealg = SciMLBase.CheckInit()
413+
else
414+
reinitializealg = SciMLBase.NoInit()
395415
end
396-
new(c, make_affect(affect; kwargs...),
397-
make_affect(initialize; kwargs...),
398-
make_affect(finalize; kwargs...), reinitializealg)
399-
end # Default affect to nothing
400-
end
416+
end
417+
SymbolicDiscreteCallback(c, SymbolicAffect(affect; kwargs...),
418+
SymbolicAffect(initialize; kwargs...),
419+
SymbolicAffect(finalize; kwargs...), reinitializealg)
420+
end # Default affect to nothing
401421

402422
function SymbolicDiscreteCallback(p::Pair, args...; kwargs...)
403423
SymbolicDiscreteCallback(p[1], p[2], args...; kwargs...)
@@ -412,6 +432,10 @@ function SymbolicDiscreteCallback(cb::Tuple, args...; kwargs...)
412432
end
413433
end
414434

435+
function complete(cb::SymbolicDiscreteCallback; kwargs...)
436+
SymbolicDiscreteCallback(cb.conditions, make_affect(cb.affect; kwargs...), make_affect(cb.initialize; kwargs...), make_affect(cb.finalize; kwargs...), cb.reinitializealg)
437+
end
438+
415439
function is_timed_condition(condition::T) where {T}
416440
if T === Num
417441
false
@@ -1060,12 +1084,8 @@ end
10601084
"""
10611085
Process the symbolic events of a system.
10621086
"""
1063-
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
1064-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
1065-
sys_eqs)
1066-
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
1067-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1068-
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
1069-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1087+
function create_symbolic_events(cont_events, disc_events)
1088+
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback)
1089+
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback)
10701090
cont_callbacks, disc_callbacks
10711091
end

src/systems/system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
389389
end
390390
continuous_events,
391391
discrete_events = create_symbolic_events(
392-
continuous_events, discrete_events, eqs, iv)
392+
continuous_events, discrete_events)
393393

394394
if iv === nothing && (!isempty(continuous_events) || !isempty(discrete_events))
395395
throw(EventsInTimeIndependentSystemError(continuous_events, discrete_events))

test/symbolic_events.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,3 +1378,28 @@ end
13781378
@test SciMLBase.successful_retcode(sol)
13791379
@test sol[x, end]1.0 atol=1e-6
13801380
end
1381+
1382+
@testset "Symbolic affects are compiled in `complete`" begin
1383+
@parameters g
1384+
@variables x(t) [state_priority = 10.0] y(t) [guess = 1.0]
1385+
@variables λ(t) [guess = 1.0]
1386+
eqs = [D(D(x)) ~ λ * x
1387+
D(D(y)) ~ λ * y - g
1388+
x^2 + y^2 ~ 1]
1389+
cevts = [[x ~ 0.0] => [D(x) ~ Pre(D(x)) + 1sign(Pre(D(x)))]]
1390+
@named pend = System(eqs, t; continuous_events = cevts)
1391+
1392+
scc = only(continuous_events(pend))
1393+
@test scc.affect isa ModelingToolkit.SymbolicAffect
1394+
1395+
pend = mtkcompile(pend)
1396+
1397+
scc = only(continuous_events(pend))
1398+
@test scc.affect isa ModelingToolkit.AffectSystem
1399+
@test length(ModelingToolkit.all_equations(scc.affect)) == 5 # 1 affect, 3 algebraic, 1 observed
1400+
1401+
u0 = [x => -1/2, D(x) => 1/2, g => 1]
1402+
prob = ODEProblem(pend, u0, (0.0, 5.0))
1403+
sol = solve(prob, FBDF())
1404+
@test SciMLBase.successful_retcode(sol)
1405+
end

0 commit comments

Comments
 (0)