Skip to content

Commit 16a3019

Browse files
Merge pull request #878 from avik-pal/ap/sstermination
Add termination conditions for NonlinearProblem and SSProblem
2 parents 8fed344 + 2eedd63 commit 16a3019

File tree

4 files changed

+257
-10
lines changed

4 files changed

+257
-10
lines changed

Project.toml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.118.0"
4+
version = "6.118.1"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
12+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1213
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1415
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1516
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
19+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1820
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
1921
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
2022
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
@@ -34,31 +36,32 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3436
[weakdeps]
3537
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3638
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
37-
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
3839
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
40+
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
3941
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
4042
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4143
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4244
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4345
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4446

4547
[extensions]
46-
DiffEqBaseZygoteExt = "Zygote"
47-
DiffEqBaseReverseDiffExt = "ReverseDiff"
48-
DiffEqBaseTrackerExt = "Tracker"
4948
DiffEqBaseDistributionsExt = "Distributions"
49+
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
50+
DiffEqBaseMPIExt = "MPI"
5051
DiffEqBaseMeasurementsExt = "Measurements"
5152
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
52-
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
53+
DiffEqBaseReverseDiffExt = "ReverseDiff"
54+
DiffEqBaseTrackerExt = "Tracker"
5355
DiffEqBaseUnitfulExt = "Unitful"
54-
DiffEqBaseMPIExt = "MPI"
56+
DiffEqBaseZygoteExt = "Zygote"
5557

5658
[compat]
5759
ArrayInterfaceCore = "0.1.26"
5860
ChainRulesCore = "1"
5961
DataStructures = "0.18"
6062
Distributions = "0.25"
6163
DocStringExtensions = "0.9"
64+
EnumX = "1"
6265
FastBroadcast = "0.2"
6366
ForwardDiff = "0.10"
6467
FunctionWrappers = "1.0"
@@ -91,12 +94,12 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
9194
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9295
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
9396
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
94-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9597
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
98+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9699
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
97100
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
98101
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
99102
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
100103

101104
[targets]
102-
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random","StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions"]
105+
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions"]

src/DiffEqBase.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ using Setfield
3838

3939
using ForwardDiff
4040

41+
using EnumX
42+
43+
using Markdown
44+
4145
# Could be made optional/glue
4246
import PreallocationTools
4347

@@ -127,6 +131,8 @@ include("init.jl")
127131
include("forwarddiff.jl")
128132
include("chainrules.jl")
129133

134+
include("termination_conditions.jl")
135+
130136
include("norecompile.jl")
131137
# This is only used for oop stiff solvers
132138
default_factorize(A) = lu(A; check = false)
@@ -152,9 +158,12 @@ export initialize!, finalize!
152158

153159
export SensitivityADPassThrough
154160

161+
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition
162+
155163
export KeywordArgError, KeywordArgWarn, KeywordArgSilent
156164

157165
if !isdefined(Base, :get_extension)
158166
include("../ext/DiffEqBaseDistributionsExt.jl")
159167
end
168+
160169
end # module

src/common_defaults.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ end
1515
@inline recursive_length(u::AbstractArray{<:AbstractArray}) = sum(recursive_length, u)
1616
@inline recursive_length(u::RecursiveArrayTools.ArrayPartition) = sum(recursive_length, u.x)
1717
@inline recursive_length(u::RecursiveArrayTools.VectorOfArray) = sum(recursive_length, u.u)
18-
@inline function recursive_length(u::AbstractArray{<:StaticArraysCore.StaticArray{S, <:Number}}) where {S}
18+
@inline function recursive_length(u::AbstractArray{
19+
<:StaticArraysCore.StaticArray{S,
20+
<:Number}
21+
}) where {S}
1922
prod(Size(eltype(u))) * length(u)
2023
end
2124

