Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Add support for inplace BatchedSimpleNewtonRaphson #76

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.18"
version = "0.1.19"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
27 changes: 21 additions & 6 deletions src/batched/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
@assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems."
iip &&
@assert alg_autodiff(alg) "Inplace BatchedSimpleNewtonRaphson currently only supports autodiff."
u, f, reconstruct = _construct_batched_problem_structure(prob)

tc = alg.termination_condition
Expand All @@ -35,12 +36,26 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs
rtol = _get_tolerance(reltol, tc.reltol, T)
termination_condition = tc(storage)

if iip
𝓙 = similar(xₙ, length(xₙ), length(xₙ))
fₙ = similar(xₙ)
jac_cfg = ForwardDiff.JacobianConfig(f, fₙ, xₙ)
end

for i in 1:maxiters
if alg_autodiff(alg)
fₙ, 𝓙 = value_derivative(f, xₙ)
if iip
value_derivative!(𝓙, fₙ, f, xₙ, jac_cfg)
else
fₙ = f(xₙ)
𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ)
if alg_autodiff(alg)
fₙ, 𝓙 = value_derivative(f, xₙ)
else
fₙ = f(xₙ)
𝓙 = FiniteDiff.finite_difference_jacobian(f,
xₙ,
diff_type(alg),
eltype(xₙ),
fₙ)
end
end

iszero(fₙ) && return DiffEqBase.build_solution(prob,
Expand All @@ -66,7 +81,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs

if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
fₙ = f(xₙ)
@maybeinplace iip fₙ=f(xₙ)
end

return DiffEqBase.build_solution(prob,
Expand Down
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ function value_derivative(f::F, x::R) where {F, R}
end
value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(f, x)

"""
value_derivative!(J, y, f!, x, cfg = JacobianConfig(f!, y, x))

Inplace version of [`SimpleNonlinearSolve.value_derivative`](@ref).
"""
function value_derivative!(J::AbstractMatrix,
y::AbstractArray,
f!::F,
x::AbstractArray,
cfg::ForwardDiff.JacobianConfig = ForwardDiff.JacobianConfig(f!, y, x)) where {F}
ForwardDiff.jacobian!(J, f!, y, x, cfg)
return y, J
end

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
Expand Down
15 changes: 7 additions & 8 deletions test/inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@ using SimpleNonlinearSolve,
StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
NNlib

# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane
# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane, BatchedSimpleNewtonRaphson
function f!(du::AbstractArray{<:Number, N},
u::AbstractArray{<:Number, N},
p::AbstractVector) where {N}
u_ = reshape(u, :, size(u, N))
du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :),
ntuple(_ -> 1, N - 1)...,
size(u, N))
du .= reshape(sum(abs2, u_; dims = 1) .- u_ .- reshape(p, 1, :), size(u))
return du
end

function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector)
du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :)
du .= sum(abs2, u; dims = 1) .- u .- reshape(p, 1, :)
return du
end

function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector)
du .= sum(abs2, u) .- p
du .= sum(abs2, u) .- u .- p
return du
end

@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true),
SimpleDFSane(batched = true))
@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(; batched = true),
SimpleDFSane(; batched = true),
SimpleNewtonRaphson(; batched = true))
@testset "T: $T" for T in (Float32, Float64)
p = rand(T, 5)
@testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))
Expand Down