Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Batched Broyden #43

Merged
merged 4 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.11"
version = "0.1.12"

[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh definitely not in this library 😅 Maybe that needs to be an extension.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😓 I guessed so. Let me set it up as a Pkg Extension then

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Pkg Extension, I think we will have to set it up as a dispatch on algorithm and not a keyword argument right? Something like Broyden{false} by default and Broyden{true} for the batched version

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it probably needs to be split like that. What is the NNlib part for? Why is the matmul batched instead of just a single matmul?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The matmul is A x B x Batch * B x C x Batch, so they are independent matmuls along the batch dimension. Also IIRC the NNlib implementation does proper CUBLAS dispatches https://github.com/FluxML/NNlibCUDA.jl/blob/5f797aec23cbb5483788697e9e911685b681bbf8/src/batchedmul.jl#L3

Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
Expand All @@ -19,6 +20,7 @@ ArrayInterfaceCore = "0.1.1"
DiffEqBase = "6.114"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
NNlib = "0.8"
Reexport = "0.2, 1"
SciMLBase = "1.73"
SnoopPrecompile = "1"
Expand Down
1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using StaticArraysCore
using LinearAlgebra
import ArrayInterfaceCore
using DiffEqBase
using NNlib # Batched Matrix Multiplication

@reexport using SciMLBase

Expand Down
29 changes: 20 additions & 9 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ and static array problems.
"""
struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::Broyden, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing,
reltol = nothing, maxiters = 1000, batch = false, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

if batch && ndims(x) != 2
error("`batch` mode works only if `ndims(prob.u0) == 2`")
end

fₙ = f(x)
T = eltype(x)
J⁻¹ = init_J(x)
J⁻¹ = init_J(x; batch)

if SciMLBase.isinplace(prob)
error("Broyden currently only supports out-of-place nonlinear problems")
Expand All @@ -30,11 +33,14 @@ function SciMLBase.__solve(prob::NonlinearProblem,
xₙ₋₁ = x
fₙ₋₁ = fₙ
for _ in 1:maxiters
xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch)
fₙ = f(xₙ)
Δxₙ = xₙ - xₙ₋₁
Δfₙ = fₙ - fₙ₋₁
J⁻¹ += ((Δxₙ - J⁻¹ * Δfₙ) ./ (Δxₙ' * J⁻¹ * Δfₙ)) * (Δxₙ' * J⁻¹)
Δxₙ = xₙ .- xₙ₋₁
Δfₙ = fₙ .- fₙ₋₁
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch)
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./
(_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))),
_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch)

iszero(fₙ) &&
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
Expand All @@ -50,3 +56,8 @@ function SciMLBase.__solve(prob::NonlinearProblem,

return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
end

function _batch_transpose(x, batch)
!batch && return x'
return reshape(x, 1, size(x)...)
end
7 changes: 3 additions & 4 deletions src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm
threshold::Int = 27
end

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::LBroyden, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
batch = false, kwargs...)
threshold = min(maxiters, alg.threshold)
x = float(prob.u0)

Expand Down
28 changes: 23 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

function init_J(x)
J = ArrayInterfaceCore.zeromatrix(x)
if ismutable(x)
J[diagind(J)] .= one(eltype(x))
function init_J(x; batch = false)
x_ = batch ? x[:, 1] : x

J = ArrayInterfaceCore.zeromatrix(x_)
if ismutable(x_)
J[diagind(J)] .= one(eltype(x_))
else
J += I
end
return J

return batch ? repeat(J, 1, 1, size(x, 2)) : J
end

function dogleg_method(H, g, Δ)
Expand All @@ -68,3 +71,18 @@ function dogleg_method(H, g, Δ)
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
return δsd + tau * δN_δsd
end

_batched_mul(x, y, batch) = x * y
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix, batch) where {T}
!batch && return x * y
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
end
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}, batch) where {T}
!batch && return x * y
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
end
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3},
batch) where {T1, T2}
!batch && return x * y
return batched_mul(x, y)
end