src/termination_conditions.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
@enumx NLSolveSafeTerminationReturnCode begin
2+
Success
3+
PatienceTermination
4+
ProtectiveTermination
5+
Failure
6+
end
7+
8+
# SteadyStateDefault and NLSolveDefault are needed to be compatible with the existing
9+
# termination conditions in NonlinearSolve and SteadyStateDiffEq
10+
@enumx NLSolveTerminationMode begin
11+
SteadyStateDefault
12+
NLSolveDefault
13+
Norm
14+
Rel
15+
RelNorm
16+
Abs
17+
AbsNorm
18+
RelSafe
19+
RelSafeBest
20+
AbsSafe
21+
AbsSafeBest
22+
end
23+
24+
struct NLSolveSafeTerminationOptions{T1, T2, T3}
25+
protective_threshold::T1
26+
patience_steps::Int
27+
patience_objective_multiplier::T2
28+
min_max_factor::T3
29+
end
30+
31+
const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
32+
NLSolveTerminationMode.NLSolveDefault,
33+
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
34+
NLSolveTerminationMode.RelNorm,
35+
NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm)
36+
37+
const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe,
38+
NLSolveTerminationMode.RelSafeBest,
39+
NLSolveTerminationMode.AbsSafe,
40+
NLSolveTerminationMode.AbsSafeBest)
41+
42+
const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest,
43+
NLSolveTerminationMode.AbsSafeBest)
44+
45+
@doc doc"""
46+
NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
47+
protective_threshold = 1e3, patience_steps::Int = 30,
48+
patience_objective_multiplier = 3, min_max_factor = 1.3)
49+
50+
Define the termination criteria for the NonlinearProblem or SteadyStateProblem.
51+
52+
## Termination Conditions
53+
54+
#### Termination on Absolute Tolerance
55+
56+
* `NLSolveTerminationMode.Abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``
57+
* `NLSolveTerminationMode.AbsNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``
58+
* `NLSolveTerminationMode.AbsSafe`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
59+
* `NLSolveTerminationMode.AbsSafeBest`: Same as `NLSolveTerminationMode.AbsSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged
60+
61+
#### Termination on Relative Tolerance
62+
63+
* `NLSolveTerminationMode.Rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``
64+
* `NLSolveTerminationMode.RelNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
65+
* `NLSolveTerminationMode.RelSafe`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
66+
* `NLSolveTerminationMode.RelSafeBest`: Same as `NLSolveTerminationMode.RelSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged
67+
68+
#### Termination using both Absolute and Relative Tolerances
69+
70+
* `NLSolveTerminationMode.Norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` or ``\| \frac{\partial u}{\partial t} \| \leq abstol``
71+
* `NLSolveTerminationMode.SteadyStateDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems but doesn't scale well for neural networks.
72+
* `NLSolveTerminationMode.NLSolveDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. Or check that the value of the current and previous state is within the specified tolerances. This is usable for small problems but doesn't scale well for neural networks.
73+
74+
## General Arguments
75+
76+
* `abstol`: Absolute Tolerance
77+
* `reltol`: Relative Tolerance
78+
79+
## Arguments specific to `*Safe*` modes
80+
81+
* `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately.
82+
* `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate.
83+
84+
"""
85+
struct NLSolveTerminationCondition{mode, T,
86+
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
87+
abstol::T
88+
reltol::T
89+
safe_termination_options::S
90+
end
91+
92+
function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
93+
print(io,
94+
"NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
95+
if mode SAFE_TERMINATION_MODES
96+
print(io, ", safe_termination_options = ", s.safe_termination_options, ")")
97+
else
98+
print(io, ")")
99+
end
100+
end
101+
102+
get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode
103+
104+
# Don't specify `mode` since the defaults would depend on the package
105+
function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
106+
protective_threshold = 1e3, patience_steps::Int = 30,
107+
patience_objective_multiplier = 3,
108+
min_max_factor = 1.3) where {T}
109+
@assert mode instances(NLSolveTerminationMode.T)
110+
options = if mode SAFE_TERMINATION_MODES
111+
NLSolveSafeTerminationOptions(protective_threshold, patience_steps,
112+
patience_objective_multiplier, min_max_factor)
113+
else
114+
nothing
115+
end
116+
return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options)
117+
end
118+
119+
function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing})
120+
mode = get_termination_mode(cond)
121+
# We need both the dispatches to support solvers that don't use the integrator
122+
# interface like SimpleNonlinearSolve
123+
if mode in BASIC_TERMINATION_MODES
124+
function _termination_condition_closure_basic(integrator, abstol, reltol, min_t)
125+
return _termination_condition_closure_basic(get_du(integrator), integrator.u,
126+
integrator.uprev, abstol, reltol)
127+
end
128+
function _termination_condition_closure_basic(du, u, uprev, abstol, reltol)
129+
return _has_converged(du, u, uprev, cond, abstol, reltol)
130+
end
131+
return _termination_condition_closure_basic
132+
else
133+
mode SAFE_BEST_TERMINATION_MODES && @assert storage !== nothing
134+
nstep::Int = 0
135+
136+
function _termination_condition_closure_safe(integrator, abstol, reltol, min_t)
137+
return _termination_condition_closure_safe(get_du(integrator), integrator.u,
138+
integrator.uprev, abstol, reltol)
139+
end
140+
@inbounds function _termination_condition_closure_safe(du, u, uprev, abstol, reltol)
141+
aType = typeof(abstol)
142+
protective_threshold = aType(cond.safe_termination_options.protective_threshold)
143+
objective_values = aType[]
144+
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier
145+
146+
if mode SAFE_BEST_TERMINATION_MODES
147+
storage[:best_objective_value] = aType(Inf)
148+
storage[:best_objective_value_iteration] = 0
149+
end
150+
151+
if mode SAFE_BEST_TERMINATION_MODES
152+
objective = norm(du)
153+
criteria = abstol
154+
else
155+
objective = norm(du) / (norm(du .+ u) + eps(aType))
156+
criteria = reltol
157+
end
158+
159+
if mode SAFE_BEST_TERMINATION_MODES
160+
if objective < storage[:best_objective_value]
161+
storage[:best_objective_value] = objective
162+
storage[:best_objective_value_iteration] = nstep + 1
163+
end
164+
end
165+
166+
# Main Termination Criteria
167+
if objective <= criteria
168+
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
169+
return true
170+
end
171+
172+
# Terminate if there has been no improvement for the last `patience_steps`
173+
nstep += 1
174+
push!(objective_values, objective)
175+
176+
if objective <= typeof(criteria)(patience_objective_multiplier) * criteria
177+
if nstep >= cond.safe_termination_options.patience_steps
178+
last_k_values = objective_values[max(1,
179+
length(objective_values) -
180+
cond.safe_termination_options.patience_steps):end]
181+
if maximum(last_k_values) <
182+
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
183+
minimum(last_k_values)
184+
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
185+
return true
186+
end
187+
end
188+
end
189+
190+
# Protective break
191+
if objective >= objective_values[1] * protective_threshold * length(du)
192+
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
193+
return true
194+
end
195+
196+
storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
197+
return false
198+
end
199+
return _termination_condition_closure_safe
200+
end
201+
end
202+
203+
# Convergence Criterions
204+
@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode},
205+
abstol = cond.abstol, reltol = cond.reltol) where {mode}
206+
return _has_converged(du, u, uprev, mode, abstol, reltol)
207+
end
208+
209+
@inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol)
210+
if mode == NLSolveTerminationMode.Norm
211+
du_norm = norm(du)
212+
return du_norm <= abstol || du_norm <= reltol * norm(du + u)
213+
elseif mode == NLSolveTerminationMode.Rel
214+
return all(abs.(du) .<= reltol .* abs.(u))
215+
elseif mode (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe,
216+
NLSolveTerminationMode.RelSafeBest)
217+
return norm(du) <= reltol * norm(du .+ u)
218+
elseif mode == NLSolveTerminationMode.Abs
219+
return all(abs.(du) .<= abstol)
220+
elseif mode (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe,
221+
NLSolveTerminationMode.AbsSafeBest)
222+
return norm(du) <= abstol
223+
elseif mode == NLSolveTerminationMode.SteadyStateDefault
224+
return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u)))
225+
elseif mode == NLSolveTerminationMode.NLSolveDefault
226+
atol, rtol = abstol, reltol
227+
return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u))) ||
228+
isapprox(u, uprev; atol, rtol)
229+
else
230+
throw(ArgumentError("Unknown termination mode: $mode"))
231+
end
232+
end

0 commit comments

Comments
 (0)