Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 136 additions & 122 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,36 @@ mutable struct Instruction{F, T<:Taped} <: AbstractInstruction
tape::T
end

mutable struct BlockInstruction{T<:Taped} <: AbstractInstruction
id::Int
args::Vector
tape::T
end

mutable struct BranchInstruction{T<:Taped} <: AbstractInstruction
condition::Any
block::Int
args::Vector
tape::T
end

mutable struct ReturnInstruction{T<:Taped} <: AbstractInstruction
arg
tape::T
end

mutable struct TapedFunction{F} <: Taped
func::F # maybe a function or a callable obejct
arity::Int
ir::Union{Nothing, IRTools.IR}
tape::RawTape
counter::Int
block_map::Dict{Int, Int}
retval
owner
function TapedFunction(f::F; arity::Int=-1) where {F}
new{F}(f, arity, nothing, RawTape(), 1, nothing)
new{F}(f, arity, nothing, RawTape(), 1,
Dict{Int, Int}(), nothing, nothing)
end
end

Expand All @@ -42,39 +63,44 @@ Base.show(io::IO, box::Box) = print(io, "Box(", box.val, ")")
MacroTools.@forward TapedFunction.tape Base.iterate, Base.length
MacroTools.@forward TapedFunction.tape Base.push!, Base.getindex, Base.lastindex

result(t::RawTape) = isempty(t) ? nothing : val(t[end].output)
result(t::TapedFunction) = result(t.tape)
result(t::TapedFunction) = t.retval

function increase_counter!(t::TapedFunction)
t.counter > length(t) && return
# instr = t[t.counter]
t.counter += 1
return t
function init!(tf::TapedFunction, args)
if isempty(tf.tape)
ir = IRTools.@code_ir tf.func(args...)
tf.ir = ir
translate!(tf, ir)
end
return tf
end

function reset!(tf::TapedFunction, ir::IRTools.IR, tape::RawTape)
tf.ir = ir
tf.tape = tape
blk_map = Dict{Int, Int}()

for (i, ins) in enumerate(tf.tape)
isa(ins, BlockInstruction) || continue
blk_map[ins.id] = i
end

tf.block_map = blk_map
tf.counter = 1
return tf
end

function (tf::TapedFunction)(args...)
if isempty(tf.tape)
ir = IRTools.@code_ir tf.func(args...)
ir = intercept(ir; recorder=:track!)
tf.ir = ir
tf.tape = RawTape()
tf2 = IRTools.evalir(ir, tf, args...)
@assert tf === tf2
else
# run the raw tape
if length(args) > 0
input = map(box, args)
tf.tape[1].input = input
end
for instruction in tf.tape
instruction()
end
init!(tf, args)
# run the raw tape
if length(args) > 0
input = map(box, args)
tf.tape[1].input = input
end
tf.counter = 1
while true
ins = tf[tf.counter]
ins()
isa(ins, ReturnInstruction) && break
end
return result(tf)
end
Expand Down Expand Up @@ -111,19 +137,32 @@ function Base.show(io::IO, tp::RawTape)
end

## methods for Instruction
Base.show(io::IO, instruction::AbstractInstruction) = print(io, "A ", typeof(instruction))
Base.show(io::IO, instruction::AbstractInstruction) = println(io, "A ", typeof(instruction))

function Base.show(io::IO, instruction::Instruction)
func = instruction.func
tape = instruction.tape
println(io, "Instruction($(func)$(map(val, instruction.input)), tape=$(objectid(tape)))")
end

function Base.show(io::IO, instruction::BlockInstruction)
id = instruction.id
tape = instruction.tape
println(io, "BlockInstruction($(id)->$(map(val, instruction.args)), tape=$(objectid(tape)))")
end

function Base.show(io::IO, instruction::BranchInstruction)
tape = instruction.tape
println(io, "BranchInstruction($(val(instruction.condition)), tape=$(objectid(tape)))")
end

function (instr::Instruction{F})() where F
# catch run-time exceptions / errors.
try
output = instr.func(map(val, instr.input)...)
func = val(instr.func)
output = func(map(val, instr.input)...)
instr.output.val = output
instr.tape.counter += 1
catch e
println(e, catch_backtrace());
rethrow(e);
Expand All @@ -137,122 +176,97 @@ function (instr::Instruction{typeof(_new)})()
expr = Expr(:new, map(val, instr.input)...)
output = eval(expr)
instr.output.val = output
instr.tape.counter += 1
catch e
println(e, catch_backtrace());
rethrow(e);
end
end

## internal functions

function track!(tape::Taped, f, args...)
f = val(f) # f maybe a Boxed closure
output = try
box(f(map(val, args)...))
catch e
@warn e
Box{Any}(nothing)
end
ins = Instruction(f, args, output, tape)
push!(tape, ins)
return output
end

function track!(tape::Taped, ::typeof(_new), args...)
output = try
expr = Expr(:new, map(val, args)...)
box(eval(expr))
catch e
@warn e
Box{Any}(nothing)
end
ins = Instruction(_new, args, output, tape)
push!(tape, ins)
return output
function (instr::BlockInstruction)()
instr.tape.counter += 1
end

