Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 30 additions & 36 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
mutable struct Instruction{F}
fun::F
input::Tuple
output
tape
end

abstract type AbstractInstruction end

mutable struct Tape
tape::Vector{Instruction}
tape::Vector{<:AbstractInstruction}
counter::Int
owner
end

Tape() = Tape(Vector{Instruction}(), nothing)
Tape(owner) = Tape(Vector{Instruction}(), owner)
mutable struct Instruction{F} <: AbstractInstruction
fun::F
input::Tuple
output
tape::Tape
end

Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing)
Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner)
MacroTools.@forward Tape.tape Base.iterate, Base.length
MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex
const NULL_TAPE = Tape()

function setowner!(tape::Tape, owner)
tape.owner = owner
end

mutable struct Box{T}
val::T
end

val(x) = x
val(x::Box) = x.val
box(x) = Box(x)
any_box(x) = Box{Any}(x)
box(x::Box) = x

gettape(x) = nothing
gettape(x::Instruction) = x.tape
Expand Down Expand Up @@ -63,11 +69,20 @@ function (instr::Instruction{F})() where F
instr.output.val = output
end

function increase_counter(t::Tape)
t.counter > length(t) && return
# instr = t[t.counter]
t.counter += 1
end

function run(tape::Tape, args...)
input = map(box, args)
tape[1].input = input
if length(args) > 0
input = map(box, args)
tape[1].input = input
end
for instruction in tape
instruction()
increase_counter(tape)
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively, I'd have expected this to also return the result of the function. Or is that extracted in another place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result is stored in Instruction.output.


Expand All @@ -77,21 +92,13 @@ function run_and_record!(tape::Tape, f, args...)
box(f(map(val, args)...))
catch e
@warn e
any_box(nothing)
Box{Any}(nothing)
end
ins = Instruction(f, args, output, tape)
push!(tape, ins)
return output
end

function dry_record!(tape::Tape, f, args...)
# We don't know the type of box.val now, so we use Box{Any}
output = any_box(nothing)
ins = Instruction(f, args, output, tape)
push!(tape, ins)
return output
end

function unbox_condition(ir)
for blk in IRTools.blocks(ir)
vars = keys(blk)
Expand Down Expand Up @@ -188,27 +195,14 @@ function (tf::TapedFunction)(args...)
tape = IRTools.evalir(ir, tf.func, args...)
tf.ir = ir
tf.tape = tape
tape.owner = tf
setowner!(tape, tf)
return result(tape)
end
# TODO: use cache
run(tf.tape, args...)
return result(tf.tape)
end

function dry_run(tf::TapedFunction)
isempty(tf.tape) || (return tf)
@assert tf.arity >= 0 "TapedFunction need a fixed arity to dry run."
args = fill(nothing, tf.arity)
ir = IRTools.@code_ir tf.func(args...)
ir = intercept(ir; recorder=:dry_record!)
tape = IRTools.evalir(ir, tf.func, args...)
tf.ir = ir
tf.tape = tape
tape.owner = tf
return tf
end

function Base.show(io::IO, tf::TapedFunction)
buf = IOBuffer()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra buffer? You could directly println to io.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I use an extra buffer, because that there are too many calls to println with io (which is stdout in most situations). And between these calls, the current task may be switched off, then Juia does some print in another task, and then switched on. In that case, the output will not be continuous and that makes it annoying to read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good reason. Then put a comment there to remind future readers why this should not be refactored.

println(buf, "TapedFunction:")
Expand Down
66 changes: 37 additions & 29 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
struct TapedTaskException
exc
backtrace
end

struct TapedTask
task::Task
tf::TapedFunction
counter::Ref{Int}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, counter, pch, cch, Any[])
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, pch, cch, Any[])
end
end

