diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index 01cc25fc..baa6ceb1 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -5,7 +5,6 @@ using Distributions using Plots using AbstractMCMC using Random123 -using Libtask """ plot_update_rate(update_rate, N) @@ -91,25 +90,19 @@ plot(x; label="x", xlabel="t") plot(y; label="y", xlabel="t") # Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition: -mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel - X::Array +mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel + X::Vector{Float64} θ::Parameters - NonLinearTimeSeries(θ::Parameters) = new(zeros(Float64, θ.T), θ) + NonLinearTimeSeries(θ::Parameters) = new(Float64[], θ) end -function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG) - x₀ = rand(rng, f₀(model.θ)) - model.X[1] = x₀ - score = logpdf(g(model.θ, x₀, 1), y[1]) - Libtask.produce(score) - - for t in 2:(model.θ.T) - state = rand(rng, f(model.θ, model.X[t - 1], t - 1)) - model.X[t] = state - score = logpdf(g(model.θ, state, t), y[t]) - Libtask.produce(score) - end +# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface: +AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ) +AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step) +function AdvancedPS.observation(model::NonLinearTimeSeries, state, step) + return logpdf(g(model.θ, state, step), y[step]) end +AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ # Here we use the particle gibbs kernel without adaptive resampling. model = NonLinearTimeSeries(θ₀) @@ -117,13 +110,15 @@ pg = AdvancedPS.PG(Nₚ, 1.0) chains = sample(rng, model, pg, Nₛ; progress=false); #md nothing #hide -particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states +particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states mean_trajectory = mean(particles; dims=2); #md nothing #hide # We can now plot all the generated traces. # Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below. -scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state") +scatter( + particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state" +) plot!(x; color=:darkorange, label="Original Trajectory") plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) @@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ) # Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS. -# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way: -mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel - X::Vector{Float64} - θ::Parameters - NonLinearSSM(θ::Parameters) = new(Float64[], θ) -end - -AdvancedPS.initialization(model::NonLinearSSM) = f₀(model.θ) -AdvancedPS.transition(model::NonLinearSSM, state, step) = f(model.θ, state, step) -function AdvancedPS.observation(model::NonLinearSSM, state, step) - return logpdf(g(model.θ, state, step), y[step]) -end -AdvancedPS.isdone(::NonLinearSSM, step) = step > Tₘ - -# We can now sample from the model using the PGAS sampler and collect the trajectories. pgas = AdvancedPS.PGAS(Nₚ) -model = NonLinearSSM(θ₀) chains = sample(rng, model, pgas, Nₛ; progress=false); particles = hcat([chain.trajectory.model.X for chain in chains]...); mean_trajectory = mean(particles; dims=2); # The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods. -scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state") +scatter( + particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state" +) plot!(x; color=:darkorange, label="Original Trajectory") plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index 7745a60d..2f60ef72 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -146,7 +146,7 @@ function AbstractMCMC.step( # Perform a particle sweep. reference = isref ? particles.vals[nparticles] : nothing - logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, reference) + logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler, reference) # Pick a particle to be retained. newtrajectory = rand(rng, particles) @@ -184,7 +184,7 @@ function AbstractMCMC.sample( particles = AdvancedPS.ParticleContainer(traces, AdvancedPS.TracedRNG(), rng) # Perform particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler) + logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler) replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals) diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index ad53c024..8ce7c2b7 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -8,6 +8,8 @@ using Random123: Random123 abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end +abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end + """ Abstract type for an abstract model formulated in the state space form """ abstract type AbstractStateSpaceModel <: AbstractParticleModel end @@ -17,8 +19,8 @@ include("resampling.jl") include("rng.jl") include("model.jl") include("container.jl") -include("pgas.jl") include("smc.jl") +include("pgas.jl") if !isdefined(Base, :get_extension) using Requires diff --git a/src/container.jl b/src/container.jl index 6510fe67..70e3a424 100644 --- a/src/container.jl +++ b/src/container.jl @@ -61,7 +61,11 @@ end Update reference trajectory. Defaults to `nothing` """ -update_ref!(particle::Trace, pc::ParticleContainer) = nothing +function update_ref!( + particle::Trace, pc::ParticleContainer, sampler::AbstractParticleSampler +) + return nothing +end """ reset_logweights!(pc::ParticleContainer) @@ -167,6 +171,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere function resample_propagate!( ::Random.AbstractRNG, pc::ParticleContainer, + sampler::AbstractParticleSampler, randcat=DEFAULT_RESAMPLER, ref::Union{Particle,Nothing}=nothing; weights=getweights(pc), @@ -214,7 +219,7 @@ function resample_propagate!( if ref !== nothing # Insert the retained particle. This is based on the replaying trick for efficiency # reasons. If we implement PG using task copying, we need to store Nx * T particles! - update_ref!(ref, pc) + update_ref!(ref, pc, sampler) @inbounds children[n] = ref end @@ -228,6 +233,7 @@ end function resample_propagate!( rng::Random.AbstractRNG, pc::ParticleContainer, + sampler::AbstractParticleSampler, resampler::ResampleWithESSThreshold, ref::Union{Particle,Nothing}=nothing; weights=getweights(pc), @@ -236,7 +242,7 @@ function resample_propagate!( ess = inv(sum(abs2, weights)) if ess ≤ resampler.threshold * length(pc) - resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights) + resample_propagate!(rng, pc, sampler, resampler.resampler, ref; weights=weights) else update_keys!(pc, ref) end @@ -311,11 +317,12 @@ function sweep!( rng::Random.AbstractRNG, pc::ParticleContainer, resampler, + sampler::AbstractMCMC.AbstractSampler, ref::Union{Particle,Nothing}=nothing, ) # Initial step: # Resample and propagate particles. - resample_propagate!(rng, pc, resampler, ref) + resample_propagate!(rng, pc, sampler, resampler, ref) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. @@ -336,7 +343,7 @@ function sweep!( # For observations ``y₂, …, yₜ``: while !isdone # Resample and propagate particles. - resample_propagate!(rng, pc, resampler, ref) + resample_propagate!(rng, pc, sampler, resampler, ref) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. diff --git a/src/pgas.jl b/src/pgas.jl index 2c00aaa9..8dfb449e 100644 --- a/src/pgas.jl +++ b/src/pgas.jl @@ -133,7 +133,7 @@ function forkr(particle::SSMTrace) return newtrace end -function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}) +function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS) current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1 isdone(ref.model, current_step(ref)) && return nothing diff --git a/src/smc.jl b/src/smc.jl index aabd0860..c45b127a 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -1,4 +1,4 @@ -struct SMC{R} <: AbstractMCMC.AbstractSampler +struct SMC{R} <: AbstractParticleSampler nparticles::Int resampler::R end @@ -46,12 +46,12 @@ function AbstractMCMC.sample( particles = ParticleContainer(traces, TracedRNG(), rng) # Perform particle sweep. - logevidence = sweep!(rng, particles, sampler.resampler) + logevidence = sweep!(rng, particles, sampler.resampler, sampler) return SMCSample(collect(particles), getweights(particles), logevidence) end -struct PG{R} <: AbstractMCMC.AbstractSampler +struct PG{R} <: AbstractParticleSampler """Number of particles.""" nparticles::Int """Resampling algorithm.""" @@ -84,7 +84,7 @@ struct PGSample{T,L} logevidence::L end -struct PGAS{R} <: AbstractMCMC.AbstractSampler +struct PGAS{R} <: AbstractParticleSampler """Number of particles.""" nparticles::Int """Resampling algorithm.""" @@ -96,7 +96,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0)) function AbstractMCMC.step( rng::Random.AbstractRNG, model::AbstractStateSpaceModel, - sampler::PGAS, + sampler::Union{PGAS,PG}, state::Union{PGState,Nothing}=nothing; kwargs..., ) @@ -116,7 +116,7 @@ function AbstractMCMC.step( # Perform a particle sweep. reference = isref ? particles.vals[nparticles] : nothing - logevidence = sweep!(rng, particles, sampler.resampler, reference) + logevidence = sweep!(rng, particles, sampler.resampler, sampler, reference) # Pick a particle to be retained. newtrajectory = rand(particles.rng, particles) diff --git a/test/container.jl b/test/container.jl index 1a57a73c..bada9b9c 100644 --- a/test/container.jl +++ b/test/container.jl @@ -73,8 +73,9 @@ ref = AdvancedPS.forkr(selected) pc_ref.vals[end] = ref + sampler = AdvancedPS.PG(length(logps)) AdvancedPS.resample_propagate!( - Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, ref + Random.GLOBAL_RNG, pc_ref, sampler, AdvancedPS.resample_systematic, ref ) @test pc_ref.logWs == zeros(3) @test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3) @@ -84,7 +85,7 @@ @test pc_ref.vals[end] === particles_ref[end] # Resample and propagate particles. - AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc) + AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc, sampler) @test pc.logWs == zeros(3) @test AdvancedPS.getweights(pc) == fill(1 / 3, 3) @test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3) diff --git a/test/pgas.jl b/test/pgas.jl index 2ceb3d33..8c7b3160 100644 --- a/test/pgas.jl +++ b/test/pgas.jl @@ -46,6 +46,7 @@ AdvancedPS.Trace(BaseModel(Params(0.9, 0.31, 1)), AdvancedPS.TracedRNG()) for _ in 1:3 ] + sampler = AdvancedPS.PGAS(3) resampler = AdvancedPS.ResampleWithESSThreshold(1.0) part = particles[3] @@ -58,11 +59,11 @@ pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG(), base_rng) AdvancedPS.reweight!(pc, ref) - AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref) + AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref) AdvancedPS.reweight!(pc, ref) pc.logWs = [-Inf, 0, -Inf] # Force ancestor update to second particle - AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref) + AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref) AdvancedPS.reweight!(pc, ref) @test all(pc.vals[2].model.X[1:2] .≈ ref.model.X[1:2])