Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 4c9f922

Browse files
Merge pull request #45 from avik-pal/ap/termination_broyden
Add Termination Conditions to Broyden
2 parents f250799 + 5a3db2b commit 4c9f922

File tree

4 files changed

+105
-55
lines changed

4 files changed

+105
-55
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ SimpleBatchedNonlinearSolveExt = "NNlib"
2323

2424
[compat]
2525
ArrayInterface = "6, 7"
26-
DiffEqBase = "6.114"
2726
FiniteDiff = "2"
2827
ForwardDiff = "0.10.3"
2928
NNlib = "0.8"

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleBatchedNonlinearSolveExt
22

3-
using ArrayInterface, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
3+
using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
4+
45
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
56

67
_batch_transpose(x) = reshape(x, 1, size(x)...)
@@ -31,6 +32,8 @@ end
3132

3233
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
3334
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
35+
tc = alg.termination_condition
36+
mode = DiffEqBase.get_termination_mode(tc)
3437
f = Base.Fix2(prob.f, prob.p)
3538
x = float(prob.u0)
3639

@@ -47,8 +50,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
4750
end
4851

4952
atol = abstol !== nothing ? abstol :
50-
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
51-
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
53+
(tc.abstol !== nothing ? tc.abstol :
54+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
55+
rtol = reltol !== nothing ? reltol :
56+
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))
57+
58+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
59+
error("Broyden currently doesn't support SAFE_BEST termination modes")
60+
end
61+
62+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
63+
termination_condition = tc(storage)
5264

5365
xₙ = x
5466
xₙ₋₁ = x
@@ -63,14 +75,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
6375
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
6476
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
6577

66-
iszero(fₙ) &&
67-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
68-
retcode = ReturnCode.Success)
69-
70-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
71-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
72-
retcode = ReturnCode.Success)
78+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
79+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
7380
end
81+
7482
xₙ₋₁ = xₙ
7583
fₙ₋₁ = fₙ
7684
end

