Skip to content

Commit 689abb2

Browse files
rikhuijzerdevmotionyebai
authored
Add TapedTask type annotations to storage[:tapedtask] (#121)
* Improve type inference for ttask inside producer This reduces the number of allocations for ```julia @time callback !== nothing && callback() ``` inside `(tf::TapedFunction)(args...; callback=nothing)` from 1 allocation: 48 bytes to 0 allocations. * Add a few more type annotation * Remove comment * Set version to 0.6.10 * Update src/tapedtask.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 01c2727 commit 689abb2

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
66
version = "0.7.0"
77

8+
89
[deps]
910
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1011
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"

src/tapedtask.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ struct TapedTask{F}
2020
end
2121
end
2222

23+
function producer()
24+
ttask = current_task().storage[:tapedtask]::TapedTask
25+
if length(ttask.produced_val) > 0
26+
val = pop!(ttask.produced_val)
27+
put!(ttask.produce_ch, val)
28+
take!(ttask.consume_ch) # wait for next consumer
29+
end
30+
return nothing
31+
end
32+
2333
function wrap_task(tf, produce_ch, consume_ch, args...)
2434
try
25-
producer = () -> begin
26-
ttask = current_task().storage[:tapedtask]
27-
if length(ttask.produced_val) > 0
28-
val = pop!(ttask.produced_val)
29-
put!(ttask.produce_ch, val)
30-
take!(ttask.consume_ch) # wait for next consumer
31-
end
32-
end
3335
tf(args...; callback=producer)
3436
catch e
3537
bt = catch_backtrace()
@@ -104,13 +106,13 @@ end
104106
ct.storage === nothing && return false
105107
haskey(ct.storage, :tapedtask) || return false
106108
# check if we are recording a tape
107-
isempty(ct.storage[:tapedtask].tf.tape) && return false
108-
return true
109+
ttask = ct.storage[:tapedtask]::TapedTask
110+
return !isempty(ttask.tf.tape)
109111
end
110112

111113
function produce(val)
112114
is_in_tapedtask() || return nothing
113-
ttask = current_task().storage[:tapedtask]
115+
ttask = current_task().storage[:tapedtask]::TapedTask
114116
length(ttask.produced_val) > 1 &&
115117
error("There is a produced value which is not consumed.")
116118
push!(ttask.produced_val, val)
@@ -161,7 +163,8 @@ Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()
161163
function Base.copy(t::TapedTask)
162164
tf = copy(t.tf)
163165
new_t = TapedTask(tf)
164-
new_t.task.storage = copy(t.task.storage)
166+
storage = t.task.storage::IdDict{Any,Any}
167+
new_t.task.storage = copy(storage)
165168
new_t.task.storage[:tapedtask] = new_t
166169
return new_t
167170
end

0 commit comments

Comments
 (0)