Skip to content

Commit 0e86f49

Browse files
Merge pull request #903 from avik-pal/ap/termination_returns
Use a mutable struct instead of Dict for Safe termination
2 parents 3b6e39b + 0e1ce2a commit 0e86f49

File tree

4 files changed

+40
-11
lines changed

4 files changed

+40
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
*.jl.*.mem
55
Manifest.toml
66
.DS_Store
7+
.vscode

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.122.2"
4+
version = "6.123.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/DiffEqBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ export initialize!, finalize!
161161

162162
export SensitivityADPassThrough
163163

164-
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition
164+
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition,
165+
NLSolveSafeTerminationResult
165166

166167
export KeywordArgError, KeywordArgWarn, KeywordArgSilent
167168

src/termination_conditions.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ struct NLSolveSafeTerminationOptions{T1, T2, T3}
2828
min_max_factor::T3
2929
end
3030

31+
TruncatedStacktraces.@truncate_stacktrace NLSolveSafeTerminationOptions
32+
33+
Base.@kwdef mutable struct NLSolveSafeTerminationResult{T}
34+
best_objective_value::T = Inf64
35+
best_objective_value_iteration::Int = 0
36+
return_code::NLSolveSafeTerminationReturnCode.T = NLSolveSafeTerminationReturnCode.Failure
37+
end
38+
39+
# Remove once support for AbstractDict has been dropped
40+
function __setproperty!(n::NLSolveSafeTerminationResult, prop::Symbol, value)
41+
setproperty!(n, prop, value)
42+
end
43+
function __setproperty!(d::AbstractDict, prop::Symbol, value)
44+
d[prop] = value
45+
end
46+
3147
const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
3248
NLSolveTerminationMode.NLSolveDefault,
3349
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
@@ -89,6 +105,8 @@ struct NLSolveTerminationCondition{mode, T,
89105
safe_termination_options::S
90106
end
91107

108+
TruncatedStacktraces.@truncate_stacktrace NLSolveTerminationCondition 1
109+
92110
function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
93111
print(io,
94112
"NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
@@ -116,8 +134,14 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
116134
return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options)
117135
end
118136

119-
function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing})
137+
function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict,
138+
NLSolveSafeTerminationResult,
139+
Nothing})
120140
mode = get_termination_mode(cond)
141+
if storage isa AbstractDict
142+
Base.depwarn("`storage` of type ($(typeof(storage)) <: AbstractDict) has been deprecated. Pass in a `NLSolveSafeTerminationResult` instance instead",
143+
:NLSolveTerminationCondition)
144+
end
121145
# We need both the dispatches to support solvers that don't use the integrator
122146
# interface like SimpleNonlinearSolve
123147
if mode in BASIC_TERMINATION_MODES
@@ -144,8 +168,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
144168
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier
145169

146170
if mode SAFE_BEST_TERMINATION_MODES
147-
storage[:best_objective_value] = aType(Inf)
148-
storage[:best_objective_value_iteration] = 0
171+
__setproperty!(storage, :best_objective_value, aType(Inf))
172+
__setproperty!(storage, :best_objective_value_iteration, 0)
149173
end
150174

151175
if mode SAFE_BEST_TERMINATION_MODES
@@ -158,14 +182,15 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
158182

159183
if mode SAFE_BEST_TERMINATION_MODES
160184
if objective < storage[:best_objective_value]
161-
storage[:best_objective_value] = objective
162-
storage[:best_objective_value_iteration] = nstep + 1
185+
__setproperty!(storage, :best_objective_value, objective)
186+
__setproperty!(storage, :best_objective_value_iteration, nstep + 1)
163187
end
164188
end
165189

166190
# Main Termination Criteria
167191
if objective <= criteria
168-
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
192+
__setproperty!(storage, :return_code,
193+
NLSolveSafeTerminationReturnCode.Success)
169194
return true
170195
end
171196

@@ -181,19 +206,21 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
181206
if maximum(last_k_values) <
182207
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
183208
minimum(last_k_values)
184-
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
209+
__setproperty!(storage, :return_code,
210+
NLSolveSafeTerminationReturnCode.PatienceTermination)
185211
return true
186212
end
187213
end
188214
end
189215

190216
# Protective break
191217
if objective >= objective_values[1] * protective_threshold * length(du)
192-
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
218+
__setproperty!(storage, :return_code,
219+
NLSolveSafeTerminationReturnCode.ProtectiveTermination)
193220
return true
194221
end
195222

196-
storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
223+
__setproperty!(storage, :return_code, NLSolveSafeTerminationReturnCode.Failure)
197224
return false
198225
end
199226
return _termination_condition_closure_safe

0 commit comments

Comments
 (0)