Skip to content

Add DAE support for GPU kernels with mass matrices and initialization #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand Down Expand Up @@ -55,6 +56,7 @@ RecursiveArrayTools = "3"
SciMLBase = "2.92"
Setfield = "1"
SimpleDiffEq = "1"
SimpleNonlinearSolve = "2"
StaticArrays = "1"
TOML = "1"
ZygoteRules = "0.2"
Expand Down
4 changes: 4 additions & 0 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using RecursiveArrayTools
import ZygoteRules
import Base.Threads
using LinearSolve
using SimpleNonlinearSolve
import SimpleNonlinearSolve: SimpleTrustRegion
#For gpu_tsit5
using Adapt, SimpleDiffEq, StaticArrays
using Parameters, MuladdMacro
Expand Down Expand Up @@ -51,6 +53,7 @@ include("ensemblegpukernel/integrators/stiff/interpolants.jl")
include("ensemblegpukernel/integrators/nonstiff/interpolants.jl")
include("ensemblegpukernel/nlsolve/type.jl")
include("ensemblegpukernel/nlsolve/utils.jl")
include("ensemblegpukernel/nlsolve/initialization.jl")
include("ensemblegpukernel/kernels.jl")

include("ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl")
Expand All @@ -71,6 +74,7 @@ include("ensemblegpukernel/tableaus/kvaerno_tableaus.jl")
include("utils.jl")
include("algorithms.jl")
include("solve.jl")
include("dae_adapt.jl")

export EnsembleCPUArray, EnsembleGPUArray, EnsembleGPUKernel, LinSolveGPUSplitFactorize

Expand Down
26 changes: 26 additions & 0 deletions src/dae_adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Override SciMLBase adapt functions to allow DAEs for GPU kernels
import SciMLBase: adapt_structure
import Adapt

# Allow DAE adaptation for GPU kernels
function adapt_structure(to, f::SciMLBase.ODEFunction{iip}) where {iip}
# For GPU kernels, we now support DAEs with mass matrices and initialization
SciMLBase.ODEFunction{iip, SciMLBase.FullSpecialize}(
f.f,
jac = f.jac,
mass_matrix = f.mass_matrix,
initialization_data = f.initialization_data
)
end

# Adapt OverrideInitData for GPU compatibility
function adapt_structure(to, f::SciMLBase.OverrideInitData)
SciMLBase.OverrideInitData(
adapt(to, f.initializeprob), # Also adapt initializeprob
f.update_initializeprob!,
f.initializeprobmap,
f.initializeprobpmap,
nothing, # Set metadata to nothing for GPU compatibility
f.is_update_oop
)
end
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ end
end

