@@ -36,26 +36,34 @@ for large-scale and numerically-difficult nonlinear least squares problems.
36
36
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
37
37
construction. This will be fixed in the near future.
38
38
"""
39
- @concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
39
+ @concrete struct GaussNewton{CJ, AD, TC } <: AbstractNewtonAlgorithm{CJ, AD, TC }
40
40
ad:: AD
41
41
linsolve
42
42
precs
43
+ termination_condition:: TC
43
44
end
44
45
45
46
function set_ad (alg:: GaussNewton{CJ} , ad) where {CJ}
46
47
return GaussNewton {CJ} (ad, alg. linsolve, alg. precs)
47
48
end
48
49
49
- function GaussNewton (; concrete_jac = nothing , linsolve = CholeskyFactorization (),
50
- precs = DEFAULT_PRECS, adkwargs... )
50
+ function GaussNewton (; concrete_jac = nothing , linsolve = NormalCholeskyFactorization (),
51
+ precs = DEFAULT_PRECS,
52
+ termination_condition = NLSolveTerminationCondition (NLSolveTerminationMode. AbsNorm;
53
+ abstol = nothing ,
54
+ reltol = nothing ), adkwargs... )
51
55
ad = default_adargs_to_adtype (; adkwargs... )
52
- return GaussNewton {_unwrap_val(concrete_jac)} (ad, linsolve, precs)
56
+ return GaussNewton {_unwrap_val(concrete_jac)} (ad,
57
+ linsolve,
58
+ precs,
59
+ termination_condition)
53
60
end
54
61
55
62
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
56
63
f
57
64
alg
58
65
u
66
+ u_prev
59
67
fu1
60
68
fu2
61
69
fu_new
72
80
internalnorm
73
81
retcode:: ReturnCode.T
74
82
abstol
83
+ reltol
75
84
prob
76
85
stats:: NLStats
86
+ tc_storage
77
87
end
78
88
79
89
function SciMLBase. __init (prob:: NonlinearLeastSquaresProblem{uType, iip} , alg_:: GaussNewton ,
80
- args... ; alias_u0 = false , maxiters = 1000 , abstol = 1e-6 , internalnorm = DEFAULT_NORM,
90
+ args... ; alias_u0 = false , maxiters = 1000 , abstol = nothing , reltol = nothing ,
91
+ internalnorm = DEFAULT_NORM,
81
92
kwargs... ) where {uType, iip}
82
93
alg = get_concrete_algorithm (alg_, prob)
83
94
@unpack f, u0, p = prob
@@ -91,27 +102,46 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
91
102
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches (alg, f, u, p, Val (iip);
92
103
linsolve_with_JᵀJ = Val (true ))
93
104
94
- return GaussNewtonCache {iip} (f, alg, u, fu1, fu2, zero (fu1), du, p, uf, linsolve, J,
95
- JᵀJ, Jᵀf, jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol,
96
- prob, NLStats (1 , 0 , 0 , 0 , 0 ))
105
+ tc = alg. termination_condition
106
+ mode = DiffEqBase. get_termination_mode (tc)
107
+
108
+ atol = _get_tolerance (abstol, tc. abstol, eltype (u))
109
+ rtol = _get_tolerance (reltol, tc. reltol, eltype (u))
110
+
111
+ storage = mode ∈ DiffEqBase. SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult () :
112
+ nothing
113
+
114
+ return GaussNewtonCache {iip} (f, alg, u, copy (u), fu1, fu2, zero (fu1), du, p, uf,
115
+ linsolve, J,
116
+ JᵀJ, Jᵀf, jac_cache, false , maxiters, internalnorm, ReturnCode. Default, atol, rtol,
117
+ prob, NLStats (1 , 0 , 0 , 0 , 0 ), storage)
97
118
end
98
119
99
120
function perform_step! (cache:: GaussNewtonCache{true} )
100
- @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
121
+ @unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
101
122
jacobian!! (J, cache)
102
123
__matmul! (JᵀJ, J' , J)
103
124
__matmul! (Jᵀf, J' , fu1)
104
125
126
+ tc_storage = cache. tc_storage
127
+ termination_condition = cache. alg. termination_condition (tc_storage)
128
+
105
129
# u = u - J \ fu
106
130
linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (JᵀJ), b = _vec (Jᵀf),
107
131
linu = _vec (du), p, reltol = cache. abstol)
108
132
cache. linsolve = linres. cache
109
133
@. u = u - du
110
134
f (cache. fu_new, u, p)
111
135
112
- (cache. internalnorm (cache. fu_new .- cache. fu1) < cache. abstol ||
113
- cache. internalnorm (cache. fu_new) < cache. abstol) &&
136
+ (termination_condition (cache. fu_new .- cache. fu1,
137
+ cache. u,
138
+ u_prev,
139
+ cache. abstol,
140
+ cache. reltol) ||
141
+ termination_condition (cache. fu_new, cache. u, u_prev, cache. abstol, cache. reltol)) &&
114
142
(cache. force_stop = true )
143
+
144
+ @. u_prev = u
115
145
cache. fu1 .= cache. fu_new
116
146
cache. stats. nf += 1
117
147
cache. stats. njacs += 1
@@ -121,7 +151,10 @@ function perform_step!(cache::GaussNewtonCache{true})
121
151
end
122
152
123
153
function perform_step! (cache:: GaussNewtonCache{false} )
124
- @unpack u, fu1, f, p, alg, linsolve = cache
154
+ @unpack u, u_prev, fu1, f, p, alg, linsolve = cache
155
+
156
+ tc_storage = cache. tc_storage
157
+ termination_condition = cache. alg. termination_condition (tc_storage)
125
158
126
159
cache. J = jacobian!! (cache. J, cache)
127
160
@@ -138,7 +171,10 @@ function perform_step!(cache::GaussNewtonCache{false})
138
171
cache. u = @. u - cache. du # `u` might not support mutation
139
172
cache. fu_new = f (cache. u, p)
140
173
141
- (cache. internalnorm (cache. fu_new) < cache. abstol) && (cache. force_stop = true )
174
+ termination_condition (cache. fu_new, cache. u, u_prev, cache. abstol, cache. reltol) &&
175
+ (cache. force_stop = true )
176
+
177
+ cache. u_prev = @. cache. u
142
178
cache. fu1 = cache. fu_new
143
179
cache. stats. nf += 1
144
180
cache. stats. njacs += 1
0 commit comments