src/broyden.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
Broyden(; batched = false)
2+
Broyden(; batched = false,
3+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
4+
abstol = nothing, reltol = nothing))
35
46
A low-overhead implementation of Broyden. This method is non-allocating on scalar
57
and static array problems.
@@ -9,12 +11,22 @@ and static array problems.
911
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
1012
`import NNlib` must be present in your code.
1113
"""
12-
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
13-
Broyden(; batched = false) = new{batched}()
14+
struct Broyden{batched, TC <: NLSolveTerminationCondition} <:
15+
AbstractSimpleNonlinearSolveAlgorithm
16+
termination_condition::TC
17+
18+
function Broyden(; batched = false,
19+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
20+
abstol = nothing,
21+
reltol = nothing))
22+
return new{batched, typeof(termination_condition)}(termination_condition)
23+
end
1424
end
1525

1626
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
1727
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
28+
tc = alg.termination_condition
29+
mode = DiffEqBase.get_termination_mode(tc)
1830
f = Base.Fix2(prob.f, prob.p)
1931
x = float(prob.u0)
2032

@@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
2739
end
2840

2941
atol = abstol !== nothing ? abstol :
30-
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
31-
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
42+
(tc.abstol !== nothing ? tc.abstol :
43+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
44+
rtol = reltol !== nothing ? reltol :
45+
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))
46+
47+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
48+
error("Broyden currently doesn't support SAFE_BEST termination modes")
49+
end
50+
51+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
52+
termination_condition = tc(storage)
3253

3354
xₙ = x
3455
xₙ₋₁ = x
@@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
4162
J⁻¹Δfₙ = J⁻¹ * Δfₙ
4263
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
4364

44-
iszero(fₙ) &&
45-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
46-
retcode = ReturnCode.Success)
47-
48-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
49-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
50-
retcode = ReturnCode.Success)
65+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
66+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
5167
end
68+
5269
xₙ₋₁ = xₙ
5370
fₙ₋₁ = fₙ
5471
end

test/basictests.jl

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
using SimpleNonlinearSolve
22
using StaticArrays
33
using BenchmarkTools
4+
using DiffEqBase
45
using Test
56

7+
const BATCHED_BROYDEN_SOLVERS = Broyden[]
8+
const BROYDEN_SOLVERS = Broyden[]
9+
10+
for mode in instances(NLSolveTerminationMode.T)
11+
if mode
12+
(NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest,
13+
NLSolveTerminationMode.AbsSafeBest)
14+
continue
15+
end
16+
17+
termination_condition = NLSolveTerminationCondition(mode; abstol = nothing,
18+
reltol = nothing)
19+
push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition))
20+
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
21+
end
22+
623
# SimpleNewtonRaphson
724
function benchmark_scalar(f, u0)
825
probN = NonlinearProblem{false}(f, u0)
@@ -50,16 +67,19 @@ if VERSION >= v"1.7"
5067
end
5168

5269
# Broyden
53-
function benchmark_scalar(f, u0)
70+
function benchmark_scalar(f, u0, alg)
5471
probN = NonlinearProblem{false}(f, u0)
55-
sol = (solve(probN, Broyden()))
72+
sol = (solve(probN, alg))
5673
end
5774

58-
sol = benchmark_scalar(sf, csu0)
59-
@test sol.retcode === ReturnCode.Success
60-
@test sol.u * sol.u - 2 < 1e-9
61-
if VERSION >= v"1.7"
62-
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
75+
for alg in BROYDEN_SOLVERS
76+
sol = benchmark_scalar(sf, csu0, alg)
77+
@test sol.retcode === ReturnCode.Success
78+
@test sol.u * sol.u - 2 < 1e-9
79+
# FIXME: Termination Condition Implementation is allocating. Not sure how to fix it.
80+
# if VERSION >= v"1.7"
81+
# @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0
82+
# end
6383
end
6484

6585
# Klement
@@ -101,8 +121,8 @@ using ForwardDiff
101121
# Immutable
102122
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
103123

104-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
105-
SimpleDFSane())
124+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
125+
SimpleDFSane(), BROYDEN_SOLVERS...)
106126
g = function (p)
107127
probN = NonlinearProblem{false}(f, csu0, p)
108128
sol = solve(probN, alg, abstol = 1e-9)
@@ -117,8 +137,8 @@ end
117137

118138
# Scalar
119139
f, u0 = (u, p) -> u * u - p, 1.0
120-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
121-
SimpleDFSane(), Halley())
140+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
141+
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
122142
g = function (p)
123143
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
124144
sol = solve(probN, alg)
@@ -183,8 +203,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()]
183203
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
184204
end
185205

186-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
187-
SimpleDFSane(), Halley())
206+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
207+
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
188208
global g, p
189209
g = function (p)
190210
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -199,14 +219,15 @@ end
199219
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
200220
probN = NonlinearProblem(f, u0)
201221

202-
@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
203-
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
204-
@test solve(probN, SimpleTrustRegion()).u[end] sqrt(2.0)
205-
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] sqrt(2.0)
206-
@test solve(probN, Broyden()).u[end] sqrt(2.0)
207-
@test solve(probN, LBroyden()).u[end] sqrt(2.0)
208-
@test solve(probN, Klement()).u[end] sqrt(2.0)
209-
@test solve(probN, SimpleDFSane()).u[end] sqrt(2.0)
222+
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
223+
SimpleTrustRegion(),
224+
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
225+
BROYDEN_SOLVERS...)
226+
sol = solve(probN, alg)
227+
228+
@test sol.retcode == ReturnCode.Success
229+
@test sol.u[end] sqrt(2.0)
230+
end
210231

211232
# Separate Error check for Halley; will be included in above error checks for the improved Halley
212233
f, u0 = (u, p) -> u * u - 2.0, 1.0
@@ -220,18 +241,16 @@ for u0 in [1.0, [1, 1.0]]
220241
probN = NonlinearProblem(f, u0)
221242
sol = sqrt(2) * u0
222243

223-
@test solve(probN, SimpleNewtonRaphson()).u sol
224-
@test solve(probN, SimpleNewtonRaphson()).u sol
225-
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol
226-
227-
@test solve(probN, SimpleTrustRegion()).u sol
228-
@test solve(probN, SimpleTrustRegion()).u sol
229-
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u sol
244+
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
245+
SimpleTrustRegion(),
246+
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(),
247+
SimpleDFSane(),
248+
BROYDEN_SOLVERS...)
249+
sol2 = solve(probN, alg)
230250

231-
@test solve(probN, Broyden()).u sol
232-
@test solve(probN, LBroyden()).u sol
233-
@test solve(probN, Klement()).u sol
234-
@test solve(probN, SimpleDFSane()).u sol
251+
@test sol2.retcode == ReturnCode.Success
252+
@test sol2.u sol
253+
end
235254
end
236255

237256
# Bisection Tests
@@ -411,3 +430,10 @@ probN = NonlinearProblem{false}(f, u0, p);
411430
sol = solve(probN, Broyden(batched = true))
412431

413432
@test abs.(sol.u) sqrt.(p)
433+
434+
for alg in BATCHED_BROYDEN_SOLVERS
435+
sol = solve(probN, alg)
436+
437+
@test sol.retcode == ReturnCode.Success
438+
@test abs.(sol.u) sqrt.(p)
439+
end

0 commit comments

Comments
 (0)