# interp_points = 0 or equivalently nothing
@inline function DiffEqBase.determine_event_occurrence(
@inline function DiffEqBase.determine_event_occurance(
integrator::DiffEqBase.AbstractODEIntegrator{
AlgType,
IIP,
Expand Down
30 changes: 23 additions & 7 deletions src/ensemblegpukernel/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@

saveat = _saveat === nothing ? saveat : _saveat

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops,
callback, save_everystep, saveat)
# Check if initialization is needed for DAEs
u0, p_init,
init_success = if SciMLBase.has_initialization_data(prob.f)
# Perform initialization using SimpleNonlinearSolve compatible algorithm
gpu_initialization_solve(prob, SimpleTrustRegion(), 1e-6, 1e-6)
else
prob.u0, prob.p, true
end

u0 = prob.u0
# Use initialized values
integ = init(alg, prob.f, false, u0, prob.tspan[1], dt, p_init, tstops,
callback, save_everystep, saveat)
tspan = prob.tspan

integ.cur_t = 0
Expand Down Expand Up @@ -68,16 +76,24 @@ end

saveat = _saveat === nothing ? saveat : _saveat

u0 = prob.u0
# Check if initialization is needed for DAEs
u0, p_init,
init_success = if SciMLBase.has_initialization_data(prob.f)
# Perform initialization using SimpleNonlinearSolve compatible algorithm
gpu_initialization_solve(prob, SimpleTrustRegion(), abstol, reltol)
else
prob.u0, prob.p, true
end

tspan = prob.tspan
f = prob.f
p = prob.p
p = p_init

t = tspan[1]
tf = prob.tspan[2]

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt,
prob.p,
integ = init(alg, prob.f, false, u0, prob.tspan[1], prob.tspan[2], dt,
p,
abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback,
saveat)

Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
```julia
vectorized_solve(probs, prob::Union{ODEProblem, SDEProblem}alg;
vectorized_solve(probs, prob::Union{ODEProblem, SDEProblem}, alg;
dt, saveat = nothing,
save_everystep = true,
debug = false, callback = CallbackSet(nothing), tstops = nothing)
Expand Down
60 changes: 60 additions & 0 deletions src/ensemblegpukernel/nlsolve/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@inline function gpu_initialization_solve(prob, nlsolve_alg, abstol, reltol)
f = prob.f
u0 = prob.u0
p = prob.p

# Check if initialization is actually needed
if !SciMLBase.has_initialization_data(f) || f.initialization_data === nothing
return u0, p, true
end

initdata = f.initialization_data
if initdata.initializeprob === nothing
return u0, p, true
end

# Use SimpleNonlinearSolve directly - it's GPU compatible
try
# Default to SimpleTrustRegion if no algorithm specified
alg = nlsolve_alg === nothing ? SimpleTrustRegion() : nlsolve_alg

# Create initialization problem
initprob = initdata.initializeprob

# Update the problem if needed
if initdata.update_initializeprob! !== nothing
if initdata.is_update_oop === Val(true)
initprob = initdata.update_initializeprob!(initprob, (u=u0, p=p))
else
initdata.update_initializeprob!(initprob, (u=u0, p=p))
end
end

# Solve initialization problem using SimpleNonlinearSolve
sol = solve(initprob, alg; abstol, reltol)

# Extract results
if SciMLBase.successful_retcode(sol)
# Apply result mappings if they exist
u_init = if initdata.initializeprobmap !== nothing
initdata.initializeprobmap(sol)
else
u0
end

p_init = if initdata.initializeprobpmap !== nothing
initdata.initializeprobpmap((u=u0, p=p), sol)
else
p
end

return u_init, p_init, true
else
# If initialization fails, use original values
return u0, p, false
end
catch
# If anything goes wrong, fall back to original values
return u0, p, false
end
end
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/nlsolve/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
else
finite_diff_jac(u -> f(u, p, t), f.jac_prototype, u)
end
W(u, p, t) = -LinearAlgebra.I + γ * dt * J(u, p, t)
W(u, p, t) = -f.mass_matrix + γ * dt * J(u, p, t)
J, W
end

Expand Down
6 changes: 4 additions & 2 deletions src/ensemblegpukernel/perform_step/gpu_rodas4_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
dtgamma = dt * γ

# Starting
W = J - I * inv(dtgamma)
mass_matrix = f.mass_matrix
W = J - mass_matrix * inv(dtgamma)
du = f(uprev, p, t)

# Step 1
Expand Down Expand Up @@ -182,7 +183,8 @@ end
dtgamma = dt * γ

# Starting
W = J - I * inv(dtgamma)
mass_matrix = f.mass_matrix
W = mass_matrix / dtgamma - J
du = f(uprev, p, t)

# Step 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
dtgamma = dt * γ

# Starting
W = J - I * inv(dtgamma)
mass_matrix = f.mass_matrix
W = J - mass_matrix * inv(dtgamma)
du = f(uprev, p, t)

# Step 1
Expand Down Expand Up @@ -229,7 +230,8 @@ end
dtgamma = dt * γ

# Starting
W = J - I * inv(dtgamma)
mass_matrix = f.mass_matrix
W = mass_matrix / dtgamma - J
du = f(uprev, p, t)

# Step 1
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand Down
62 changes: 62 additions & 0 deletions test/gpu_kernel_de/stiff_ode/gpu_ode_modelingtoolkit_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using ModelingToolkit, DiffEqGPU, OrdinaryDiffEq, LinearAlgebra, Test
using ModelingToolkit: t_nounits as t, D_nounits as D
using KernelAbstractions: CPU

# ModelingToolkit problems are too complex for GPU array adaptation,
# so we use CPU backend for DAE testing
const backend = CPU()

# Define the cartesian pendulum DAE system
@parameters g = 9.81 L = 1.0
@variables x(t) y(t) [state_priority = 10] λ(t)

# The cartesian pendulum DAE system:
# m*ddot(x) = (x/L)*λ (simplified with m=1)
# m*ddot(y) = (y/L)*λ - mg (simplified with m=1)
# x^2 + y^2 = L^2 (constraint equation)
eqs = [D(D(x)) ~ λ * x / L
D(D(y)) ~ λ * y / L - g
x^2 + y^2 ~ L^2]

@named pendulum = ODESystem(eqs, t, [x, y, λ], [g, L])

# Perform structural simplification with index reduction
pendulum_sys = structural_simplify(dae_index_lowering(pendulum))

# Initial conditions: pendulum starts at bottom right position
u0 = [x => 1.0, y => 0.0, λ => -g] # λ initial guess for tension

# Time span
tspan = (0.0f0, 1.0f0)

# Create the ODE problem
prob = ODEProblem(pendulum_sys, u0, tspan, Float32[])

# Verify DAE properties
@test SciMLBase.has_initialization_data(prob.f) == true
@test prob.f.mass_matrix !== nothing

# Create ensemble problem for GPU testing
monteprob = EnsembleProblem(prob, safetycopy = false)

# Test with GPURosenbrock23
sol = solve(monteprob, GPURosenbrock23(), EnsembleGPUKernel(backend),
trajectories = 2,
dt = 0.01f0,
adaptive = false)

@test length(sol.u) == 2

# Check constraint satisfaction: x^2 + y^2 ≈ L^2
final_state = sol.u[1][end]
x_final, y_final = final_state[1], final_state[2]
constraint_error = abs(x_final^2 + y_final^2 - 1.0f0)
@test constraint_error < 0.1f0 # Reasonable tolerance for fixed timestep

# Test with GPURodas4
sol2 = solve(monteprob, GPURodas4(), EnsembleGPUKernel(backend),
trajectories = 2,
dt = 0.01f0,
adaptive = false)

@test length(sol2.u) == 2
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using SafeTestsets, Test
@time @safetestset "GPU Kernelized Stiff ODE Mass Matrix" begin
include("gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl")
end
@time @safetestset "GPU Kernelized ModelingToolkit DAE" begin
include("gpu_kernel_de/stiff_ode/gpu_ode_modelingtoolkit_dae.jl")
end
@time @testset "GPU Kernelized Non Stiff ODE Regression" begin
include("gpu_kernel_de/gpu_ode_regression.jl")
end
Expand Down
Loading
Loading