function TapedTask(tf::TapedFunction, args...)
tf.owner != nothing && error("TapedFunction is owned to another task.")
# dry_run(tf)
isempty(tf.tape) && tf(args...)
counter = Ref{Int}(1)
produce_ch = Channel()
consume_ch = Channel{Int}()
task = @task try
step_in(tf, counter, args)
step_in(tf.tape, args)
catch e
put!(produce_ch, TapedTaskException(e))
# @error "TapedTask Error: " exception=(e, catch_backtrace())
bt = catch_backtrace()
put!(produce_ch, TapedTaskException(e, bt))
# @error "TapedTask Error: " exception=(e, bt)
rethrow()
finally
@static if VERSION >= v"1.4"
Expand All @@ -40,7 +39,7 @@ function TapedTask(tf::TapedFunction, args...)
close(produce_ch)
close(consume_ch)
end
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
t = TapedTask(task, tf, produce_ch, consume_ch)
task.storage === nothing && (task.storage = IdDict())
task.storage[:tapedtask] = t
tf.owner = t
Expand All @@ -53,25 +52,28 @@ TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
func(t::TapedTask) = t.tf.func

function step_in(tf::TapedFunction, counter::Ref{Int}, args)
len = length(tf.tape)
if(counter[] <= 1 && length(args) > 0)

function step_in(t::Tape, args)
len = length(t)
if(t.counter <= 1 && length(args) > 0)
input = map(box, args)
tf.tape[1].input = input
t[1].input = input
end
while counter[] <= len
tf.tape[counter[]]()
while t.counter <= len
t[t.counter]()
# produce and wait after an instruction is done
ttask = tf.owner
ttask = t.owner.owner
if length(ttask.produced_val) > 0
val = pop!(ttask.produced_val)
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
end
counter[] += 1
increase_counter(t)
end
end

next_step(t::TapedTask) = increase_counter(t.tf.tape)

#=
# ** Approach (A) to implement `produce`:
# Make`produce` a standalone instturction. This approach does NOT
Expand Down Expand Up @@ -186,18 +188,21 @@ function copy_box(old_box::Box{T}, roster::Dict{UInt64, Any}) where T
end
copy_box(o, roster::Dict{UInt64, Any}) = o

function Base.copy(t::Tape)
function Base.copy(x::Instruction, on_tape::Tape, roster::Dict{UInt64, Any})
input = map(x.input) do ob
copy_box(ob, roster)
end
output = copy_box(x.output, roster)
Instruction(x.fun, input, output, on_tape)
end

function Base.copy(t::Tape, roster::Dict{UInt64, Any})
old_data = t.tape
new_data = Vector{Instruction}()
new_tape = Tape(new_data, t.owner)
new_data = Vector{AbstractInstruction}()
new_tape = Tape(new_data, t.counter, t.owner)

roster = Dict{UInt64, Any}()
for x in old_data
input = map(x.input) do ob
copy_box(ob, roster)
end
output = copy_box(x.output, roster)
new_ins = Instruction(x.fun, input, output, new_tape)
new_ins = copy(x, new_tape, roster)
push!(new_data, new_ins)
end

Expand All @@ -207,8 +212,9 @@ end
function Base.copy(tf::TapedFunction)
new_tf = TapedFunction(tf.func; arity=tf.arity)
new_tf.ir = tf.ir
new_tape = copy(tf.tape)
new_tape.owner = new_tf
roster = Dict{UInt64, Any}()
new_tape = copy(tf.tape, roster)
setowner!(new_tape, new_tf)
new_tf.tape = new_tape
return new_tf
end
Expand All @@ -217,6 +223,8 @@ function Base.copy(t::TapedTask)
# t.counter[] <= 1 && error("Can't copy a TapedTask which is not running.")
tf = copy(t.tf)
new_t = TapedTask(tf)
new_t.counter[] = t.counter[] + 1
new_t.task.storage = copy(t.task.storage)
new_t.task.storage[:tapedtask] = new_t
next_step(new_t)
return new_t
end