diff --git a/.gitignore b/.gitignore index 0355d91aa..fe8149f5c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ *.jl.*.mem Manifest.toml .DS_Store +.vscode diff --git a/Project.toml b/Project.toml index c7ffc524c..58dd3077c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.122.2" +version = "6.123.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 3ea3c5667..e7d581dfe 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -161,7 +161,8 @@ export initialize!, finalize! export SensitivityADPassThrough -export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition +export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition, + NLSolveSafeTerminationResult export KeywordArgError, KeywordArgWarn, KeywordArgSilent diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 8e22f1a14..48ed234f0 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -28,6 +28,22 @@ struct NLSolveSafeTerminationOptions{T1, T2, T3} min_max_factor::T3 end +TruncatedStacktraces.@truncate_stacktrace NLSolveSafeTerminationOptions + +Base.@kwdef mutable struct NLSolveSafeTerminationResult{T} + best_objective_value::T = Inf64 + best_objective_value_iteration::Int = 0 + return_code::NLSolveSafeTerminationReturnCode.T = NLSolveSafeTerminationReturnCode.Failure +end + +# Remove once support for AbstractDict has been dropped +function __setproperty!(n::NLSolveSafeTerminationResult, prop::Symbol, value) + setproperty!(n, prop, value) +end +function __setproperty!(d::AbstractDict, prop::Symbol, value) + d[prop] = value +end + const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.NLSolveDefault, NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel, @@ -89,6 +105,8 @@ struct NLSolveTerminationCondition{mode, T, safe_termination_options::S end +TruncatedStacktraces.@truncate_stacktrace NLSolveTerminationCondition 1 + function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode} print(io, "NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)") @@ -116,8 +134,14 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options) end -function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing}) +function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, + NLSolveSafeTerminationResult, + Nothing}) mode = get_termination_mode(cond) + if storage isa AbstractDict + Base.depwarn("`storage` of type ($(typeof(storage)) <: AbstractDict) has been deprecated. Pass in a `NLSolveSafeTerminationResult` instance instead", + :NLSolveTerminationCondition) + end # We need both the dispatches to support solvers that don't use the integrator # interface like SimpleNonlinearSolve if mode in BASIC_TERMINATION_MODES @@ -144,8 +168,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier if mode ∈ SAFE_BEST_TERMINATION_MODES - storage[:best_objective_value] = aType(Inf) - storage[:best_objective_value_iteration] = 0 + __setproperty!(storage, :best_objective_value, aType(Inf)) + __setproperty!(storage, :best_objective_value_iteration, 0) end if mode ∈ SAFE_BEST_TERMINATION_MODES @@ -158,14 +182,15 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth if mode ∈ SAFE_BEST_TERMINATION_MODES if objective < storage[:best_objective_value] - storage[:best_objective_value] = objective - storage[:best_objective_value_iteration] = nstep + 1 + __setproperty!(storage, :best_objective_value, objective) + __setproperty!(storage, :best_objective_value_iteration, nstep + 1) end end # Main Termination Criteria if objective <= criteria - storage[:return_code] = NLSolveSafeTerminationReturnCode.Success + __setproperty!(storage, :return_code, + NLSolveSafeTerminationReturnCode.Success) return true end @@ -181,7 +206,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth if maximum(last_k_values) < typeof(criteria)(cond.safe_termination_options.min_max_factor) * minimum(last_k_values) - storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination + __setproperty!(storage, :return_code, + NLSolveSafeTerminationReturnCode.PatienceTermination) return true end end @@ -189,11 +215,12 @@ function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Noth # Protective break if objective >= objective_values[1] * protective_threshold * length(du) - storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination + __setproperty!(storage, :return_code, + NLSolveSafeTerminationReturnCode.ProtectiveTermination) return true end - storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure + __setproperty!(storage, :return_code, NLSolveSafeTerminationReturnCode.Failure) return false end return _termination_condition_closure_safe