Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c86a984

Browse files
committed
Automatic choice for maximum trust region radius
1 parent a161199 commit c86a984

File tree

4 files changed

+60
-48
lines changed

4 files changed

+60
-48
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,19 @@ include("utils.jl")
1919
include("bisection.jl")
2020
include("falsi.jl")
2121
include("raphson.jl")
22-
include("ad.jl")
2322
include("broyden.jl")
2423
include("klement.jl")
2524
include("trustRegion.jl")
25+
include("ad.jl")
2626

2727
import SnoopPrecompile
2828

2929
SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
3030
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
31-
for alg in (SimpleNewtonRaphson, Broyden, Klement)
31+
for alg in (SimpleNewtonRaphson, Broyden, Klement, SimpleTrustRegion)
3232
solve(prob_no_brack, alg(), abstol = T(1e-2))
3333
end
3434

35-
for alg in (SimpleTrustRegion(10.0),)
36-
solve(prob_no_brack, alg, abstol = T(1e-2))
37-
end
38-
3935
#=
4036
for alg in (SimpleNewtonRaphson,)
4137
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])

src/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ end
3030

3131
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
3232
iip,
33-
<:Dual{T, V, P}}, alg::SimpleNewtonRaphson,
33+
<:Dual{T, V, P}},
34+
alg::Union{SimpleNewtonRaphson, SimpleTrustRegion},
3435
args...; kwargs...) where {iip, T, V, P}
3536
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3637
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
@@ -39,7 +40,8 @@ end
3940
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
4041
iip,
4142
<:AbstractArray{<:Dual{T, V, P}}},
42-
alg::SimpleNewtonRaphson, args...; kwargs...) where {iip, T, V, P}
43+
alg::Union{SimpleNewtonRaphson, SimpleTrustRegion}, args...;
44+
kwargs...) where {iip, T, V, P}
4345
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
4446
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
4547
retcode = sol.retcode)

