19
19
"""
20
20
LibtaskModel{F}
21
21
22
- State wrapper to hold `Libtask.CTask` model initiated from `f`
22
+ State wrapper to hold `Libtask.CTask` model initiated from `f`.
23
23
"""
24
- struct LibtaskModel{F1,F2}
25
- f:: F1
26
- ctask:: Libtask.TapedTask{F2}
27
-
28
- LibtaskModel (f:: F1 , ctask:: Libtask.TapedTask{F2} ) where {F1,F2} = new {F1,F2} (f, ctask)
29
- end
30
-
31
- function LibtaskModel (f, args... )
32
- return LibtaskModel (
24
+ function AdvancedPS. LibtaskModel (
25
+ f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
26
+ ) # Changed the API, need to take care of the RNG properly
27
+ return AdvancedPS. LibtaskModel (
33
28
f,
34
- Libtask. TapedTask (f, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}),
29
+ Libtask. TapedTask (
30
+ f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31
+ ),
35
32
)
36
33
end
37
34
38
- Base. copy (model:: LibtaskModel ) = LibtaskModel (model. f, copy (model. ctask))
35
+ function Base. copy (model:: AdvancedPS.LibtaskModel )
36
+ return AdvancedPS. LibtaskModel (model. f, copy (model. ctask))
37
+ end
39
38
40
- const LibtaskTrace{R} = AdvancedPS. Trace{<: LibtaskModel ,R}
39
+ const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS. LibtaskModel ,R}
41
40
42
41
function AdvancedPS. Trace (
43
42
model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
44
43
)
45
- return AdvancedPS. Trace (LibtaskModel (model, args... ), rng)
44
+ return AdvancedPS. Trace (AdvancedPS . LibtaskModel (model, rng , args... ), rng)
46
45
end
47
46
48
47
# step to the next observe statement and
@@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
56
55
end
57
56
58
57
# create a backward reference in task_local_storage
59
- function addreference! (task:: Task , trace:: LibtaskTrace )
58
+ function AdvancedPS . addreference! (task:: Task , trace:: LibtaskTrace )
60
59
if task. storage === nothing
61
60
task. storage = IdDict ()
62
61
end
@@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace)
65
64
return task
66
65
end
67
66
68
- current_trace () = current_task (). storage[:__trace ]
69
-
70
- function update_rng! (trace:: LibtaskTrace )
67
+ function AdvancedPS. update_rng! (trace:: LibtaskTrace )
71
68
rng, = trace. model. ctask. args
72
69
trace. rng = rng
73
70
return trace
76
73
# Task copying version of fork for Trace.
77
74
function AdvancedPS. fork (trace:: LibtaskTrace , isref:: Bool = false )
78
75
newtrace = copy (trace)
79
- update_rng! (newtrace)
76
+ AdvancedPS . update_rng! (newtrace)
80
77
isref && AdvancedPS. delete_retained! (newtrace. model. f)
81
78
isref && delete_seeds! (newtrace)
82
79
83
80
# add backward reference
84
- addreference! (newtrace. model. ctask. task, newtrace)
81
+ AdvancedPS . addreference! (newtrace. model. ctask. task, newtrace)
85
82
return newtrace
86
83
end
87
84
@@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
94
91
ctask = Libtask. TapedTask (
95
92
newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
96
93
)
97
- new_tapedmodel = LibtaskModel (newf, ctask)
94
+ new_tapedmodel = AdvancedPS . LibtaskModel (newf, ctask)
98
95
99
96
# add backward reference
100
97
newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
101
- addreference! (ctask. task, newtrace)
98
+ AdvancedPS . addreference! (ctask. task, newtrace)
102
99
AdvancedPS. gen_refseed! (newtrace)
103
100
return newtrace
104
101
end
@@ -135,9 +132,8 @@ function AbstractMCMC.step(
135
132
AdvancedPS. forkr (copy (state. trajectory))
136
133
else
137
134
trng = AdvancedPS. TracedRNG ()
138
- gen_model = LibtaskModel (deepcopy (model), trng)
139
- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (model), trng), trng)
140
- addreference! (gen_model. ctask. task, trace) # Do we need it here ?
135
+ trace = AdvancedPS. Trace (deepcopy (model), trng)
136
+ AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
141
137
trace
142
138
end
143
139
end
@@ -174,9 +170,8 @@ function AbstractMCMC.sample(
174
170
175
171
traces = map (1 : (sampler. nparticles)) do i
176
172
trng = AdvancedPS. TracedRNG ()
177
- gen_model = LibtaskModel (deepcopy (model), trng)
178
- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (model), trng), trng)
179
- addreference! (gen_model. ctask. task, trace) # Do we need it here ?
173
+ trace = AdvancedPS. Trace (deepcopy (model), trng)
174
+ AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180
175
trace
181
176
end
182
177
@@ -202,7 +197,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle)
202
197
trng = deepcopy (particle. rng)
203
198
Random123. set_counter! (trng. rng, 0 )
204
199
trng. count = 1
205
- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (particle. model. f), trng), trng)
200
+ trace = AdvancedPS. Trace (
201
+ AdvancedPS. LibtaskModel (deepcopy (particle. model. f), trng), trng
202
+ )
206
203
score = AdvancedPS. advance! (trace, true )
207
204
while ! isnothing (score)
208
205
score = AdvancedPS. advance! (trace, true )
0 commit comments