Skip to content

Commit 319c74a

Browse files
authored
Switch to new model evaluation API (#40)
* Enable new model evaluation API. * Apply suggestions from code review * Apply suggestions from code review * Update container.jl * Update Project.toml * Minor bugfix. * Update Project.toml
1 parent 488822c commit 319c74a

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.3"
4+
version = "0.3.1"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -13,6 +13,6 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1313
[compat]
1414
AbstractMCMC = "2, 3"
1515
Distributions = "0.23, 0.24, 0.25"
16-
Libtask = "0.6"
16+
Libtask = "0.6.2"
1717
StatsFuns = "0.9"
1818
julia = "1.3"

src/container.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ end
66
const Particle = Trace
77

88
function Trace(f)
9-
ctask = Libtask.CTask(f)
9+
if hasfield(typeof(f), :evaluator) # Test whether f is a Turing.TracedModel
10+
# println(f.evaluator)
11+
ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...)
12+
else # f is a Function, or AdavncedPS.Model
13+
ctask = Libtask.CTask(f)
14+
end
1015

1116
# add backward reference
1217
newtrace = Trace(f, ctask)
@@ -42,7 +47,12 @@ end
4247
# Create new task and copy randomness
4348
function forkr(trace::Trace)
4449
newf = reset_model(trace.f)
45-
ctask = Libtask.CTask(trace.ctask)
50+
# ctask = Libtask.CTask(trace.ctask)
51+
if hasfield(typeof(newf), :evaluator) # Test whether f is a Turing.TracedModel
52+
ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...)
53+
else # f is a Function, or AdavncedPS.Model
54+
ctask = Libtask.CTask(newf)
55+
end
4656

4757
# add backward reference
4858
newtrace = Trace(newf, ctask)

0 commit comments

Comments
 (0)