Skip to content

Commit 8518ef4

Browse files
Expose API for Turing integration (#90)
* Move things around * Update Project.toml * Add test, fix comments --------- Co-authored-by: Hong Ge <[email protected]>
1 parent b157064 commit 8518ef4

File tree

5 files changed

+67
-40
lines changed

5 files changed

+67
-40
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,29 @@ end
1919
"""
2020
LibtaskModel{F}
2121
22-
State wrapper to hold `Libtask.CTask` model initiated from `f`
22+
State wrapper to hold `Libtask.CTask` model initiated from `f`.
2323
"""
24-
struct LibtaskModel{F1,F2}
25-
f::F1
26-
ctask::Libtask.TapedTask{F2}
27-
28-
LibtaskModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask)
29-
end
30-
31-
function LibtaskModel(f, args...)
32-
return LibtaskModel(
24+
function AdvancedPS.LibtaskModel(
25+
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
26+
) # Changed the API, need to take care of the RNG properly
27+
return AdvancedPS.LibtaskModel(
3328
f,
34-
Libtask.TapedTask(f, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}),
29+
Libtask.TapedTask(
30+
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
31+
),
3532
)
3633
end
3734

38-
Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask))
35+
function Base.copy(model::AdvancedPS.LibtaskModel)
36+
return AdvancedPS.LibtaskModel(model.f, copy(model.ctask))
37+
end
3938

40-
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R}
39+
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
4140

4241
function AdvancedPS.Trace(
4342
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
4443
)
45-
return AdvancedPS.Trace(LibtaskModel(model, args...), rng)
44+
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
4645
end
4746

4847
# step to the next observe statement and
@@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
5655
end
5756

5857
# create a backward reference in task_local_storage
59-
function addreference!(task::Task, trace::LibtaskTrace)
58+
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
6059
if task.storage === nothing
6160
task.storage = IdDict()
6261
end
@@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace)
6564
return task
6665
end
6766

68-
current_trace() = current_task().storage[:__trace]
69-
70-
function update_rng!(trace::LibtaskTrace)
67+
function AdvancedPS.update_rng!(trace::LibtaskTrace)
7168
rng, = trace.model.ctask.args
7269
trace.rng = rng
7370
return trace
@@ -76,12 +73,12 @@ end
7673
# Task copying version of fork for Trace.
7774
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
7875
newtrace = copy(trace)
79-
update_rng!(newtrace)
76+
AdvancedPS.update_rng!(newtrace)
8077
isref && AdvancedPS.delete_retained!(newtrace.model.f)
8178
isref && delete_seeds!(newtrace)
8279

8380
# add backward reference
84-
addreference!(newtrace.model.ctask.task, newtrace)
81+
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
8582
return newtrace
8683
end
8784

@@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
9491
ctask = Libtask.TapedTask(
9592
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
9693
)
97-
new_tapedmodel = LibtaskModel(newf, ctask)
94+
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
9895

9996
# add backward reference
10097
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
101-
addreference!(ctask.task, newtrace)
98+
AdvancedPS.addreference!(ctask.task, newtrace)
10299
AdvancedPS.gen_refseed!(newtrace)
103100
return newtrace
104101
end
@@ -135,9 +132,8 @@ function AbstractMCMC.step(
135132
AdvancedPS.forkr(copy(state.trajectory))
136133
else
137134
trng = AdvancedPS.TracedRNG()
138-
gen_model = LibtaskModel(deepcopy(model), trng)
139-
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
140-
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
135+
trace = AdvancedPS.Trace(deepcopy(model), trng)
136+
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
141137
trace
142138
end
143139
end
@@ -174,9 +170,8 @@ function AbstractMCMC.sample(
174170

175171
traces = map(1:(sampler.nparticles)) do i
176172
trng = AdvancedPS.TracedRNG()
177-
gen_model = LibtaskModel(deepcopy(model), trng)
178-
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
179-
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
173+
trace = AdvancedPS.Trace(deepcopy(model), trng)
174+
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
180175
trace
181176
end
182177

@@ -202,7 +197,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle)
202197
trng = deepcopy(particle.rng)
203198
Random123.set_counter!(trng.rng, 0)
204199
trng.count = 1
205-
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(particle.model.f), trng), trng)
200+
trace = AdvancedPS.Trace(
201+
AdvancedPS.LibtaskModel(deepcopy(particle.model.f), trng), trng
202+
)
206203
score = AdvancedPS.advance!(trace, true)
207204
while !isnothing(score)
208205
score = AdvancedPS.advance!(trace, true)

src/model.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Trace{F,R}
2+
Trace{F,R}
33
"""
44
mutable struct Trace{F,R}
55
model::F
@@ -10,24 +10,37 @@ const Particle = Trace
1010
const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R}
1111
const GenericTrace{R} = Trace{<:AbstractGenericModel,R}
1212

13-
# reset log probability
1413
reset_logprob!(::AdvancedPS.Particle) = nothing
15-
1614
reset_model(f) = deepcopy(f)
1715
delete_retained!(f) = nothing
1816

19-
Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))
17+
"""
18+
copy(trace::Trace)
2019
21-
# This is required to make it visible from outside extensions
22-
function observe end
23-
function replay end
20+
Copy a trace. The `TracedRNG` is deep-copied. The inner model is shallow-copied.
21+
"""
22+
Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))
2423

2524
"""
26-
gen_refseed!(particle::Particle)
25+
gen_refseed!(particle::Particle)
2726
28-
Generate a new seed for the reference particle
27+
Generate a new seed for the reference particle.
2928
"""
3029
function gen_refseed!(particle::Particle)
3130
seed = split(state(particle.rng.rng), 1)
3231
return safe_set_refseed!(particle.rng, seed[1])
3332
end
33+
34+
# A few internal functions used in the Libtask extension. Since it is not possible to access objects defined
35+
# in an extension, we just define dummy in the main module and implement them in the extension.
36+
function observe end
37+
function replay end
38+
function addreference! end
39+
40+
current_trace() = current_task().storage[:__trace]
41+
42+
# We need this one to be visible outside of the extension for dispatching (Turing.jl).
43+
struct LibtaskModel{F,T}
44+
f::F
45+
ctask::T
46+
end

src/rng.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,5 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
117117
Increase the model step counter by `n`
118118
"""
119119
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n
120+
121+
function update_rng! end

test/container.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130

131131
# Test task copy version of trace
132132
trng = AdvancedPS.TracedRNG()
133-
tr = AdvancedPS.Trace(Model(Ref(0)), trng, trng)
133+
tr = AdvancedPS.Trace(Model(Ref(0)), trng)
134134

135135
consume(tr.model.ctask)
136136
consume(tr.model.ctask)
@@ -143,6 +143,21 @@
143143
@test consume(a.model.ctask) == 4
144144
end
145145

146+
@testset "current trace" begin
147+
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end
148+
149+
function (model::TaskIdModel)(rng::Random.AbstractRNG)
150+
# Just print the task it's running in
151+
id = objectid(AdvancedPS.current_trace())
152+
return Libtask.produce(id)
153+
end
154+
155+
trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
156+
AdvancedPS.addreference!(trace.model.ctask.task, trace)
157+
158+
@test AdvancedPS.advance!(trace, false) === objectid(trace)
159+
end
160+
146161
@testset "seed container" begin
147162
seed = 1
148163
n = 3

0 commit comments

Comments
 (0)