Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit a516a4a

Browse files
Merge pull request #43 from avik-pal/ap/batch
Batched Broyden
2 parents e600ee5 + 9f2b6b2 commit a516a4a

File tree

7 files changed

+134
-15
lines changed

7 files changed

+134
-15
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Manifest.toml
2+
3+
wip

Project.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
@@ -10,15 +10,23 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1415
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
1516
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1617

18+
[weakdeps]
19+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
20+
21+
[extensions]
22+
SimpleBatchedNonlinearSolveExt = "NNlib"
23+
1724
[compat]
1825
ArrayInterfaceCore = "0.1.1"
1926
DiffEqBase = "6.114"
2027
FiniteDiff = "2"
2128
ForwardDiff = "0.10.3"
29+
NNlib = "0.8"
2230
Reexport = "0.2, 1"
2331
SciMLBase = "1.73"
2432
SnoopPrecompile = "1"
@@ -27,10 +35,11 @@ julia = "1.6"
2735

2836
[extras]
2937
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
38+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3039
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3140
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3241
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3342
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3443

3544
[targets]
36-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays"]
45+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module SimpleBatchedNonlinearSolveExt
2+
3+
using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
4+
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
5+
6+
_batch_transpose(x) = reshape(x, 1, size(x)...)
7+
8+
_batched_mul(x, y) = x * y
9+
10+
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T}
11+
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
12+
end
13+
14+
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T}
15+
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
16+
end
17+
18+
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2}
19+
return batched_mul(x, y)
20+
end
21+
22+
function _init_J_batched(x::AbstractMatrix{T}) where {T}
23+
J = ArrayInterfaceCore.zeromatrix(x[:, 1])
24+
if ismutable(x)
25+
J[diagind(J)] .= one(eltype(x))
26+
else
27+
J += I
28+
end
29+
return repeat(J, 1, 1, size(x, 2))
30+
end
31+
32+
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
33+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
34+
f = Base.Fix2(prob.f, prob.p)
35+
x = float(prob.u0)
36+
37+
if ndims(x) != 2
38+
error("`batch` mode works only if `ndims(prob.u0) == 2`")
39+
end
40+
41+
fₙ = f(x)
42+
T = eltype(x)
43+
J⁻¹ = _init_J_batched(x)
44+
45+
if SciMLBase.isinplace(prob)
46+
error("Broyden currently only supports out-of-place nonlinear problems")
47+
end
48+
49+
atol = abstol !== nothing ? abstol :
50+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
51+
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
52+
53+
xₙ = x
54+
xₙ₋₁ = x
55+
fₙ₋₁ = fₙ
56+
for i in 1:maxiters
57+
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁)
58+
fₙ = f(xₙ)
59+
Δxₙ = xₙ .- xₙ₋₁
60+
Δfₙ = fₙ .- fₙ₋₁
61+
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ)
62+
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./
63+
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
64+
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
65+
66+
iszero(fₙ) &&
67+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
68+
retcode = ReturnCode.Success)
69+
70+
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
71+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
72+
retcode = ReturnCode.Success)
73+
end
74+
xₙ₋₁ = xₙ
75+
fₙ₋₁ = fₙ
76+
end
77+
78+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
79+
end
80+
81+
end

src/SimpleNonlinearSolve.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ using DiffEqBase
1010

1111
@reexport using SciMLBase
1212

13+
if !isdefined(Base, :get_extension)
14+
using Requires
15+
end
16+
17+
function __init__()
18+
@static if !isdefined(Base, :get_extension)
19+
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end
20+
end
21+
end
22+
1323
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
1424
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
1525
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end

src/broyden.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
"""
2-
```julia
3-
Broyden()
4-
```
2+
Broyden(; batched = false)
53
64
A low-overhead implementation of Broyden. This method is non-allocating on scalar
75
and static array problems.
6+
7+
!!! note
8+
9+
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
10+
`import NNlib` must be present in your code.
811
"""
9-
struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end
12+
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
13+
Broyden(; batched = false) = new{batched}()
14+
end
1015

11-
function SciMLBase.__solve(prob::NonlinearProblem,
12-
alg::Broyden, args...; abstol = nothing,
13-
reltol = nothing,
14-
maxiters = 1000, kwargs...)
16+
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
17+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
1518
f = Base.Fix2(prob.f, prob.p)
1619
x = float(prob.u0)
20+
1721
fₙ = f(x)
1822
T = eltype(x)
1923
J⁻¹ = init_J(x)
@@ -34,7 +38,8 @@ function SciMLBase.__solve(prob::NonlinearProblem,
3438
fₙ = f(xₙ)
3539
Δxₙ = xₙ - xₙ₋₁
3640
Δfₙ = fₙ - fₙ₋₁
37-
J⁻¹ += ((Δxₙ - J⁻¹ * Δfₙ) ./ (Δxₙ' * J⁻¹ * Δfₙ)) * (Δxₙ' * J⁻¹)
41+
J⁻¹Δfₙ = J⁻¹ * Δfₙ
42+
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
3843

3944
iszero(fₙ) &&
4045
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;

src/lbroyden.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm
88
threshold::Int = 27
99
end
1010

11-
@views function SciMLBase.__solve(prob::NonlinearProblem,
12-
alg::LBroyden, args...; abstol = nothing,
13-
reltol = nothing,
14-
maxiters = 1000, kwargs...)
11+
@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...;
12+
abstol = nothing, reltol = nothing, maxiters = 1000,
13+
batch = false, kwargs...)
1514
threshold = min(maxiters, alg.threshold)
1615
x = float(prob.u0)
1716

test/basictests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,15 @@ for options in list_of_options
370370
sol = solve(probN, alg)
371371
@test all(abs.(f(u, p)) .< 1e-10)
372372
end
373+
374+
# Batched Broyden
375+
using NNlib
376+
377+
f, u0 = (u, p) -> u .* u .- p, randn(1, 3)
378+
379+
p = [2.0 1.0 5.0];
380+
probN = NonlinearProblem{false}(f, u0, p);
381+
382+
sol = solve(probN, Broyden(batched = true))
383+
384+
@test abs.(sol.u) sqrt.(p)

0 commit comments

Comments
 (0)