Skip to content

Commit 48355e1

Browse files
authored
Merge branch 'master' into perf
2 parents d1a2fea + 24b3d4b commit 48355e1

File tree

5 files changed

+24
-9
lines changed

5 files changed

+24
-9
lines changed

.github/workflows/Testing.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,3 @@ jobs:
4242
${{ runner.os }}-
4343
- uses: julia-actions/julia-buildpkg@latest
4444
- uses: julia-actions/julia-runtest@latest
45-
with:
46-
coverage: false

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.7.1"
6+
version = "0.7.3"
77

88
[deps]
9-
CodeInfoTools = "bc773b8a-8374-437a-b9f2-0e9785855863"
109
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1110
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
1211
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1312
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1413

1514
[compat]
16-
CodeInfoTools = "0.3.4"
15+
FunctionWrappers = "1.1"
1716
LRUCache = "1.3"
1817
julia = "1.7"
1918

src/Libtask.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Libtask
22

3-
using CodeInfoTools
43
using FunctionWrappers: FunctionWrapper
54
using LRUCache
65

src/tapedfunction.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
abstract type AbstractInstruction end
44
const RawTape = Vector{AbstractInstruction}
55

6+
function _infer(f, args_type)
7+
# `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}]
8+
ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1]
9+
# ir1 = CodeInfoTools.code_inferred(f, args_type...)
10+
# ir1.ssavaluetypes = ir0.ssavaluetypes
11+
return ir0
12+
end
13+
614
mutable struct TapedFunction{F, TapeType}
715
func::F # maybe a function, a constructor, or a callable object
816
arity::Int
@@ -22,8 +30,7 @@ mutable struct TapedFunction{F, TapeType}
2230
tf.counter = 1
2331
return tf
2432
end
25-
26-
ir = CodeInfoTools.code_inferred(f, args_type...)
33+
ir = _infer(f, args_type)
2734
bindings, tape = translate!(RawTape(), ir)
2835

2936
tf = new{F, T}(f, length(args), ir, tape, 1, bindings, :none)

src/tapedtask.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,24 @@ Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()
162162
# copy the task
163163

164164
function Base.copy(t::TapedTask; args=())
165+
length(args) > 0 && t.tf.counter >1 &&
166+
error("can't copy started task with new arguments")
165167
tf = copy(t.tf)
166168
task_args = if length(args) > 0
169+
# this cond implies t.tf.counter == 0, i.e., the task is not started yet
167170
typeof(args) == typeof(t.args) || error("bad arguments")
168171
args
169172
else
170-
tape_copy.(t.args)
173+
if t.tf.counter > 1
174+
# the task is running, we find the
175+
# real args from the copied bindings
176+
map(1:length(t.args)) do i
177+
get(tf.bindings, Symbol("_", i + 1), t.args[i])
178+
end
179+
else
180+
# the task is not started yet, but no args is given
181+
tape_copy.(t.args)
182+
end
171183
end
172184
new_t = TapedTask(tf, task_args...)
173185
storage = t.task.storage::IdDict{Any,Any}

0 commit comments

Comments
 (0)