From 932f4e94eeb00bda9daad067d53b2bf175ff80df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Feb 2023 13:02:27 -0500 Subject: [PATCH 1/4] Make broyden batched --- Project.toml | 4 +++- src/SimpleNonlinearSolve.jl | 1 + src/broyden.jl | 29 ++++++++++++++++++++--------- src/lbroyden.jl | 7 +++---- src/utils.jl | 28 +++++++++++++++++++++++----- 5 files changed, 50 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 6288c3e..351bb62 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.11" +version = "0.1.12" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" @@ -9,6 +9,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" @@ -19,6 +20,7 @@ ArrayInterfaceCore = "0.1.1" DiffEqBase = "6.114" FiniteDiff = "2" ForwardDiff = "0.10.3" +NNlib = "0.8" Reexport = "0.2, 1" SciMLBase = "1.73" SnoopPrecompile = "1" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 5b170ce..b48e33e 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -7,6 +7,7 @@ using StaticArraysCore using LinearAlgebra import ArrayInterfaceCore using DiffEqBase +using NNlib # Batched Matrix Multiplication @reexport using SciMLBase diff --git a/src/broyden.jl b/src/broyden.jl index a05caa0..7539a2b 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -8,15 +8,18 @@ and static array problems. """ struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end -function SciMLBase.__solve(prob::NonlinearProblem, - alg::Broyden, args...; abstol = nothing, - reltol = nothing, - maxiters = 1000, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, + reltol = nothing, maxiters = 1000, batch = false, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) + + if batch && ndims(x) != 2 + error("`batch` mode works only if `ndims(prob.u0) == 2`") + end + fₙ = f(x) T = eltype(x) - J⁻¹ = init_J(x) + J⁻¹ = init_J(x; batch) if SciMLBase.isinplace(prob) error("Broyden currently only supports out-of-place nonlinear problems") @@ -30,11 +33,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, xₙ₋₁ = x fₙ₋₁ = fₙ for _ in 1:maxiters - xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁ + xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch) fₙ = f(xₙ) - Δxₙ = xₙ - xₙ₋₁ - Δfₙ = fₙ - fₙ₋₁ - J⁻¹ += ((Δxₙ - J⁻¹ * Δfₙ) ./ (Δxₙ' * J⁻¹ * Δfₙ)) * (Δxₙ' * J⁻¹) + Δxₙ = xₙ .- xₙ₋₁ + Δfₙ = fₙ .- fₙ₋₁ + J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch) + J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./ + (_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))), + _batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch) iszero(fₙ) && return SciMLBase.build_solution(prob, alg, xₙ, fₙ; @@ -50,3 +56,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) end + +function _batch_transpose(x, batch) + !batch && return x' + return reshape(x, 1, size(x)...) +end diff --git a/src/lbroyden.jl b/src/lbroyden.jl index f983bbd..f8ed9f5 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -8,10 +8,9 @@ Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm threshold::Int = 27 end -@views function SciMLBase.__solve(prob::NonlinearProblem, - alg::LBroyden, args...; abstol = nothing, - reltol = nothing, - maxiters = 1000, kwargs...) +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, + batch = false, kwargs...) threshold = min(maxiters, alg.threshold) x = float(prob.u0) diff --git a/src/utils.jl b/src/utils.jl index f3b7de9..6f12b82 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,14 +34,17 @@ value(x) = x value(x::Dual) = ForwardDiff.value(x) value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -function init_J(x) - J = ArrayInterfaceCore.zeromatrix(x) - if ismutable(x) - J[diagind(J)] .= one(eltype(x)) +function init_J(x; batch = false) + x_ = batch ? x[:, 1] : x + + J = ArrayInterfaceCore.zeromatrix(x_) + if ismutable(x_) + J[diagind(J)] .= one(eltype(x_)) else J += I end - return J + + return batch ? repeat(J, 1, 1, size(x, 2)) : J end function dogleg_method(H, g, Δ) @@ -68,3 +71,18 @@ function dogleg_method(H, g, Δ) tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd return δsd + tau * δN_δsd end + +_batched_mul(x, y, batch) = x * y +function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix, batch) where {T} + !batch && return x * y + return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) +end +function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}, batch) where {T} + !batch && return x * y + return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) +end +function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}, + batch) where {T1, T2} + !batch && return x * y + return batched_mul(x, y) +end \ No newline at end of file From 5e468fef173510ed3bc0e6a07b48d06cae3493aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Feb 2023 11:07:39 -0500 Subject: [PATCH 2/4] Make it an extension --- Project.toml | 7 ++- .../SimpleBatchedNonlinearSolveExt.jl | 9 ++++ ext/SimpleBatchedNonlinearSolveExt/broyden.jl | 48 +++++++++++++++++++ .../lbroyden.jl | 0 ext/SimpleBatchedNonlinearSolveExt/utils.jl | 25 ++++++++++ src/SimpleNonlinearSolve.jl | 1 - src/broyden.jl | 38 +++++++-------- src/utils.jl | 28 ++--------- 8 files changed, 109 insertions(+), 47 deletions(-) create mode 100644 ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl create mode 100644 ext/SimpleBatchedNonlinearSolveExt/broyden.jl create mode 100644 ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl create mode 100644 ext/SimpleBatchedNonlinearSolveExt/utils.jl diff --git a/Project.toml b/Project.toml index 351bb62..4852c1f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,12 +9,17 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" + +[extensions] +SimpleBatchedNonlinearSolveExt = "NNlib" + [compat] ArrayInterfaceCore = "0.1.1" DiffEqBase = "6.114" diff --git a/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl new file mode 100644 index 0000000..5a7483e --- /dev/null +++ b/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl @@ -0,0 +1,9 @@ +module SimpleBatchedNonlinearSolveExt + +using SimpleNonlinearSolve, SciMLBase, NNlib + +include("utils.jl") +include("broyden.jl") +include("lbroyden.jl") + +end \ No newline at end of file diff --git a/ext/SimpleBatchedNonlinearSolveExt/broyden.jl b/ext/SimpleBatchedNonlinearSolveExt/broyden.jl new file mode 100644 index 0000000..2476e95 --- /dev/null +++ b/ext/SimpleBatchedNonlinearSolveExt/broyden.jl @@ -0,0 +1,48 @@ +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + f = Base.Fix2(prob.f, prob.p) + x = float(prob.u0) + + if ndims(x) != 2 + error("`batch` mode works only if `ndims(prob.u0) == 2`") + end + + fₙ = f(x) + T = eltype(x) + J⁻¹ = _init_J_batched(x) + + if SciMLBase.isinplace(prob) + error("Broyden currently only supports out-of-place nonlinear problems") + end + + atol = abstol !== nothing ? abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) + rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + + xₙ = x + xₙ₋₁ = x + fₙ₋₁ = fₙ + for _ in 1:maxiters + xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch) + fₙ = f(xₙ) + Δxₙ = xₙ .- xₙ₋₁ + Δfₙ = fₙ .- fₙ₋₁ + J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch) + J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./ + (_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))), + _batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch) + + iszero(fₙ) && + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; + retcode = ReturnCode.Success) + + if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; + retcode = ReturnCode.Success) + end + xₙ₋₁ = xₙ + fₙ₋₁ = fₙ + end + + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) +end diff --git a/ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl b/ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl new file mode 100644 index 0000000..e69de29 diff --git a/ext/SimpleBatchedNonlinearSolveExt/utils.jl b/ext/SimpleBatchedNonlinearSolveExt/utils.jl new file mode 100644 index 0000000..4dfd4f3 --- /dev/null +++ b/ext/SimpleBatchedNonlinearSolveExt/utils.jl @@ -0,0 +1,25 @@ +_batch_transpose(x) = reshape(x, 1, size(x)...) + +_batched_mul(x, y) = x * y + +function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} + return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) +end + +function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} + return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) +end + +function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} + return batched_mul(x, y) +end + +function _init_J_batched(x::AbstractMatrix{T}) where {T} + J = ArrayInterfaceCore.zeromatrix(x[:, 1]) + if ismutable(x) + J[diagind(J)] .= one(eltype(x)) + else + J += I + end + return repeat(J, 1, 1, size(x, 2)) +end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index b48e33e..5b170ce 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -7,7 +7,6 @@ using StaticArraysCore using LinearAlgebra import ArrayInterfaceCore using DiffEqBase -using NNlib # Batched Matrix Multiplication @reexport using SciMLBase diff --git a/src/broyden.jl b/src/broyden.jl index 7539a2b..7820be8 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,25 +1,26 @@ """ -```julia -Broyden() -``` + Broyden() A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. """ -struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end +struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm + Broyden(batched = false) = new{batched}() +end -function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, - reltol = nothing, maxiters = 1000, batch = false, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) - if batch && ndims(x) != 2 - error("`batch` mode works only if `ndims(prob.u0) == 2`") - end + # if batch && ndims(x) != 2 + # error("`batch` mode works only if `ndims(prob.u0) == 2`") + # end fₙ = f(x) T = eltype(x) - J⁻¹ = init_J(x; batch) + # J⁻¹ = init_J(x; batch) + J⁻¹ = init_J(x) if SciMLBase.isinplace(prob) error("Broyden currently only supports out-of-place nonlinear problems") @@ -33,14 +34,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol xₙ₋₁ = x fₙ₋₁ = fₙ for _ in 1:maxiters - xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch) + xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁ fₙ = f(xₙ) - Δxₙ = xₙ .- xₙ₋₁ - Δfₙ = fₙ .- fₙ₋₁ - J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch) - J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./ - (_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))), - _batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch) + Δxₙ = xₙ - xₙ₋₁ + Δfₙ = fₙ - fₙ₋₁ + J⁻¹Δfₙ = J⁻¹ * Δfₙ + J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹) iszero(fₙ) && return SciMLBase.build_solution(prob, alg, xₙ, fₙ; @@ -56,8 +55,3 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) end - -function _batch_transpose(x, batch) - !batch && return x' - return reshape(x, 1, size(x)...) -end diff --git a/src/utils.jl b/src/utils.jl index 6f12b82..f3b7de9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,17 +34,14 @@ value(x) = x value(x::Dual) = ForwardDiff.value(x) value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -function init_J(x; batch = false) - x_ = batch ? x[:, 1] : x - - J = ArrayInterfaceCore.zeromatrix(x_) - if ismutable(x_) - J[diagind(J)] .= one(eltype(x_)) +function init_J(x) + J = ArrayInterfaceCore.zeromatrix(x) + if ismutable(x) + J[diagind(J)] .= one(eltype(x)) else J += I end - - return batch ? repeat(J, 1, 1, size(x, 2)) : J + return J end function dogleg_method(H, g, Δ) @@ -71,18 +68,3 @@ function dogleg_method(H, g, Δ) tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd return δsd + tau * δN_δsd end - -_batched_mul(x, y, batch) = x * y -function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix, batch) where {T} - !batch && return x * y - return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) -end -function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}, batch) where {T} - !batch && return x * y - return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) -end -function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}, - batch) where {T1, T2} - !batch && return x * y - return batched_mul(x, y) -end \ No newline at end of file From 93c827915aa20a0e0d5eafbbda13727122980ff1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Feb 2023 11:28:44 -0500 Subject: [PATCH 3/4] Add requires for backward compat --- Project.toml | 4 ++- ...n.jl => SimpleBatchedNonlinearSolveExt.jl} | 34 +++++++++++++++++++ .../SimpleBatchedNonlinearSolveExt.jl | 9 ----- .../lbroyden.jl | 0 ext/SimpleBatchedNonlinearSolveExt/utils.jl | 25 -------------- src/SimpleNonlinearSolve.jl | 12 +++++++ src/broyden.jl | 9 ++--- 7 files changed, 51 insertions(+), 42 deletions(-) rename ext/{SimpleBatchedNonlinearSolveExt/broyden.jl => SimpleBatchedNonlinearSolveExt.jl} (66%) delete mode 100644 ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl delete mode 100644 ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl delete mode 100644 ext/SimpleBatchedNonlinearSolveExt/utils.jl diff --git a/Project.toml b/Project.toml index 4852c1f..28e3410 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -34,10 +35,11 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays"] +test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"] diff --git a/ext/SimpleBatchedNonlinearSolveExt/broyden.jl b/ext/SimpleBatchedNonlinearSolveExt.jl similarity index 66% rename from ext/SimpleBatchedNonlinearSolveExt/broyden.jl rename to ext/SimpleBatchedNonlinearSolveExt.jl index 2476e95..88de488 100644 --- a/ext/SimpleBatchedNonlinearSolveExt/broyden.jl +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -1,3 +1,34 @@ +module SimpleBatchedNonlinearSolveExt + +using SimpleNonlinearSolve, SciMLBase +isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) + +_batch_transpose(x) = reshape(x, 1, size(x)...) + +_batched_mul(x, y) = x * y + +function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} + return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) +end + +function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} + return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) +end + +function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} + return batched_mul(x, y) +end + +function _init_J_batched(x::AbstractMatrix{T}) where {T} + J = ArrayInterfaceCore.zeromatrix(x[:, 1]) + if ismutable(x) + J[diagind(J)] .= one(eltype(x)) + else + J += I + end + return repeat(J, 1, 1, size(x, 2)) +end + function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) @@ -46,3 +77,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) end + + +end \ No newline at end of file diff --git a/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl deleted file mode 100644 index 5a7483e..0000000 --- a/ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module SimpleBatchedNonlinearSolveExt - -using SimpleNonlinearSolve, SciMLBase, NNlib - -include("utils.jl") -include("broyden.jl") -include("lbroyden.jl") - -end \ No newline at end of file diff --git a/ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl b/ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl deleted file mode 100644 index e69de29..0000000 diff --git a/ext/SimpleBatchedNonlinearSolveExt/utils.jl b/ext/SimpleBatchedNonlinearSolveExt/utils.jl deleted file mode 100644 index 4dfd4f3..0000000 --- a/ext/SimpleBatchedNonlinearSolveExt/utils.jl +++ /dev/null @@ -1,25 +0,0 @@ -_batch_transpose(x) = reshape(x, 1, size(x)...) - -_batched_mul(x, y) = x * y - -function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} - return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) -end - -function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} - return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) -end - -function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} - return batched_mul(x, y) -end - -function _init_J_batched(x::AbstractMatrix{T}) where {T} - J = ArrayInterfaceCore.zeromatrix(x[:, 1]) - if ismutable(x) - J[diagind(J)] .= one(eltype(x)) - else - J += I - end - return repeat(J, 1, 1, size(x, 2)) -end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 5b170ce..51785e3 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -10,6 +10,18 @@ using DiffEqBase @reexport using SciMLBase +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin + include("../ext/SimpleBatchedNonlinearSolveExt.jl") + end + end +end + abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end diff --git a/src/broyden.jl b/src/broyden.jl index 7820be8..15da8a7 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,11 +1,11 @@ """ - Broyden() + Broyden(; batched = false) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. """ struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm - Broyden(batched = false) = new{batched}() + Broyden(; batched = false) = new{batched}() end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; @@ -13,13 +13,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) - # if batch && ndims(x) != 2 - # error("`batch` mode works only if `ndims(prob.u0) == 2`") - # end - fₙ = f(x) T = eltype(x) - # J⁻¹ = init_J(x; batch) J⁻¹ = init_J(x) if SciMLBase.isinplace(prob) From 9f2b6b2c6a85fda890f302b8baef7da2bc5cb338 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Feb 2023 14:51:45 -0500 Subject: [PATCH 4/4] Add tests for batched broyden --- .gitignore | 3 +++ ext/SimpleBatchedNonlinearSolveExt.jl | 17 ++++++++--------- src/SimpleNonlinearSolve.jl | 4 +--- src/broyden.jl | 5 +++++ test/basictests.jl | 12 ++++++++++++ 5 files changed, 29 insertions(+), 12 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e4cfbfe --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +Manifest.toml + +wip \ No newline at end of file diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl index 88de488..d599704 100644 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -1,6 +1,6 @@ module SimpleBatchedNonlinearSolveExt -using SimpleNonlinearSolve, SciMLBase +using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) _batch_transpose(x) = reshape(x, 1, size(x)...) @@ -53,15 +53,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; xₙ = x xₙ₋₁ = x fₙ₋₁ = fₙ - for _ in 1:maxiters - xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch) + for i in 1:maxiters + xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁) fₙ = f(xₙ) Δxₙ = xₙ .- xₙ₋₁ Δfₙ = fₙ .- fₙ₋₁ - J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch) - J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./ - (_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))), - _batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch) + J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ) + J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./ + (_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))), + _batched_mul(_batch_transpose(Δxₙ), J⁻¹)) iszero(fₙ) && return SciMLBase.build_solution(prob, alg, xₙ, fₙ; @@ -78,5 +78,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) end - -end \ No newline at end of file +end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 51785e3..b82a7d0 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -16,9 +16,7 @@ end function __init__() @static if !isdefined(Base, :get_extension) - @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin - include("../ext/SimpleBatchedNonlinearSolveExt.jl") - end + @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end end end diff --git a/src/broyden.jl b/src/broyden.jl index 15da8a7..d0ae233 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -3,6 +3,11 @@ A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. + +!!! note + + To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or + `import NNlib` must be present in your code. """ struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm Broyden(; batched = false) = new{batched}() diff --git a/test/basictests.jl b/test/basictests.jl index 3f386d2..12525d8 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -370,3 +370,15 @@ for options in list_of_options sol = solve(probN, alg) @test all(abs.(f(u, p)) .< 1e-10) end + +# Batched Broyden +using NNlib + +f, u0 = (u, p) -> u .* u .- p, randn(1, 3) + +p = [2.0 1.0 5.0]; +probN = NonlinearProblem{false}(f, u0, p); + +sol = solve(probN, Broyden(batched = true)) + +@test abs.(sol.u) ≈ sqrt.(p)