function unbox_condition(ir)
for blk in IRTools.blocks(ir)
vars = keys(blk)
brs = IRTools.branches(blk)
for (i, br) in enumerate(brs)
IRTools.isconditional(br) || continue
cond = br.condition
new_cond = IRTools.push!(
blk,
IRTools.xcall(@__MODULE__, :val, cond))
brs[i] = IRTools.Branch(br; condition=new_cond)
function (instr::BranchInstruction)()
if instr.condition === nothing || !val(instr.condition) # unless
target = instr.block
target_idx = instr.tape.block_map[target]
blk_ins = instr.tape[target_idx]
@assert isa(blk_ins, BlockInstruction)
@assert length(instr.args) == length(blk_ins.args)
for i in 1:length(instr.args)
blk_ins.args[i].val = val(instr.args[i])
end
instr.tape.counter = target_idx
else
instr.tape.counter += 1
end
end

box_args() = nothing
box_args(x) = x
box_args(args...) = args

function _replace_args(args, pairs::Dict)
map(args) do x
haskey(pairs, x) ? pairs[x] : x
end
function (instr::ReturnInstruction)()
instr.tape.retval = val(instr.arg)
end

function intercept(ir; recorder=:track!)
ir == nothing && return
# we use tf instead of the original function as the first argument
# get the TapedFunction
tape = pushfirst!(ir, IRTools.xcall(Base, :identity, IRTools.arguments(ir)[1]))

# box the args
first_blk = IRTools.blocks(ir)[1]
args = IRTools.arguments(first_blk)
arity = length(args) - 1
arg_pairs= Dict()

args_var = args[1]
if arity == 0
args_var = IRTools.insertafter!(ir, tape, IRTools.xcall(@__MODULE__, :box_args))
elseif arity == 1
args_var = IRTools.insertafter!(ir, tape, IRTools.xcall(@__MODULE__, :box_args, args[2]))
arg_pairs = Dict(args[2] => args_var)
else # arity > 1
args_var = IRTools.insertafter!(ir, tape, IRTools.xcall(@__MODULE__, :box_args, args[2:end]...))
args_new, last_pos = [], args_var

iter_state = []
## internal functions

for i in 1:arity
last_pos = IRTools.insertafter!(ir, last_pos, IRTools.xcall(Base, :indexed_iterate, args_var, i, iter_state...))
args_iter = last_pos
last_pos = IRTools.insertafter!(ir, last_pos, IRTools.xcall(Core, :getfield, args_iter, 1))
push!(args_new, last_pos)
if i != arity
last_pos = IRTools.insertafter!(ir, last_pos, IRTools.xcall(Core, :getfield, args_iter, 2))
iter_state = [last_pos]
end
arg_boxer(var, boxes) = var
arg_boxer(var::GlobalRef, boxes) = eval(var)
arg_boxer(var::QuoteNode, boxes) = eval(var)
arg_boxer(var::IRTools.Variable, boxes) = get!(boxes, var, Box{Any}(nothing))
function args_initializer(ins::BlockInstruction)
return (args...) -> begin
@assert length(args) + 1 == length(ins.args)
ins.args[1].val = ins.tape.func
for i in 1:length(args)
ins.args[i + 1].val = val(args[i]) # fill the boxes
end
arg_pairs = Dict(zip(args[2:end], args_new))
end
end

# here we assumed the ir only has a return statement at its last block,
# and we make sure the return value is from a function call (to `identity`)
last_blk = IRTools.blocks(ir)[end]
retv = IRTools.returnvalue(last_blk)
IRTools.return!(last_blk, IRTools.xcall(Base, :identity, retv))
function translate!(taped::Taped, ir::IRTools.IR)
tape = taped.tape
boxes = Dict{IRTools.Variable, Box{Any}}()
_box = (x) -> arg_boxer(x, boxes)
for (blk_id, blk) in enumerate(IRTools.blocks(ir))
# blocks
blk_args = IRTools.arguments(blk)
push!(tape, BlockInstruction(blk_id, map(_box, blk_args), taped))
# `+ 1` because we will have an extra ins at the beginning
taped.block_map[blk_id] = length(tape) + 1
# normal instructions
for (x, st) in blk
if Meta.isexpr(st.expr, :call)
args = map(_box, st.expr.args)
# args[1] is the function
f = args[1]
ins = Instruction(f, args[2:end] |> Tuple,
_box(x), taped)
push!(tape, ins)
elseif Meta.isexpr(st.expr, :new)
args = map(_box, st.expr.args)
ins = Instruction(_new, args |> Tuple, _box(x), taped)
push!(tape, ins)
else
@warn "Unknown IR code: " st
ins = Instruction(identity, (eval(st.expr),), _box(x), taped)
push!(tape, ins)
end
end

for (x, st) in ir
x == tape && continue
if Meta.isexpr(st.expr, :call)
new_args = (x == args_var) ? st.expr.args : _replace_args(st.expr.args, arg_pairs)
ir[x] = IRTools.xcall(@__MODULE__, recorder, tape, new_args...)
elseif Meta.isexpr(st.expr, :new)
args = st.expr.args
ir[x] = IRTools.xcall(@__MODULE__, recorder, tape, _new, args...)
else
@warn "Unknown IR code: " st
# branches (including `return`)
for br in IRTools.branches(blk)
if br.condition === nothing && br.block == 0
ins = ReturnInstruction(_box(br.args[1]), taped)
push!(tape, ins)
else
ins = BranchInstruction(
_box(br.condition), br.block, map(_box, br.args), taped)
push!(tape, ins)
end
end
end
# the real return value will be in the last instruction on the tape
IRTools.return!(ir, tape)
unbox_condition(ir)
return ir
init_ins = Instruction(args_initializer(tape[1]), (),
Box{Any}(nothing), taped)
insert!(tape, 1, init_ins)
end
Loading