Skip to content

Add termination conditions for NonlinearProblem and SSProblem #878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.118.0"
version = "6.118.1"

[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -34,31 +36,32 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DiffEqBaseZygoteExt = "Zygote"
DiffEqBaseReverseDiffExt = "ReverseDiff"
DiffEqBaseTrackerExt = "Tracker"
DiffEqBaseDistributionsExt = "Distributions"
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseMPIExt = "MPI"
DiffEqBaseMeasurementsExt = "Measurements"
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseReverseDiffExt = "ReverseDiff"
DiffEqBaseTrackerExt = "Tracker"
DiffEqBaseUnitfulExt = "Unitful"
DiffEqBaseMPIExt = "MPI"
DiffEqBaseZygoteExt = "Zygote"

[compat]
ArrayInterfaceCore = "0.1.26"
ChainRulesCore = "1"
DataStructures = "0.18"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnumX = "1"
FastBroadcast = "0.2"
ForwardDiff = "0.10"
FunctionWrappers = "1.0"
Expand Down Expand Up @@ -91,12 +94,12 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random","StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions"]
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions"]
9 changes: 9 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ using Setfield

using ForwardDiff

using EnumX

using Markdown

# Could be made optional/glue
import PreallocationTools

Expand Down Expand Up @@ -127,6 +131,8 @@ include("init.jl")
include("forwarddiff.jl")
include("chainrules.jl")

include("termination_conditions.jl")

include("norecompile.jl")
# This is only used for oop stiff solvers
default_factorize(A) = lu(A; check = false)
Expand All @@ -152,9 +158,12 @@ export initialize!, finalize!

export SensitivityADPassThrough

export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition

export KeywordArgError, KeywordArgWarn, KeywordArgSilent

if !isdefined(Base, :get_extension)
include("../ext/DiffEqBaseDistributionsExt.jl")
end

end # module
5 changes: 4 additions & 1 deletion src/common_defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ end
@inline recursive_length(u::AbstractArray{<:AbstractArray}) = sum(recursive_length, u)
@inline recursive_length(u::RecursiveArrayTools.ArrayPartition) = sum(recursive_length, u.x)
@inline recursive_length(u::RecursiveArrayTools.VectorOfArray) = sum(recursive_length, u.u)
@inline function recursive_length(u::AbstractArray{<:StaticArraysCore.StaticArray{S, <:Number}}) where {S}
@inline function recursive_length(u::AbstractArray{
<:StaticArraysCore.StaticArray{S,
<:Number}
}) where {S}
prod(Size(eltype(u))) * length(u)
end

Expand Down
232 changes: 232 additions & 0 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
@enumx NLSolveSafeTerminationReturnCode begin
Success
PatienceTermination
ProtectiveTermination
Failure
end

# SteadyStateDefault and NLSolveDefault are needed to be compatible with the existing
# termination conditions in NonlinearSolve and SteadyStateDiffEq
@enumx NLSolveTerminationMode begin
SteadyStateDefault
NLSolveDefault
Norm
Rel
RelNorm
Abs
AbsNorm
RelSafe
RelSafeBest
AbsSafe
AbsSafeBest
end

struct NLSolveSafeTerminationOptions{T1, T2, T3}
protective_threshold::T1
patience_steps::Int
patience_objective_multiplier::T2
min_max_factor::T3
end

const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
NLSolveTerminationMode.NLSolveDefault,
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
NLSolveTerminationMode.RelNorm,
NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm)

const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe,
NLSolveTerminationMode.RelSafeBest,
NLSolveTerminationMode.AbsSafe,
NLSolveTerminationMode.AbsSafeBest)

const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest,
NLSolveTerminationMode.AbsSafeBest)

