diff --git a/src/jumps.jl b/src/jumps.jl index d80f3d92..0c70a73d 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -499,6 +499,7 @@ function JumpSet(vj, cj, rj, maj::MassActionJump{S, T, U, V}) where {S <: Number end JumpSet(jump::ConstantRateJump) = JumpSet((), (jump,), nothing, nothing) +JumpSet(jumps::AbstractVector{ConstantRateJump}) = JumpSet((), jumps, nothing, nothing) JumpSet(jump::VariableRateJump) = JumpSet((jump,), (), nothing, nothing) JumpSet(jump::RegularJump) = JumpSet((), (), jump, nothing) JumpSet(jump::AbstractMassActionJump) = JumpSet((), (), nothing, jump) diff --git a/src/problem.jl b/src/problem.jl index 192c7b8a..3ad21375 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -198,12 +198,15 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS ## Spatial jumps handling if spatial_system !== nothing && hopping_constants !== nothing - (num_crjs(jumps) == num_vrjs(jumps) == 0) || - error("Spatial aggregators only support MassActionJumps currently.") + (num_vrjs(jumps) == 0) || + error("Spatial aggregators currently only support MassActionJumps and ConstantRateJumps.") if is_spatial(aggregator) kwargs = merge((; hopping_constants, spatial_system), kwargs) else + if num_crjs(jumps) != 0 + error("Use a spatial SSA, e.g. DirectCRDirect in order to use ConstantRateJumps.") + end prob, maj = flatten(maj, prob, spatial_system, hopping_constants; kwargs...) end end diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index bb44b144..baf081f2 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -39,6 +39,9 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat # a dependency graph is needed if dep_graph === nothing + if length(rx_rates.cr_jumps) != 0 + error("Provide a dependency graph to use DirectCRDirect with constant rate jumps.") + end dg = make_dependency_graph(num_specs, rx_rates.ma_jumps) else dg = dep_graph @@ -54,6 +57,9 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat end if jumptovars_map === nothing + if length(rx_rates.cr_jumps) != 0 + error("Provide a jump-to-species dependency graph to use DirectCRDirect with constant rate jumps.") + end jtov_map = jump_to_vars_map(rx_rates.ma_jumps) else jtov_map = jumptovars_map @@ -94,7 +100,7 @@ function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder next_jump_time = typemax(typeof(end_time)) - rx_rates = RxRates(num_sites(spatial_system), majumps) + rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps) hop_rates = HopRates(hopping_constants, spatial_system) site_rates = zeros(typeof(end_time), num_sites(spatial_system)) @@ -199,4 +205,4 @@ end number of constant rate jumps """ -num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = 0 +num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = length(aggregator.rx_rates.cr_jumps) diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 3cfe7eed..2ae29805 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -32,6 +32,9 @@ function NSMJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop # a dependency graph is needed if dep_graph === nothing + if length(rx_rates.cr_jumps) != 0 + error("Provide a dependency graph to use NSM with constant rate jumps.") + end dg = make_dependency_graph(num_specs, rx_rates.ma_jumps) else dg = dep_graph @@ -47,6 +50,9 @@ function NSMJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop end if jumptovars_map === nothing + if length(rx_rates.cr_jumps) != 0 + error("Provide a jump-to-species dependency graph to use NSM with constant rate jumps.") + end jtov_map = jump_to_vars_map(rx_rates.ma_jumps) else jtov_map = jumptovars_map @@ -83,7 +89,7 @@ function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jum next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder next_jump_time = typemax(typeof(end_time)) - rx_rates = RxRates(num_sites(spatial_system), majumps) + rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps) hop_rates = HopRates(hopping_constants, spatial_system) NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, @@ -187,4 +193,4 @@ end number of constant rate jumps """ -num_constant_rate_jumps(aggregator::NSMJumpAggregation) = 0 +num_constant_rate_jumps(aggregator::NSMJumpAggregation) = length(aggregator.rx_rates.cr_jumps) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 6d688719..f86dffcc 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -1,9 +1,10 @@ """ -A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs +A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs. +Massaction jumps go first in the indexing, then constant rate jumps. """ ### spatial rx rates ### -struct RxRates{F, M} +struct RxRates{F, M, C} "rx_rates[i,j] is rate of reaction i at site j" rates::Matrix{F} @@ -12,20 +13,25 @@ struct RxRates{F, M} "AbstractMassActionJump" ma_jumps::M + + "indexable collection of ConstantRateJump" + cr_jumps::C end """ - RxRates(num_sites::Int, ma_jumps::M) where {M} + RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C} initializes RxRates with zero rates """ -function RxRates(num_sites::Int, ma_jumps::M) where {M} - numrxjumps = get_num_majumps(ma_jumps) +function RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C} + numrxjumps = get_num_majumps(ma_jumps) + length(cr_jumps) rates = zeros(Float64, numrxjumps, num_sites) - RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps) + RxRates{Float64, M, C}(rates, vec(sum(rates, dims = 1)), ma_jumps, cr_jumps) end +RxRates(num_sites::Int, ma_jumps::M) where {M<:AbstractMassActionJump} = RxRates(num_sites, ma_jumps, ConstantRateJump[]) +RxRates(num_sites::Int, cr_jumps::C) where {C} = RxRates(num_sites, SpatialMassActionJump(), cr_jumps) -num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) +num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) + length(rx_rates.cr_jumps) """ reset!(rx_rates::RxRates) @@ -48,7 +54,7 @@ function total_site_rx_rate(rx_rates::RxRates, site) end """ - update_rx_rates!(rx_rates, rxs, u, site) + update_rx_rates!(rx_rates, rxs, integrator, site) update rates of all reactions in rxs at site """ @@ -56,8 +62,13 @@ function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, site) ma_jumps = rx_rates.ma_jumps @inbounds for rx in rxs - rate = eval_massaction_rate(u, rx, ma_jumps, site) - set_rx_rate_at_site!(rx_rates, site, rx, rate) + if is_massaction(rx_rates, rx) + rate = eval_massaction_rate(u, rx, ma_jumps, site) + set_rx_rate_at_site!(rx_rates, site, rx, rate) + else + cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(ma_jumps)] + set_rx_rate_at_site!(rx_rates, site, rx, cr_jump.rate(u, integrator.p, integrator.t, site)) + end end end @@ -77,6 +88,16 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end +function execute_rx_at_site!(integrator, rx_rates::RxRates, rx, site) + if is_massaction(rx_rates, rx) + @inbounds executerx!((@view integrator.u[:, site]), rx, + rx_rates.ma_jumps) + else + cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(rx_rates.ma_jumps)] + cr_jump.affect!(integrator, site) + end +end + # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] @@ -90,5 +111,10 @@ function Base.show(io::IO, ::MIME"text/plain", rx_rates::RxRates) println(io, "RxRates with $num_rxs reactions and $num_sites sites") end +"Return true if jump is a massaction jump." +function is_massaction(rx_rates::RxRates, rx) + rx <= get_num_majumps(rx_rates.ma_jumps) +end + eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: SpatialMassActionJump} = evalrxrate(u, rx, ma_jumps, site) eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: MassActionJump} = evalrxrate((@view u[:, site]), rx, ma_jumps) diff --git a/src/spatial/spatial_massaction_jump.jl b/src/spatial/spatial_massaction_jump.jl index 34504f1b..2a5dd437 100644 --- a/src/spatial/spatial_massaction_jump.jl +++ b/src/spatial/spatial_massaction_jump.jl @@ -89,6 +89,12 @@ function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy) end +function SpatialMassActionJump() + empty_majump = MassActionJump(Vector{Float64}(), + Vector{Vector{Pair{Int, Int}}}(), + Vector{Vector{Pair{Int, Int}}}()) + SpatialMassActionJump(empty_majump) +end ############################################## function get_num_majumps(smaj::SpatialMassActionJump{Nothing, Nothing, S, U, V}) where diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index ddb9db41..d38f797b 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -9,7 +9,7 @@ struct SpatialJump{J} "source location" src::J - "index of jump as a hop or reaction" + "index of jump as a hop or reaction, hops first, then massaction reactions, then constant rate reactions" jidx::Int "destination location, equal to src for within-site reactions" @@ -69,8 +69,7 @@ function update_state!(p, integrator) execute_hop!(integrator, jump.src, jump.dst, jump.jidx) else rx_index = reaction_id_from_jump(p, jump) - @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + execute_rx_at_site!(integrator, p.rx_rates, rx_index, jump.src) end # save jump that was just exectued p.prev_jump = jump diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index c44358e8..a753f752 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -26,6 +26,23 @@ netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]] rates = [0.1 / mesh_size, 1.0] majumps = MassActionJump(rates, reactstoch, netstoch) +# equivalent constant rate jumps +rate1(u,p,t,site) = u[1,site]*u[2,site] / 2 +rate2(u,p,t,site) = u[3,site] +affect1!(integrator,site) = begin + integrator.u[1, site] -= 1 + integrator.u[2, site] -= 1 + integrator.u[3, site] += 1 +end +affect2!(integrator,site) = begin + integrator.u[1, site] += 1 + integrator.u[2, site] += 1 + integrator.u[3, site] -= 1 +end +crjumps = JumpSet([ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!)]) +dep_graph = [[1,2],[1,2]] +jumptovars_map = [[1,2,3],[1,2,3]] + # spatial system setup hopping_rate = diffusivity * (linear_size / domain_size)^2 @@ -56,6 +73,11 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng = rng)) +# setup constant rate jump problems +push!(jump_problems, JumpProblem(prob, NSM(), crjumps, hopping_constants = hopping_constants, + spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng)) +push!(jump_problems, JumpProblem(prob, DirectCRDirect(), crjumps, hopping_constants = hopping_constants, + spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng)) # setup flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, diff --git a/test/spatial/reaction_rates.jl b/test/spatial/reaction_rates.jl index 09c25e3c..a95d04e0 100644 --- a/test/spatial/reaction_rates.jl +++ b/test/spatial/reaction_rates.jl @@ -24,22 +24,31 @@ num_species = 3 reactstoch = [[1 => 1, 2 => 1], [3 => 1]] netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]] rates = [0.1, 1.0] -num_rxs = length(rates) ma_jumps = MassActionJump(rates, reactstoch, netstoch) spatial_ma_jumps = SpatialMassActionJump(rates, reactstoch, netstoch) +rate_fn = (u, p, t, site) -> 1.0 +affect_fn!(integrator) = nothing # a dummy reaction, does nothing +cr_jumps = [ConstantRateJump(rate_fn, affect_fn!)] +num_rxs = 3 u = ones(Int, num_species, num_nodes) integrator = DummyIntegrator(u,nothing,nothing) rng = StableRNG(12345) +# Test constructors +@test JP.RxRates(num_nodes, ma_jumps).ma_jumps == ma_jumps +@test JP.RxRates(num_nodes, spatial_ma_jumps).ma_jumps == spatial_ma_jumps +@test JP.RxRates(num_nodes, cr_jumps).cr_jumps == cr_jumps + # Tests for RxRates -rx_rates_list = [JP.RxRates(num_nodes, ma_jumps), JP.RxRates(num_nodes, spatial_ma_jumps)] +rx_rates_list = [JP.RxRates(num_nodes, ma_jumps, cr_jumps), JP.RxRates(num_nodes, spatial_ma_jumps, cr_jumps)] for rx_rates in rx_rates_list - @test JP.num_rxs(rx_rates) == length(rates) + @test JP.num_rxs(rx_rates) == num_rxs show(io, "text/plain", rx_rates) for site in 1:num_nodes JP.update_rx_rates!(rx_rates, 1:num_rxs, integrator, site) - @test JP.total_site_rx_rate(rx_rates, site) == 1.1 - rx_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:num_rxs] + @test JP.total_site_rx_rate(rx_rates, site) == 2.1 + majump_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:2] + rx_props = [majump_props..., 1.0] rx_probs = rx_props / sum(rx_props) d = Dict{Int, Int}() for i in 1:num_samples