Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

add ImmutableNonlinearProblem #153

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore: ChainRulesCore, NoTangent
using DiffEqBase: DiffEqBase
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem

# The expectation here is that no-one is using this directly inside a GPU kernel. We can
# eventually lift this requirement using a custom adjoint
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module SimpleNonlinearSolveReverseDiffExt
using ArrayInterface: ArrayInterface
using DiffEqBase: DiffEqBase
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
import SimpleNonlinearSolve: __internal_solve_up

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module SimpleNonlinearSolveTrackerExt

using DiffEqBase: DiffEqBase
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
using Tracker: Tracker, TrackedArray

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray,
Expand Down
21 changes: 17 additions & 4 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess
norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size

Expand All @@ -35,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("immutable_nonlinear_problem.jl")
include("utils.jl")
include("linesearch.jl")

Expand Down Expand Up @@ -70,6 +71,18 @@ end
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
Expand All @@ -79,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
37 changes: 24 additions & 13 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function SciMLBase.solve(
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
prob = convert(ImmutableNonlinearProblem, prob)
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand All @@ -31,7 +42,7 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
end

function __nlsolve_ad(
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
Expand Down
67 changes: 67 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""
Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""
Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end


function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
prob.u0,
prob.p,
prob.problem_type;
prob.kwargs...)
end

function DiffEqBase.get_concrete_problem(prob::ImmutableNonlinearProblem, isadapt; kwargs...)
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
p = DiffEqBase.get_concrete_p(prob, kwargs)
DiffEqBase.remake(prob; u0 = u0, p = p)
end
2 changes: 1 addition & 1 deletion src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...) where {M}
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; termination_condition = nothing, kwargs...)
if prob.u0 isa SArray
if termination_condition === nothing ||
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
end

@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -114,7 +114,7 @@ end
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
# finicky, so we'll implement it separately from the generic version
# Ignore termination_condition. Don't pass things into internal functions
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, maxiters = 1000, kwargs...)
x = prob.u0
fx = _get_fx(prob, x)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scalar and static array problems.
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
return _get_fx(prob.f, x, prob.p)
end
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline function _get_fx(f::NonlinearFunction, x, p)
if isinplace(f)
if f.resid_prototype !== nothing
Expand All @@ -145,7 +145,7 @@ end
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
# is meant for low overhead solvers, users can opt into the other termination modes but the
# default is to use the least overhead version.
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
end
Expand All @@ -155,14 +155,14 @@ function init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
T = promote_type(eltype(du), eltype(u))
abstol = __get_tolerance(u, abstol, T)
reltol = __get_tolerance(u, reltol, T)
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
internalnorm = ifelse(
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
prob isa ImmutableNonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
else
tc
Expand Down
2 changes: 1 addition & 1 deletion test/core/adjoint_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)

@test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff
@test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff
end
1 change: 1 addition & 0 deletions test/gpu/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end
end

prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])
prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob)

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(),
Expand Down
Loading