Skip to content

Use a mutable struct instead of Dict for Safe termination #903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.jl.*.mem
Manifest.toml
.DS_Store
.vscode
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.122.2"
version = "6.123.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
3 changes: 2 additions & 1 deletion src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ export initialize!, finalize!

export SensitivityADPassThrough

export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition,
NLSolveSafeTerminationResult

export KeywordArgError, KeywordArgWarn, KeywordArgSilent

Expand Down
45 changes: 36 additions & 9 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ struct NLSolveSafeTerminationOptions{T1, T2, T3}
min_max_factor::T3
end

TruncatedStacktraces.@truncate_stacktrace NLSolveSafeTerminationOptions

Base.@kwdef mutable struct NLSolveSafeTerminationResult{T}
best_objective_value::T = Inf64
best_objective_value_iteration::Int = 0
return_code::NLSolveSafeTerminationReturnCode.T = NLSolveSafeTerminationReturnCode.Failure
end

# Remove once support for AbstractDict has been dropped
function __setproperty!(n::NLSolveSafeTerminationResult, prop::Symbol, value)
setproperty!(n, prop, value)
end
function __setproperty!(d::AbstractDict, prop::Symbol, value)
d[prop] = value
end

const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
NLSolveTerminationMode.NLSolveDefault,
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
Expand Down Expand Up @@ -89,6 +105,8 @@ struct NLSolveTerminationCondition{mode, T,
safe_termination_options::S
end

TruncatedStacktraces.@truncate_stacktrace NLSolveTerminationCondition 1

function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
print(io,
"NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
Expand Down Expand Up @@ -116,8 +134,14 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options)
end

function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing})
function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict,
NLSolveSafeTerminationResult,
Nothing})
mode = get_termination_mode(cond)
if storage isa AbstractDict
Base.depwarn("`storage` of type ($(typeof(storage)) <: AbstractDict) has been deprecated. Pass in a `NLSolveSafeTerminationResult` instance instead",
:NLSolveTerminationCondition)
end
# We need both the dispatches to support solvers that don't use the integrator
# interface like SimpleNonlinearSolve
if mode in BASIC_TERMINATION_MODES
Expand All @@ -144,8 +168,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier

if mode ∈ SAFE_BEST_TERMINATION_MODES
storage[:best_objective_value] = aType(Inf)
storage[:best_objective_value_iteration] = 0
__setproperty!(storage, :best_objective_value, aType(Inf))
__setproperty!(storage, :best_objective_value_iteration, 0)
end

if mode ∈ SAFE_BEST_TERMINATION_MODES
Expand All @@ -158,14 +182,15 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth

if mode ∈ SAFE_BEST_TERMINATION_MODES
if objective < storage[:best_objective_value]
storage[:best_objective_value] = objective
storage[:best_objective_value_iteration] = nstep + 1
__setproperty!(storage, :best_objective_value, objective)
__setproperty!(storage, :best_objective_value_iteration, nstep + 1)
end
end

# Main Termination Criteria
if objective <= criteria
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.Success)
return true
end

Expand All @@ -181,19 +206,21 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth
if maximum(last_k_values) <
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
minimum(last_k_values)
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.PatienceTermination)
return true
end
end
end

# Protective break
if objective >= objective_values[1] * protective_threshold * length(du)
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
__setproperty!(storage, :return_code,
NLSolveSafeTerminationReturnCode.ProtectiveTermination)
return true
end

storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
__setproperty!(storage, :return_code, NLSolveSafeTerminationReturnCode.Failure)
return false
end
return _termination_condition_closure_safe
Expand Down