|
| 1 | +@enumx NLSolveSafeTerminationReturnCode begin |
| 2 | + Success |
| 3 | + PatienceTermination |
| 4 | + ProtectiveTermination |
| 5 | + Failure |
| 6 | +end |
| 7 | + |
| 8 | +# SteadyStateDefault and NLSolveDefault are needed to be compatible with the existing |
| 9 | +# termination conditions in NonlinearSolve and SteadyStateDiffEq |
| 10 | +@enumx NLSolveTerminationMode begin |
| 11 | + SteadyStateDefault |
| 12 | + NLSolveDefault |
| 13 | + Norm |
| 14 | + Rel |
| 15 | + RelNorm |
| 16 | + Abs |
| 17 | + AbsNorm |
| 18 | + RelSafe |
| 19 | + RelSafeBest |
| 20 | + AbsSafe |
| 21 | + AbsSafeBest |
| 22 | +end |
| 23 | + |
| 24 | +struct NLSolveSafeTerminationOptions{T1, T2, T3} |
| 25 | + protective_threshold::T1 |
| 26 | + patience_steps::Int |
| 27 | + patience_objective_multiplier::T2 |
| 28 | + min_max_factor::T3 |
| 29 | +end |
| 30 | + |
| 31 | +const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault, |
| 32 | + NLSolveTerminationMode.NLSolveDefault, |
| 33 | + NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel, |
| 34 | + NLSolveTerminationMode.RelNorm, |
| 35 | + NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm) |
| 36 | + |
| 37 | +const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe, |
| 38 | + NLSolveTerminationMode.RelSafeBest, |
| 39 | + NLSolveTerminationMode.AbsSafe, |
| 40 | + NLSolveTerminationMode.AbsSafeBest) |
| 41 | + |
| 42 | +const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest, |
| 43 | + NLSolveTerminationMode.AbsSafeBest) |
| 44 | + |
| 45 | +@doc doc""" |
| 46 | + NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, |
| 47 | + protective_threshold = 1e3, patience_steps::Int = 30, |
| 48 | + patience_objective_multiplier = 3, min_max_factor = 1.3) |
| 49 | +
|
| 50 | +Define the termination criteria for the NonlinearProblem or SteadyStateProblem. |
| 51 | +
|
| 52 | +## Termination Conditions |
| 53 | +
|
| 54 | +#### Termination on Absolute Tolerance |
| 55 | +
|
| 56 | + * `NLSolveTerminationMode.Abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)`` |
| 57 | + * `NLSolveTerminationMode.AbsNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol`` |
| 58 | + * `NLSolveTerminationMode.AbsSafe`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) |
| 59 | + * `NLSolveTerminationMode.AbsSafeBest`: Same as `NLSolveTerminationMode.AbsSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged |
| 60 | +
|
| 61 | +#### Termination on Relative Tolerance |
| 62 | +
|
| 63 | + * `NLSolveTerminationMode.Rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)`` |
| 64 | + * `NLSolveTerminationMode.RelNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` |
| 65 | + * `NLSolveTerminationMode.RelSafe`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) |
| 66 | + * `NLSolveTerminationMode.RelSafeBest`: Same as `NLSolveTerminationMode.RelSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged |
| 67 | +
|
| 68 | +#### Termination using both Absolute and Relative Tolerances |
| 69 | +
|
| 70 | + * `NLSolveTerminationMode.Norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` or ``\| \frac{\partial u}{\partial t} \| \leq abstol`` |
| 71 | + * `NLSolveTerminationMode.SteadyStateDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems but doesn't scale well for neural networks. |
| 72 | + * `NLSolveTerminationMode.NLSolveDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. Or check that the value of the current and previous state is within the specified tolerances. This is usable for small problems but doesn't scale well for neural networks. |
| 73 | +
|
| 74 | +## General Arguments |
| 75 | +
|
| 76 | + * `abstol`: Absolute Tolerance |
| 77 | + * `reltol`: Relative Tolerance |
| 78 | +
|
| 79 | +## Arguments specific to `*Safe*` modes |
| 80 | +
|
| 81 | + * `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately. |
| 82 | + * `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate. |
| 83 | +
|
| 84 | +""" |
| 85 | +struct NLSolveTerminationCondition{mode, T, |
| 86 | + S <: Union{<:NLSolveSafeTerminationOptions, Nothing}} |
| 87 | + abstol::T |
| 88 | + reltol::T |
| 89 | + safe_termination_options::S |
| 90 | +end |
| 91 | + |
| 92 | +function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode} |
| 93 | + print(io, |
| 94 | + "NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)") |
| 95 | + if mode ∈ SAFE_TERMINATION_MODES |
| 96 | + print(io, ", safe_termination_options = ", s.safe_termination_options, ")") |
| 97 | + else |
| 98 | + print(io, ")") |
| 99 | + end |
| 100 | +end |
| 101 | + |
| 102 | +get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode |
| 103 | + |
| 104 | +# Don't specify `mode` since the defaults would depend on the package |
| 105 | +function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, |
| 106 | + protective_threshold = 1e3, patience_steps::Int = 30, |
| 107 | + patience_objective_multiplier = 3, |
| 108 | + min_max_factor = 1.3) where {T} |
| 109 | + @assert mode ∈ instances(NLSolveTerminationMode.T) |
| 110 | + options = if mode ∈ SAFE_TERMINATION_MODES |
| 111 | + NLSolveSafeTerminationOptions(protective_threshold, patience_steps, |
| 112 | + patience_objective_multiplier, min_max_factor) |
| 113 | + else |
| 114 | + nothing |
| 115 | + end |
| 116 | + return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options) |
| 117 | +end |
| 118 | + |
| 119 | +function (cond::NLSolveTerminationCondition)(storage::Union{<:AbstractDict, Nothing}) |
| 120 | + mode = get_termination_mode(cond) |
| 121 | + # We need both the dispatches to support solvers that don't use the integrator |
| 122 | + # interface like SimpleNonlinearSolve |
| 123 | + if mode in BASIC_TERMINATION_MODES |
| 124 | + function _termination_condition_closure_basic(integrator, abstol, reltol, min_t) |
| 125 | + return _termination_condition_closure_basic(get_du(integrator), integrator.u, |
| 126 | + integrator.uprev, abstol, reltol) |
| 127 | + end |
| 128 | + function _termination_condition_closure_basic(du, u, uprev, abstol, reltol) |
| 129 | + return _has_converged(du, u, uprev, cond, abstol, reltol) |
| 130 | + end |
| 131 | + return _termination_condition_closure_basic |
| 132 | + else |
| 133 | + mode ∈ SAFE_BEST_TERMINATION_MODES && @assert storage !== nothing |
| 134 | + nstep::Int = 0 |
| 135 | + |
| 136 | + function _termination_condition_closure_safe(integrator, abstol, reltol, min_t) |
| 137 | + return _termination_condition_closure_safe(get_du(integrator), integrator.u, |
| 138 | + integrator.uprev, abstol, reltol) |
| 139 | + end |
| 140 | + @inbounds function _termination_condition_closure_safe(du, u, uprev, abstol, reltol) |
| 141 | + aType = typeof(abstol) |
| 142 | + protective_threshold = aType(cond.safe_termination_options.protective_threshold) |
| 143 | + objective_values = aType[] |
| 144 | + patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier |
| 145 | + |
| 146 | + if mode ∈ SAFE_BEST_TERMINATION_MODES |
| 147 | + storage[:best_objective_value] = aType(Inf) |
| 148 | + storage[:best_objective_value_iteration] = 0 |
| 149 | + end |
| 150 | + |
| 151 | + if mode ∈ SAFE_BEST_TERMINATION_MODES |
| 152 | + objective = norm(du) |
| 153 | + criteria = abstol |
| 154 | + else |
| 155 | + objective = norm(du) / (norm(du .+ u) + eps(aType)) |
| 156 | + criteria = reltol |
| 157 | + end |
| 158 | + |
| 159 | + if mode ∈ SAFE_BEST_TERMINATION_MODES |
| 160 | + if objective < storage[:best_objective_value] |
| 161 | + storage[:best_objective_value] = objective |
| 162 | + storage[:best_objective_value_iteration] = nstep + 1 |
| 163 | + end |
| 164 | + end |
| 165 | + |
| 166 | + # Main Termination Criteria |
| 167 | + if objective <= criteria |
| 168 | + storage[:return_code] = NLSolveSafeTerminationReturnCode.Success |
| 169 | + return true |
| 170 | + end |
| 171 | + |
| 172 | + # Terminate if there has been no improvement for the last `patience_steps` |
| 173 | + nstep += 1 |
| 174 | + push!(objective_values, objective) |
| 175 | + |
| 176 | + if objective <= typeof(criteria)(patience_objective_multiplier) * criteria |
| 177 | + if nstep >= cond.safe_termination_options.patience_steps |
| 178 | + last_k_values = objective_values[max(1, |
| 179 | + length(objective_values) - |
| 180 | + cond.safe_termination_options.patience_steps):end] |
| 181 | + if maximum(last_k_values) < |
| 182 | + typeof(criteria)(cond.safe_termination_options.min_max_factor) * |
| 183 | + minimum(last_k_values) |
| 184 | + storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination |
| 185 | + return true |
| 186 | + end |
| 187 | + end |
| 188 | + end |
| 189 | + |
| 190 | + # Protective break |
| 191 | + if objective >= objective_values[1] * protective_threshold * length(du) |
| 192 | + storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination |
| 193 | + return true |
| 194 | + end |
| 195 | + |
| 196 | + storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure |
| 197 | + return false |
| 198 | + end |
| 199 | + return _termination_condition_closure_safe |
| 200 | + end |
| 201 | +end |
| 202 | + |
| 203 | +# Convergence Criterions |
| 204 | +@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode}, |
| 205 | + abstol = cond.abstol, reltol = cond.reltol) where {mode} |
| 206 | + return _has_converged(du, u, uprev, mode, abstol, reltol) |
| 207 | +end |
| 208 | + |
| 209 | +@inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol) |
| 210 | + if mode == NLSolveTerminationMode.Norm |
| 211 | + du_norm = norm(du) |
| 212 | + return du_norm <= abstol || du_norm <= reltol * norm(du + u) |
| 213 | + elseif mode == NLSolveTerminationMode.Rel |
| 214 | + return all(abs.(du) .<= reltol .* abs.(u)) |
| 215 | + elseif mode ∈ (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe, |
| 216 | + NLSolveTerminationMode.RelSafeBest) |
| 217 | + return norm(du) <= reltol * norm(du .+ u) |
| 218 | + elseif mode == NLSolveTerminationMode.Abs |
| 219 | + return all(abs.(du) .<= abstol) |
| 220 | + elseif mode ∈ (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe, |
| 221 | + NLSolveTerminationMode.AbsSafeBest) |
| 222 | + return norm(du) <= abstol |
| 223 | + elseif mode == NLSolveTerminationMode.SteadyStateDefault |
| 224 | + return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u))) |
| 225 | + elseif mode == NLSolveTerminationMode.NLSolveDefault |
| 226 | + atol, rtol = abstol, reltol |
| 227 | + return all((abs.(du) .<= abstol) .| (abs.(du) .<= reltol .* abs.(u))) || |
| 228 | + isapprox(u, uprev; atol, rtol) |
| 229 | + else |
| 230 | + throw(ArgumentError("Unknown termination mode: $mode")) |
| 231 | + end |
| 232 | +end |
0 commit comments