Skip to content

Commit 2763d65

Browse files
committed
Add termination condition to gaussnewton and other fixes
1 parent b58355e commit 2763d65

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

src/gaussnewton.jl

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,34 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3636
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
3737
construction. This will be fixed in the near future.
3838
"""
39-
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
39+
@concrete struct GaussNewton{CJ, AD, TC} <: AbstractNewtonAlgorithm{CJ, AD, TC}
4040
ad::AD
4141
linsolve
4242
precs
43+
termination_condition::TC
4344
end
4445

4546
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
4647
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
4748
end
4849

49-
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
50-
precs = DEFAULT_PRECS, adkwargs...)
50+
function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
51+
precs = DEFAULT_PRECS,
52+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
53+
abstol = nothing,
54+
reltol = nothing), adkwargs...)
5155
ad = default_adargs_to_adtype(; adkwargs...)
52-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
56+
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
57+
linsolve,
58+
precs,
59+
termination_condition)
5360
end
5461

5562
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
5663
f
5764
alg
5865
u
66+
u_prev
5967
fu1
6068
fu2
6169
fu_new
@@ -72,12 +80,15 @@ end
7280
internalnorm
7381
retcode::ReturnCode.T
7482
abstol
83+
reltol
7584
prob
7685
stats::NLStats
86+
tc_storage
7787
end
7888

7989
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
80-
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
90+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
91+
internalnorm = DEFAULT_NORM,
8192
kwargs...) where {uType, iip}
8293
alg = get_concrete_algorithm(alg_, prob)
8394
@unpack f, u0, p = prob
@@ -91,27 +102,46 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
91102
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
92103
linsolve_with_JᵀJ = Val(true))
93104

94-
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
95-
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
96-
prob, NLStats(1, 0, 0, 0, 0))
105+
tc = alg.termination_condition
106+
mode = DiffEqBase.get_termination_mode(tc)
107+
108+
atol = _get_tolerance(abstol, tc.abstol, eltype(u))
109+
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))
110+
111+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
112+
nothing
113+
114+
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
115+
linsolve, J,
116+
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol,
117+
prob, NLStats(1, 0, 0, 0, 0), storage)
97118
end
98119

99120
function perform_step!(cache::GaussNewtonCache{true})
100-
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
121+
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
101122
jacobian!!(J, cache)
102123
__matmul!(JᵀJ, J', J)
103124
__matmul!(Jᵀf, J', fu1)
104125

126+
tc_storage = cache.tc_storage
127+
termination_condition = cache.alg.termination_condition(tc_storage)
128+
105129
# u = u - J \ fu
106130
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
107131
linu = _vec(du), p, reltol = cache.abstol)
108132
cache.linsolve = linres.cache
109133
@. u = u - du
110134
f(cache.fu_new, u, p)
111135

112-
(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
113-
cache.internalnorm(cache.fu_new) < cache.abstol) &&
136+
(termination_condition(cache.fu_new .- cache.fu1,
137+
cache.u,
138+
u_prev,
139+
cache.abstol,
140+
cache.reltol) ||
141+
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
114142
(cache.force_stop = true)
143+
144+
@. u_prev = u
115145
cache.fu1 .= cache.fu_new
116146
cache.stats.nf += 1
117147
cache.stats.njacs += 1
@@ -121,7 +151,10 @@ function perform_step!(cache::GaussNewtonCache{true})
121151
end
122152

123153
function perform_step!(cache::GaussNewtonCache{false})
124-
@unpack u, fu1, f, p, alg, linsolve = cache
154+
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
155+
156+
tc_storage = cache.tc_storage
157+
termination_condition = cache.alg.termination_condition(tc_storage)
125158

126159
cache.J = jacobian!!(cache.J, cache)
127160

@@ -138,7 +171,10 @@ function perform_step!(cache::GaussNewtonCache{false})
138171
cache.u = @. u - cache.du # `u` might not support mutation
139172
cache.fu_new = f(cache.u, p)
140173

141-
(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
174+
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
175+
(cache.force_stop = true)
176+
177+
cache.u_prev = @. cache.u
142178
cache.fu1 = cache.fu_new
143179
cache.stats.nf += 1
144180
cache.stats.njacs += 1

src/levenberg.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
9999
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
100100
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
101101
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
102-
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
102+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
103103
abstol = nothing,
104104
reltol = nothing),
105105
adkwargs...)
@@ -209,7 +209,6 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
209209
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
210210
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
211211
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)
212-
213212
end
214213

215214
function perform_step!(cache::LevenbergMarquardtCache{true})
@@ -271,11 +270,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
271270
if (1 - β)^b_uphill * loss loss_old
272271
# Accept step.
273272
cache.u .+= δ
274-
if termination_condition(cache.fu_tmp,
275-
cache.u,
276-
u_prev,
277-
cache.abstol,
278-
cache.reltol)
273+
if loss < cache.abstol
279274
cache.force_stop = true
280275
return nothing
281276
end

0 commit comments

Comments
 (0)