Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 8d952fe

Browse files
committed
Add tests for batched broyden
1 parent 93c8279 commit 8d952fe

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Manifest.toml
2+
3+
wip

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SimpleBatchedNonlinearSolveExt
22

3-
using SimpleNonlinearSolve, SciMLBase
3+
using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
44
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
55

66
_batch_transpose(x) = reshape(x, 1, size(x)...)
@@ -53,15 +53,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
5353
xₙ = x
5454
xₙ₋₁ = x
5555
fₙ₋₁ = fₙ
56-
for _ in 1:maxiters
57-
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch)
56+
for i in 1:maxiters
57+
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁)
5858
fₙ = f(xₙ)
5959
Δxₙ = xₙ .- xₙ₋₁
6060
Δfₙ = fₙ .- fₙ₋₁
61-
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch)
62-
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./
63-
(_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))),
64-
_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch)
61+
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ)
62+
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./
63+
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
64+
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
6565

6666
iszero(fₙ) &&
6767
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
@@ -78,5 +78,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
7878
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
7979
end
8080

81-
8281
end

src/broyden.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
44
A low-overhead implementation of Broyden. This method is non-allocating on scalar
55
and static array problems.
6+
7+
!!! note
8+
9+
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
10+
`import NNlib` must be present in your code.
611
"""
712
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
813
Broyden(; batched = false) = new{batched}()

test/basictests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,14 @@ for options in list_of_options
370370
sol = solve(probN, alg)
371371
@test all(abs.(f(u, p)) .< 1e-10)
372372
end
373+
374+
# Batched Broyden
375+
using NNlib
376+
377+
f, u0 = (u, p) -> u .* u .- p, randn(1, 3)
378+
379+
p = [2.0 1.0 5.0]; probN = NonlinearProblem{false}(f, u0, p)
380+
381+
sol = solve(probN, Broyden(batched = true))
382+
383+
@test abs.(sol.u) sqrt.(p)

0 commit comments

Comments
 (0)