diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e4cfbfe --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +Manifest.toml + +wip \ No newline at end of file diff --git a/Project.toml b/Project.toml index 6288c3e..28e3410 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.11" +version = "0.1.12" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" @@ -10,15 +10,23 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" + +[extensions] +SimpleBatchedNonlinearSolveExt = "NNlib" + [compat] ArrayInterfaceCore = "0.1.1" DiffEqBase = "6.114" FiniteDiff = "2" ForwardDiff = "0.10.3" +NNlib = "0.8" Reexport = "0.2, 1" SciMLBase = "1.73" SnoopPrecompile = "1" @@ -27,10 +35,11 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays"] +test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"] diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl new file mode 100644 index 0000000..d599704 --- /dev/null +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -0,0 +1,81 @@ +module SimpleBatchedNonlinearSolveExt + +using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase +isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) + +_batch_transpose(x) = reshape(x, 1, size(x)...) + +_batched_mul(x, y) = x * y + +function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} + return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) +end + +function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} + return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) +end + +function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} + return batched_mul(x, y) +end + +function _init_J_batched(x::AbstractMatrix{T}) where {T} + J = ArrayInterfaceCore.zeromatrix(x[:, 1]) + if ismutable(x) + J[diagind(J)] .= one(eltype(x)) + else + J += I + end + return repeat(J, 1, 1, size(x, 2)) +end + +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + f = Base.Fix2(prob.f, prob.p) + x = float(prob.u0) + + if ndims(x) != 2 + error("`batch` mode works only if `ndims(prob.u0) == 2`") + end + + fₙ = f(x) + T = eltype(x) + J⁻¹ = _init_J_batched(x) + + if SciMLBase.isinplace(prob) + error("Broyden currently only supports out-of-place nonlinear problems") + end + + atol = abstol !== nothing ? abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) + rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + + xₙ = x + xₙ₋₁ = x + fₙ₋₁ = fₙ + for i in 1:maxiters + xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁) + fₙ = f(xₙ) + Δxₙ = xₙ .- xₙ₋₁ + Δfₙ = fₙ .- fₙ₋₁ + J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ) + J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./ + (_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))), + _batched_mul(_batch_transpose(Δxₙ), J⁻¹)) + + iszero(fₙ) && + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; + retcode = ReturnCode.Success) + + if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol) + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; + retcode = ReturnCode.Success) + end + xₙ₋₁ = xₙ + fₙ₋₁ = fₙ + end + + return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) +end + +end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 5b170ce..b82a7d0 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -10,6 +10,16 @@ using DiffEqBase @reexport using SciMLBase +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end + end +end + abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end diff --git a/src/broyden.jl b/src/broyden.jl index a05caa0..d0ae233 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,19 +1,23 @@ """ -```julia -Broyden() -``` + Broyden(; batched = false) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. + +!!! note + + To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or + `import NNlib` must be present in your code. """ -struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end +struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm + Broyden(; batched = false) = new{batched}() +end -function SciMLBase.__solve(prob::NonlinearProblem, - alg::Broyden, args...; abstol = nothing, - reltol = nothing, - maxiters = 1000, kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) + fₙ = f(x) T = eltype(x) J⁻¹ = init_J(x) @@ -34,7 +38,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, fₙ = f(xₙ) Δxₙ = xₙ - xₙ₋₁ Δfₙ = fₙ - fₙ₋₁ - J⁻¹ += ((Δxₙ - J⁻¹ * Δfₙ) ./ (Δxₙ' * J⁻¹ * Δfₙ)) * (Δxₙ' * J⁻¹) + J⁻¹Δfₙ = J⁻¹ * Δfₙ + J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹) iszero(fₙ) && return SciMLBase.build_solution(prob, alg, xₙ, fₙ; diff --git a/src/lbroyden.jl b/src/lbroyden.jl index f983bbd..f8ed9f5 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -8,10 +8,9 @@ Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm threshold::Int = 27 end -@views function SciMLBase.__solve(prob::NonlinearProblem, - alg::LBroyden, args...; abstol = nothing, - reltol = nothing, - maxiters = 1000, kwargs...) +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, + batch = false, kwargs...) threshold = min(maxiters, alg.threshold) x = float(prob.u0) diff --git a/test/basictests.jl b/test/basictests.jl index 3f386d2..12525d8 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -370,3 +370,15 @@ for options in list_of_options sol = solve(probN, alg) @test all(abs.(f(u, p)) .< 1e-10) end + +# Batched Broyden +using NNlib + +f, u0 = (u, p) -> u .* u .- p, randn(1, 3) + +p = [2.0 1.0 5.0]; +probN = NonlinearProblem{false}(f, u0, p); + +sol = solve(probN, Broyden(batched = true)) + +@test abs.(sol.u) ≈ sqrt.(p)