Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Improve Code Standards #147

Merged
merged 4 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
style = "sciml"
format_markdown = true
annotate_untyped_fields_with_any = false
format_docstrings = true
format_docstrings = true
join_lines_based_on_source = false
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DiffResults = "1.1"
ExplicitImports = "1.5.0"
FastClosures = "0.3.2"
FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
Expand Down Expand Up @@ -73,6 +74,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff", "ReverseDiff", "Tracker"]
test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "Tracker", "Zygote"]
17 changes: 10 additions & 7 deletions ext/SimpleNonlinearSolveChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
using ChainRulesCore: ChainRulesCore, NoTangent
using DiffEqBase: DiffEqBase
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve

# The expectation here is that no-one is using this directly inside a GPU kernel. We can
# eventually lift this requirement using a custom adjoint
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...;
kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(Δ)
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(),
∂args...)
return (
f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
end
return out, ∇__internal_solve_up
end
Expand Down
11 changes: 6 additions & 5 deletions ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
module SimpleNonlinearSolvePolyesterForwardDiffExt

using SimpleNonlinearSolve, PolyesterForwardDiff
using PolyesterForwardDiff: PolyesterForwardDiff
using SimpleNonlinearSolve: SimpleNonlinearSolve

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x,
chunksize) where {F}
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
f!::F, y, J, x, chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize)
return J
end

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
chunksize) where {F}
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(
f::F, J, x, chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
return J
end
Expand Down
99 changes: 53 additions & 46 deletions ext/SimpleNonlinearSolveReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,67 @@
module SimpleNonlinearSolveReverseDiffExt

using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
import ReverseDiff: TrackedArray, TrackedReal
using ArrayInterface: ArrayInterface
using DiffEqBase: DiffEqBase
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
import SimpleNonlinearSolve: __internal_solve_up

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray,
u0_changed, p, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal}, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(prob::$(pType), sensealg, u0, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
true, alg, args...; kwargs...)
end

function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::$(pType), sensealg, u0::AbstractArray{<:TrackedReal},
u0_changed, p, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p),
true, alg, args...; kwargs...)
end

ReverseDiff.@grad function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(_args...)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
ReverseDiff.@grad function __internal_solve_up(
prob::$(pType), sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
ReverseDiffOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(_args...)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end
return Array(out), ∇__internal_solve_up
end
end
return Array(out), ∇__internal_solve_up
end

end
2 changes: 1 addition & 1 deletion ext/SimpleNonlinearSolveStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SimpleNonlinearSolveStaticArraysExt

using SimpleNonlinearSolve
using SimpleNonlinearSolve: SimpleNonlinearSolve

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true

Expand Down
79 changes: 43 additions & 36 deletions ext/SimpleNonlinearSolveTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
module SimpleNonlinearSolveTrackerExt

using DiffEqBase, SciMLBase, SimpleNonlinearSolve, Tracker

function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
sensealg, u0::TrackedArray, u0_changed, p, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(prob::NonlinearProblem,
sensealg, u0, u0_changed, p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(
SimpleNonlinearSolve.__internal_solve_up, prob, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
end

Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(_prob::NonlinearProblem,
sensealg, u0_, u0_changed, p_, p_changed, alg, args...; kwargs...)
u0, p = Tracker.data(u0_), Tracker.data(p_)
prob = remake(_prob; u0, p)
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
SciMLBase.TrackerOriginator(), alg, args...; kwargs...)

function ∇__internal_solve_up(Δ)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
using DiffEqBase: DiffEqBase
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve
using Tracker: Tracker, TrackedArray

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray,
u0_changed, p, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return Tracker.track(SimpleNonlinearSolve.__internal_solve_up, prob, sensealg,
u0, u0_changed, p, p_changed, alg, args...; kwargs...)
end

Tracker.@grad function SimpleNonlinearSolve.__internal_solve_up(
_prob::$(pType), sensealg, u0_, u0_changed,
p_, p_changed, alg, args...; kwargs...)
u0, p = Tracker.data(u0_), Tracker.data(p_)
prob = remake(_prob; u0, p)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)

function ∇__internal_solve_up(Δ)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end

return out, ∇__internal_solve_up
end
end

return out, ∇__internal_solve_up
end

end
3 changes: 2 additions & 1 deletion ext/SimpleNonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SimpleNonlinearSolveZygoteExt

import SimpleNonlinearSolve, Zygote
using SimpleNonlinearSolve: SimpleNonlinearSolve
using Zygote: Zygote

SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true

Expand Down
Loading