diff --git a/Project.toml b/Project.toml index e4c556f..ce68fc7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.18" +version = "0.1.19" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl index 323c07e..a141819 100644 --- a/src/batched/raphson.jl +++ b/src/batched/raphson.jl @@ -20,7 +20,8 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) iip = SciMLBase.isinplace(prob) - @assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems." + iip && + @assert alg_autodiff(alg) "Inplace BatchedSimpleNewtonRaphson currently only supports autodiff." u, f, reconstruct = _construct_batched_problem_structure(prob) tc = alg.termination_condition @@ -35,12 +36,26 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs rtol = _get_tolerance(reltol, tc.reltol, T) termination_condition = tc(storage) + if iip + 𝓙 = similar(xₙ, length(xₙ), length(xₙ)) + fₙ = similar(xₙ) + jac_cfg = ForwardDiff.JacobianConfig(f, fₙ, xₙ) + end + for i in 1:maxiters - if alg_autodiff(alg) - fₙ, 𝓙 = value_derivative(f, xₙ) + if iip + value_derivative!(𝓙, fₙ, f, xₙ, jac_cfg) else - fₙ = f(xₙ) - 𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ) + if alg_autodiff(alg) + fₙ, 𝓙 = value_derivative(f, xₙ) + else + fₙ = f(xₙ) + 𝓙 = FiniteDiff.finite_difference_jacobian(f, + xₙ, + diff_type(alg), + eltype(xₙ), + fₙ) + end end iszero(fₙ) && return DiffEqBase.build_solution(prob, @@ -66,7 +81,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES xₙ = storage.u - fₙ = f(xₙ) + @maybeinplace iip fₙ=f(xₙ) end return DiffEqBase.build_solution(prob, diff --git a/src/utils.jl b/src/utils.jl index 890aa24..12462a0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -30,6 +30,20 @@ function value_derivative(f::F, x::R) where {F, R} end value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(f, x) +""" + value_derivative!(J, y, f!, x, cfg = JacobianConfig(f!, y, x)) + +Inplace version of [`SimpleNonlinearSolve.value_derivative`](@ref). +""" +function value_derivative!(J::AbstractMatrix, + y::AbstractArray, + f!::F, + x::AbstractArray, + cfg::ForwardDiff.JacobianConfig = ForwardDiff.JacobianConfig(f!, y, x)) where {F} + ForwardDiff.jacobian!(J, f!, y, x, cfg) + return y, J +end + value(x) = x value(x::Dual) = ForwardDiff.value(x) value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) diff --git a/test/inplace.jl b/test/inplace.jl index 886e820..d488210 100644 --- a/test/inplace.jl +++ b/test/inplace.jl @@ -2,29 +2,28 @@ using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, NNlib -# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane +# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane, BatchedSimpleNewtonRaphson function f!(du::AbstractArray{<:Number, N}, u::AbstractArray{<:Number, N}, p::AbstractVector) where {N} u_ = reshape(u, :, size(u, N)) - du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :), - ntuple(_ -> 1, N - 1)..., - size(u, N)) + du .= reshape(sum(abs2, u_; dims = 1) .- u_ .- reshape(p, 1, :), size(u)) return du end function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector) - du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :) + du .= sum(abs2, u; dims = 1) .- u .- reshape(p, 1, :) return du end function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector) - du .= sum(abs2, u) .- p + du .= sum(abs2, u) .- u .- p return du end -@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true), - SimpleDFSane(batched = true)) +@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(; batched = true), + SimpleDFSane(; batched = true), + SimpleNewtonRaphson(; batched = true)) @testset "T: $T" for T in (Float32, Float64) p = rand(T, 5) @testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))