diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 083cfccc..1284c78d 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -13,6 +13,11 @@ mutable struct Instruction{F} <: AbstractInstruction tape::Tape end +mutable struct TapeInstruction <: AbstractInstruction + subtape::Tape + tape::Tape +end + Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing) Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner) MacroTools.@forward Tape.tape Base.iterate, Base.length @@ -21,6 +26,9 @@ const NULL_TAPE = Tape() function setowner!(tape::Tape, owner) tape.owner = owner + for ins in tape + isa(ins, TapeInstruction) && setowner!(ins.subtape, owner) + end return tape end @@ -52,6 +60,12 @@ function Base.show(io::IO, instruction::Instruction) println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))") end +function Base.show(io::IO, ti::TapeInstruction) + subtape = ti.subtape + tape = ti.tape + println(io, "TapeInstruction($(subtape)), tape=$(objectid(tape)))") +end + function Base.show(io::IO, tp::Tape) buf = IOBuffer() print(buf, "$(length(tp))-element Tape") @@ -70,10 +84,19 @@ function (instr::Instruction{F})() where F instr.output.val = output end +function (instr::TapeInstruction)() + run(instr.subtape) +end + function increase_counter!(t::Tape) t.counter > length(t) && return - # instr = t[t.counter] - t.counter += 1 + instr = t[t.counter] + if isa(instr, TapeInstruction) + increase_counter!(instr.subtape) + else + # must be a produce instruction? + t.counter += 1 + end return t end @@ -84,21 +107,62 @@ function run(tape::Tape, args...) end for instruction in tape instruction() - increase_counter!(tape) + tape.counter += 1 end end +# if we should trace into a function +# TODO: +# overload (instr::Instruction{F})() to specify +# which function to trace into +function trace_into end +trace_into(x) = false + function run_and_record!(tape::Tape, f, args...) f = val(f) # f maybe a Boxed closure - output = try - box(f(map(val, args)...)) - catch e - @warn e - Box{Any}(nothing) + should_trace = trace_into(f) + if !should_trace + output = try + box(f(map(val, args)...)) + catch e + @warn e + Box{Any}(nothing) + end + ins = Instruction(f, args, output, tape) + push!(tape, ins) + return output + else + real_args = map(val, args) + ir = IRTools.@code_ir f(real_args...) + ir = intercept(ir; recorder=:run_and_record!) + # 1. we should distinguish fixed args and varargs here + arg_len = ir |> IRTools.arguments |> length + arg_len -= 2 # 1 for f, 1 for vargs + p_args = real_args[1:arg_len] + v_args = arg_len < 1 ? real_args : real_args[arg_len+1:end] + # detect if there is an vararg for f + if length(v_args) == 1 + m = which(f, typeof(real_args)) + no_vararg = @static if VERSION >= v"1.7" + (m.sig <: Tuple) && hasproperty(m.sig, :types) && + !isa(m.sig.types[end], Core.TypeofVararg) + else + (m.sig <: Tuple) && hasproperty(m.sig, :types) && + !(m.sig.types[end] <: Vararg{Any}) + end + if no_vararg + v_args = v_args[1] + end + end + subtape = IRTools.evalir(ir, f, p_args..., v_args) + # 2. we should recover the args after getting the tape + # to keep the chain complete + subtape[1].input = args + ins = TapeInstruction(subtape, tape) + output = subtape[end].output + push!(tape, ins) + return output end - ins = Instruction(f, args, output, tape) - push!(tape, ins) - return output end function unbox_condition(ir) diff --git a/src/tapedtask.jl b/src/tapedtask.jl index e6da976d..dfcf883e 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -52,7 +52,6 @@ TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...) TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...) func(t::TapedTask) = t.tf.func - function step_in(t::Tape, args) len = length(t) if(t.counter <= 1 && length(args) > 0) @@ -60,7 +59,12 @@ function step_in(t::Tape, args) t[1].input = input end while t.counter <= len - t[t.counter]() + ins = t[t.counter] + if isa(ins, TapeInstruction) + step_in(ins.subtape, ()) + else + ins() + end # produce and wait after an instruction is done ttask = t.owner.owner if length(ttask.produced_val) > 0 @@ -68,10 +72,11 @@ function step_in(t::Tape, args) put!(ttask.produce_ch, val) take!(ttask.consume_ch) # wait for next consumer end - increase_counter!(t) + t.counter += 1 end end + function next_step!(t::TapedTask) increase_counter!(t.tf.tape) return t @@ -99,7 +104,6 @@ function (instr::Instruction{typeof(produce)})() end =# - # ** Approach (B) to implement `produce`: # This way has its caveat: # `produce` may deeply hide in an instruction, but not be an instruction @@ -107,6 +111,8 @@ end # the instruction after the one which contains this `produce` call. If the # call to `produce` is not the last expression in the instuction, that # instruction will not be whole executed in the copied task. +# With the abilty to trace into nested function call, we can minimize the +# limitation of this caveat. @inline function is_in_tapedtask() ct = current_task() ct.storage === nothing && return false @@ -119,6 +125,8 @@ end function produce(val) is_in_tapedtask() || return nothing ttask = current_task().storage[:tapedtask] + # put!(ttask.produce_ch, val) + # take!(ttask.consume_ch) # wait for next consumer length(ttask.produced_val) > 1 && error("There is a produced value which is not consumed.") push!(ttask.produced_val, val) @@ -199,6 +207,11 @@ function Base.copy(x::Instruction, on_tape::Tape, roster::Dict{UInt64, Any}) Instruction(x.fun, input, output, on_tape) end +function Base.copy(x::TapeInstruction, on_tape::Tape, roster::Dict{UInt64, Any}) + subtape = copy(x.subtape, roster) + TapeInstruction(subtape, on_tape) +end + function Base.copy(t::Tape, roster::Dict{UInt64, Any}) old_data = t.tape new_data = Vector{AbstractInstruction}() @@ -223,7 +236,6 @@ function Base.copy(tf::TapedFunction) end function Base.copy(t::TapedTask) - # t.counter[] <= 1 && error("Can't copy a TapedTask which is not running.") tf = copy(t.tf) new_t = TapedTask(tf) new_t.task.storage = copy(t.task.storage) diff --git a/test/runtests.jl b/test/runtests.jl index 2749827a..39c8c08e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Libtask using Test include("ctask.jl") +include("special-instuctions.jl") include("tarray.jl") include("tref.jl") diff --git a/test/special-instuctions.jl b/test/special-instuctions.jl new file mode 100644 index 00000000..fec2cfff --- /dev/null +++ b/test/special-instuctions.jl @@ -0,0 +1,28 @@ +@testset "Special Instructions" begin + + @testset "TapeInstruction" begin + i1(x) = i2(x) + i2(x) = produce(x) + + Libtask.trace_into(::typeof(i1)) = true + Libtask.trace_into(::typeof(i2)) = true + + function f() + t = 0 + while t < 4 + i1(t) + t = 1 + t + end + end + + ctask = CTask(f) + @test consume(ctask) == 0 + @test consume(ctask) == 1 + a = copy(ctask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ctask) == 2 + @test consume(ctask) == 3 + end + +end