@doc doc"""
NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
protective_threshold = 1e3, patience_steps::Int = 30,
patience_objective_multiplier = 3, min_max_factor = 1.3)

Define the termination criteria for the NonlinearProblem or SteadyStateProblem.

## Termination Conditions

#### Termination on Absolute Tolerance

* `NLSolveTerminationMode.Abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``
* `NLSolveTerminationMode.AbsNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``
* `NLSolveTerminationMode.AbsSafe`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
* `NLSolveTerminationMode.AbsSafeBest`: Same as `NLSolveTerminationMode.AbsSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged

#### Termination on Relative Tolerance

* `NLSolveTerminationMode.Rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``
* `NLSolveTerminationMode.RelNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
* `NLSolveTerminationMode.RelSafe`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
* `NLSolveTerminationMode.RelSafeBest`: Same as `NLSolveTerminationMode.RelSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged

#### Termination using both Absolute and Relative Tolerances

* `NLSolveTerminationMode.Norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` or ``\| \frac{\partial u}{\partial t} \| \leq abstol``
* `NLSolveTerminationMode.SteadyStateDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems but doesn't scale well for neural networks.
* `NLSolveTerminationMode.NLSolveDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. Or check that the value of the current and previous state is within the specified tolerances. This is usable for small problems but doesn't scale well for neural networks.

## General Arguments

* `abstol`: Absolute Tolerance
* `reltol`: Relative Tolerance

## Arguments specific to `*Safe*` modes

* `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately.
* `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate.

"""
struct NLSolveTerminationCondition{mode, T,
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
abstol::T
reltol::T
safe_termination_options::S
end

function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
print(io,
"NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
if mode ∈ SAFE_TERMINATION_MODES
print(io, ", safe_termination_options = ", s.safe_termination_options, ")")
else
print(io, ")")
end
end

get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode

# Don't specify `mode` since the defaults would depend on the package
function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
protective_threshold = 1e3, patience_steps::Int = 30,
patience_objective_multiplier = 3,
min_max_factor = 1.3) where {T}
@assert mode ∈ instances(NLSolveTerminationMode.T)
options = if mode ∈ SAFE_TERMINATION_MODES
NLSolveSafeTerminationOptions(protective_threshold, patience_steps,
patience_objective_multiplier, min_max_factor)
else
nothing
end
return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options)
end

function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing})
mode = get_termination_mode(cond)
# We need both the dispatches to support solvers that don't use the integrator
# interface like SimpleNonlinearSolve
if mode in BASIC_TERMINATION_MODES
function _termination_condition_closure_basic(integrator, abstol, reltol, min_t)
return _termination_condition_closure_basic(get_du(integrator), integrator.u,
integrator.uprev, abstol, reltol)
end
function _termination_condition_closure_basic(du, u, uprev, abstol, reltol)
return _has_converged(du, u, uprev, cond, abstol, reltol)
end
return _termination_condition_closure_basic
else
mode ∈ SAFE_BEST_TERMINATION_MODES && @assert storage !== nothing
nstep::Int = 0

function _termination_condition_closure_safe(integrator, abstol, reltol, min_t)
return _termination_condition_closure_safe(get_du(integrator), integrator.u,
integrator.uprev, abstol, reltol)
end
@inbounds function _termination_condition_closure_safe(du, u, uprev, abstol, reltol)
aType = typeof(abstol)
protective_threshold = aType(cond.safe_termination_options.protective_threshold)
objective_values = aType[]
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier

if mode ∈ SAFE_BEST_TERMINATION_MODES
storage[:best_objective_value] = aType(Inf)
storage[:best_objective_value_iteration] = 0
end

if mode ∈ SAFE_BEST_TERMINATION_MODES
objective = norm(du)
criteria = abstol
else
objective = norm(du) / (norm(du .+ u) + eps(aType))
criteria = reltol
end

if mode ∈ SAFE_BEST_TERMINATION_MODES
if objective < storage[:best_objective_value]
storage[:best_objective_value] = objective
storage[:best_objective_value_iteration] = nstep + 1
end
end

# Main Termination Criteria
if objective <= criteria
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
return true
end

# Terminate if there has been no improvement for the last `patience_steps`
nstep += 1
push!(objective_values, objective)

if objective <= typeof(criteria)(patience_objective_multiplier) * criteria
if nstep >= cond.safe_termination_options.patience_steps
last_k_values = objective_values[max(1,
length(objective_values) -
cond.safe_termination_options.patience_steps):end]
if maximum(last_k_values) <
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
minimum(last_k_values)
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
return true
end
end
end

# Protective break
if objective >= objective_values[1] * protective_threshold * length(du)
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
return true
end

storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
return false
end
return _termination_condition_closure_safe
end
end

# Convergence Criterions
@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode},
abstol = cond.abstol, reltol = cond.reltol) where {mode}
return _has_converged(du, u, uprev, mode, abstol, reltol)
end

@inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol)
if mode == NLSolveTerminationMode.Norm
du_norm = norm(du)
return du_norm <= abstol || du_norm <= reltol * norm(du + u)
elseif mode == NLSolveTerminationMode.Rel
return all(abs.(du) .<= reltol .* abs.(u))
elseif mode ∈ (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe,
NLSolveTerminationMode.RelSafeBest)
return norm(du) <= reltol * norm(du .+ u)
elseif mode == NLSolveTerminationMode.Abs
return all(abs.(du) .<= abstol)
elseif mode ∈ (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe,
NLSolveTerminationMode.AbsSafeBest)
return norm(du) <= abstol
elseif mode == NLSolveTerminationMode.SteadyStateDefault
return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u)))
elseif mode == NLSolveTerminationMode.NLSolveDefault
atol, rtol = abstol, reltol
return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u))) ||
isapprox(u, uprev; atol, rtol)
else
throw(ArgumentError("Unknown termination mode: $mode"))
end
end