src/trustRegion.jl

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""
22
```julia
3-
SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
4-
autodiff = Val{true}(), diff_type = Val{:forward})
3+
SimpleTrustRegion(; chunk_size = Val{0}(),
4+
autodiff = Val{true}(),
5+
diff_type = Val{:forward},
6+
max_trust_radius::Real = 0.0,
7+
initial_trust_radius::Real = 0.0,
8+
step_threshold::Real = 0.1,
9+
shrink_threshold::Real = 0.25,
10+
expand_threshold::Real = 0.75,
11+
shrink_factor::Real = 0.25,
12+
expand_factor::Real = 2.0,
13+
max_shrink_times::Int = 32
514
```
615
716
A low-overhead implementation of a
817
[trust-region](https://optimization.mccormick.northwestern.edu/index.php/Trust-region_methods)
918
solver
1019
11-
### Arguments
12-
- `max_trust_radius`: the maximum radius of the trust region. The step size in the algorithm
13-
will change dynamically. However, it will never be greater than the `max_trust_radius`.
14-
1520
### Keyword Arguments
1621
1722
- `chunk_size`: the chunk size used by the internal ForwardDiff.jl automatic differentiation
@@ -26,6 +31,8 @@ solver
2631
- `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to
2732
`Val{:forward}` for forward finite differences. For more details on the choices, see the
2833
[FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation.
34+
- `max_trust_radius`: the maximum radius of the trust region. Defaults to
35+
`max(norm(f(u0)), maximum(u0) - minimum(u0))`.
2936
- `initial_trust_radius`: the initial trust region radius. Defaults to
3037
`max_trust_radius / 11`.
3138
- `step_threshold`: the threshold for taking a step. In every iteration, the threshold is
@@ -58,25 +65,28 @@ struct SimpleTrustRegion{T, CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
5865
shrink_factor::T
5966
expand_factor::T
6067
max_shrink_times::Int
61-
function SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
68+
function SimpleTrustRegion(; chunk_size = Val{0}(),
6269
autodiff = Val{true}(),
6370
diff_type = Val{:forward},
64-
initial_trust_radius::Number = max_trust_radius / 11,
65-
step_threshold::Number = 0.1,
66-
shrink_threshold::Number = 0.25,
67-
expand_threshold::Number = 0.75,
68-
shrink_factor::Number = 0.25,
69-
expand_factor::Number = 2.0,
71+
max_trust_radius::Real = 0.0,
72+
initial_trust_radius::Real = 0.0,
73+
step_threshold::Real = 0.1,
74+
shrink_threshold::Real = 0.25,
75+
expand_threshold::Real = 0.75,
76+
shrink_factor::Real = 0.25,
77+
expand_factor::Real = 2.0,
7078
max_shrink_times::Int = 32)
71-
new{typeof(initial_trust_radius), SciMLBase._unwrap_val(chunk_size),
72-
SciMLBase._unwrap_val(autodiff), SciMLBase._unwrap_val(diff_type)}(max_trust_radius,
73-
initial_trust_radius,
74-
step_threshold,
75-
shrink_threshold,
76-
expand_threshold,
77-
shrink_factor,
78-
expand_factor,
79-
max_shrink_times)
79+
new{typeof(initial_trust_radius),
80+
SciMLBase._unwrap_val(chunk_size),
81+
SciMLBase._unwrap_val(autodiff),
82+
SciMLBase._unwrap_val(diff_type)}(max_trust_radius,
83+
initial_trust_radius,
84+
step_threshold,
85+
shrink_threshold,
86+
expand_threshold,
87+
shrink_factor,
88+
expand_factor,
89+
max_shrink_times)
8090
end
8191
end
8292

@@ -114,6 +124,14 @@ function SciMLBase.__solve(prob::NonlinearProblem,
114124
∇f = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg), eltype(x), F)
115125
end
116126

127+
# Set default trust region radius if not specified by user.
128+
if Δₘₐₓ == 0.0
129+
Δₘₐₓ = max(norm(F), maximum(x) - minimum(x))
130+
end
131+
if Δ == 0.0
132+
Δ = Δₘₐₓ / 11
133+
end
134+
117135
fₖ = 0.5 * norm(F)^2
118136
H = ∇f * ∇f
119137
g = ∇f * F

test/basictests.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
# SimpleTrustRegion
5656
function benchmark_scalar(f, u0)
5757
probN = NonlinearProblem{false}(f, u0)
58-
sol = (solve(probN, SimpleTrustRegion(10.0)))
58+
sol = (solve(probN, SimpleTrustRegion()))
5959
end
6060

6161
sol = benchmark_scalar(sf, csu0)
@@ -68,8 +68,7 @@ using ForwardDiff
6868
# Immutable
6969
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
7070

71-
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
72-
SimpleTrustRegion(10.0)]
71+
for alg in (SimpleNewtonRaphson(), Broyden(), Klement(), SimpleTrustRegion())
7372
g = function (p)
7473
probN = NonlinearProblem{false}(f, csu0, p)
7574
sol = solve(probN, alg, abstol = 1e-9)
@@ -84,8 +83,7 @@ end
8483

8584
# Scalar
8685
f, u0 = (u, p) -> u * u - p, 1.0
87-
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
88-
SimpleTrustRegion(10.0)]
86+
for alg in (SimpleNewtonRaphson(), Broyden(), Klement(), SimpleTrustRegion())
8987
g = function (p)
9088
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
9189
sol = solve(probN, alg)
@@ -126,8 +124,7 @@ for alg in [Bisection(), Falsi()]
126124
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
127125
end
128126

129-
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
130-
SimpleTrustRegion(10.0)]
127+
for alg in (SimpleNewtonRaphson(), Broyden(), Klement(), SimpleTrustRegion())
131128
global g, p
132129
g = function (p)
133130
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -144,8 +141,8 @@ probN = NonlinearProblem(f, u0)
144141

145142
@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
146143
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
147-
@test solve(probN, SimpleTrustRegion(10.0)).u[end] sqrt(2.0)
148-
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
144+
@test solve(probN, SimpleTrustRegion()).u[end] sqrt(2.0)
145+
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] sqrt(2.0)
149146
@test solve(probN, Broyden()).u[end] sqrt(2.0)
150147
@test solve(probN, Klement()).u[end] sqrt(2.0)
151148

@@ -159,9 +156,9 @@ for u0 in [1.0, [1, 1.0]]
159156
@test solve(probN, SimpleNewtonRaphson()).u sol
160157
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol
161158

162-
@test solve(probN, SimpleTrustRegion(10.0)).u sol
163-
@test solve(probN, SimpleTrustRegion(10.0)).u sol
164-
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u sol
159+
@test solve(probN, SimpleTrustRegion()).u sol
160+
@test solve(probN, SimpleTrustRegion()).u sol
161+
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u sol
165162

166163
@test solve(probN, Broyden()).u sol
167164

@@ -215,17 +212,16 @@ f = (u, p) -> 0.010000000000000002 .+
215212
(0.21640425613334457 .+
216213
216.40425613334457 ./
217214
(1 .+ 0.0006250000000000001(u .^ 2.0))) .^ 2.0)) .^ 2.0) .-
218-
0.0011552453009332421u
219-
.-p
215+
0.0011552453009332421u .- p
220216
g = function (p)
221217
probN = NonlinearProblem{false}(f, u0, p)
222-
sol = solve(probN, SimpleTrustRegion(100.0))
218+
sol = solve(probN, SimpleTrustRegion())
223219
return sol.u
224220
end
225221
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
226222
u = g(p)
227223
f(u, p)
228-
@test all(f(u, p) .< 1e-10)
224+
@test all(abs.(f(u, p)) .< 1e-10)
229225

230226
# Test kwars in `SimpleTrustRegion`
231227
max_trust_radius = [10.0, 100.0, 1000.0]
@@ -242,7 +238,7 @@ list_of_options = zip(max_trust_radius, initial_trust_radius, step_threshold,
242238
expand_factor, max_shrink_times)
243239
for options in list_of_options
244240
local probN, sol, alg
245-
alg = SimpleTrustRegion(options[1];
241+
alg = SimpleTrustRegion(max_trust_radius = options[1],
246242
initial_trust_radius = options[2],
247243
step_threshold = options[3],
248244
shrink_threshold = options[4],
@@ -253,5 +249,5 @@ for options in list_of_options
253249

254250
probN = NonlinearProblem(f, u0, p)
255251
sol = solve(probN, alg)
256-
@test all(f(u, p) .< 1e-10)
252+
@test all(abs.(f(u, p)) .< 1e-10)
257253
end

0 commit comments

Comments
 (0)