Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 9697e06

Browse files
committed
Make it an extension
1 parent 932f4e9 commit 9697e06

File tree

8 files changed

+107
-45
lines changed

8 files changed

+107
-45
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12-
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1312
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1413
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1514
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
1615
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1716

17+
[weakdeps]
18+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
19+
20+
[extensions]
21+
SimpleBatchedNonlinearSolveExt = "NNlib"
22+
1823
[compat]
1924
ArrayInterfaceCore = "0.1.1"
2025
DiffEqBase = "6.114"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module SimpleBatchedNonlinearSolveExt
2+
3+
using SimpleNonlinearSolve, SciMLBase, NNlib
4+
5+
include("utils.jl")
6+
include("broyden.jl")
7+
include("lbroyden.jl")
8+
9+
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
2+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
3+
f = Base.Fix2(prob.f, prob.p)
4+
x = float(prob.u0)
5+
6+
if ndims(x) != 2
7+
error("`batch` mode works only if `ndims(prob.u0) == 2`")
8+
end
9+
10+
fₙ = f(x)
11+
T = eltype(x)
12+
J⁻¹ = _init_J_batched(x)
13+
14+
if SciMLBase.isinplace(prob)
15+
error("Broyden currently only supports out-of-place nonlinear problems")
16+
end
17+
18+
atol = abstol !== nothing ? abstol :
19+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
20+
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
21+
22+
xₙ = x
23+
xₙ₋₁ = x
24+
fₙ₋₁ = fₙ
25+
for _ in 1:maxiters
26+
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch)
27+
fₙ = f(xₙ)
28+
Δxₙ = xₙ .- xₙ₋₁
29+
Δfₙ = fₙ .- fₙ₋₁
30+
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch)
31+
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./
32+
(_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))),
33+
_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch)
34+
35+
iszero(fₙ) &&
36+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
37+
retcode = ReturnCode.Success)
38+
39+
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
40+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
41+
retcode = ReturnCode.Success)
42+
end
43+
xₙ₋₁ = xₙ
44+
fₙ₋₁ = fₙ
45+
end
46+
47+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
48+
end

ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
_batch_transpose(x) = reshape(x, 1, size(x)...)
2+
3+
_batched_mul(x, y) = x * y
4+
5+
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T}
6+
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
7+
end
8+
9+
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T}
10+
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
11+
end
12+
13+
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2}
14+
return batched_mul(x, y)
15+
end
16+
17+
function _init_J_batched(x::AbstractMatrix{T}) where {T}
18+
J = ArrayInterfaceCore.zeromatrix(x[:, 1])
19+
if ismutable(x)
20+
J[diagind(J)] .= one(eltype(x))
21+
else
22+
J += I
23+
end
24+
return repeat(J, 1, 1, size(x, 2))
25+
end

src/SimpleNonlinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using StaticArraysCore
77
using LinearAlgebra
88
import ArrayInterfaceCore
99
using DiffEqBase
10-
using NNlib # Batched Matrix Multiplication
1110

1211
@reexport using SciMLBase
1312

src/broyden.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
"""
2-
```julia
3-
Broyden()
4-
```
2+
Broyden()
53
64
A low-overhead implementation of Broyden. This method is non-allocating on scalar
75
and static array problems.
86
"""
9-
struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end
7+
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
8+
Broyden(batched = false) = new{batched}()
9+
end
1010

11-
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing,
12-
reltol = nothing, maxiters = 1000, batch = false, kwargs...)
11+
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
12+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
1313
f = Base.Fix2(prob.f, prob.p)
1414
x = float(prob.u0)
1515

16-
if batch && ndims(x) != 2
17-
error("`batch` mode works only if `ndims(prob.u0) == 2`")
18-
end
16+
# if batch && ndims(x) != 2
17+
# error("`batch` mode works only if `ndims(prob.u0) == 2`")
18+
# end
1919

2020
fₙ = f(x)
2121
T = eltype(x)
22-
J⁻¹ = init_J(x; batch)
22+
# J⁻¹ = init_J(x; batch)
23+
J⁻¹ = init_J(x)
2324

2425
if SciMLBase.isinplace(prob)
2526
error("Broyden currently only supports out-of-place nonlinear problems")
@@ -33,14 +34,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol
3334
xₙ₋₁ = x
3435
fₙ₋₁ = fₙ
3536
for _ in 1:maxiters
36-
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch)
37+
xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁
3738
fₙ = f(xₙ)
3839
Δxₙ = xₙ .- xₙ₋₁
3940
Δfₙ = fₙ .- fₙ₋₁
40-
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch)
41-
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./
42-
(_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))),
43-
_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch)
41+
J⁻¹Δfₙ = J⁻¹ * Δfₙ
42+
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
4443

4544
iszero(fₙ) &&
4645
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
@@ -56,8 +55,3 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol
5655

5756
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
5857
end
59-
60-
function _batch_transpose(x, batch)
61-
!batch && return x'
62-
return reshape(x, 1, size(x)...)
63-
end

src/utils.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,14 @@ value(x) = x
3434
value(x::Dual) = ForwardDiff.value(x)
3535
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
3636

37-
function init_J(x; batch = false)
38-
x_ = batch ? x[:, 1] : x
39-
40-
J = ArrayInterfaceCore.zeromatrix(x_)
41-
if ismutable(x_)
42-
J[diagind(J)] .= one(eltype(x_))
37+
function init_J(x)
38+
J = ArrayInterfaceCore.zeromatrix(x)
39+
if ismutable(x)
40+
J[diagind(J)] .= one(eltype(x))
4341
else
4442
J += I
4543
end
46-
47-
return batch ? repeat(J, 1, 1, size(x, 2)) : J
44+
return J
4845
end
4946

5047
function dogleg_method(H, g, Δ)
@@ -71,18 +68,3 @@ function dogleg_method(H, g, Δ)
7168
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
7269
return δsd + tau * δN_δsd
7370
end
74-
75-
_batched_mul(x, y, batch) = x * y
76-
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix, batch) where {T}
77-
!batch && return x * y
78-
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
79-
end
80-
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}, batch) where {T}
81-
!batch && return x * y
82-
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
83-
end
84-
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3},
85-
batch) where {T1, T2}
86-
!batch && return x * y
87-
return batched_mul(x, y)
88-
end

0 commit comments

Comments
 (0)