diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 7660b90b..4f54d21f 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -40,12 +40,25 @@ const LOGGING = Ref(false) abstract type AbstractInstruction end const RawTape = Vector{AbstractInstruction} +separate_kwargs(args...; kwargs...) = (args, values(kwargs)) + function _infer(f, args_type) # `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}] ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1] return ir0 end +resolve_globalref(var) = var +resolve_globalref(var::Core.GlobalRef) = getproperty(var.mod, var.name) + +function mark_kwarg_func_as_nonprimitive(ir::Core.CodeInfo) + line = ir.code[end - 1] + Meta.isexpr(line, :call) || error("Expected a call expression") + f = resolve_globalref(line.args[1]) + @debug "Marking $f as non-primitive" + @eval is_primitive(::typeof($f)) = false +end + const Bindings = Vector{Any} mutable struct TapedFunction{F, TapeType} @@ -59,26 +72,31 @@ mutable struct TapedFunction{F, TapeType} retval_binding_slot::Int # 0 indicates the function has not returned deepcopy_types::Type # use a Union type for multiple types - function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T} + function TapedFunction{F, T}(_f::F, _args...; cache=false, deepcopy_types=Union{}, _kwargs...) where {F, T} + f, args = make_kwcall_maybe(_f, _args...; _kwargs...) args_type = _accurate_typeof.(args) - cache_key = (f, deepcopy_types, args_type...) + cache_key = (f, deepcopy_types, args_type...) if cache && haskey(TRCache, cache_key) # use cache - cached_tf = TRCache[cache_key]::TapedFunction{F, T} + cached_tf = TRCache[cache_key]::TapedFunction{typeof(f), T} tf = copy(cached_tf) tf.counter = 1 return tf end ir = _infer(f, args_type) + if iskwcall(f) + mark_kwarg_func_as_nonprimitive(ir) + end binding_values, slots, tape = translate!(RawTape(), ir) - tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types) + # TODO: Make this use `kwcall` instead. + tf = new{typeof(f), T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types) TRCache[cache_key] = tf # set cache return tf end - TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) = - TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types) + TapedFunction(f, args...; cache=false, deepcopy_types=Union{}, kwargs...) = + TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types, kwargs...) function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1} new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape, @@ -91,6 +109,19 @@ end const TRCache = LRU{Tuple, TapedFunction}(maxsize=10) const CompiledTape = Vector{FunctionWrapper{Nothing, Tuple{TapedFunction}}} +# TODO: Make this work on pre-1.9 Julia. +iskwcall(f) = false +iskwcall(f::typeof(Core.kwcall)) = true +iskwcall(tf::TapedFunction) = tf.func === Core.kwcall +function make_kwcall_maybe(f, args...; kwargs...) + return if length(kwargs) > 0 + args, kwargs = separate_kwargs(args...; kwargs...) + Core.kwcall, (kwargs, f, args...) + else + f, args + end +end + function Base.convert(::Type{CompiledTape}, tape::RawTape) ctape = CompiledTape(undef, length(tape)) for idx in 1:length(tape) diff --git a/src/tapedtask.jl b/src/tapedtask.jl index c96f4720..0127948a 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -67,13 +67,14 @@ BASE_COPY_TYPES = Union{Array, Ref} # NOTE: evaluating model without a trace, see # https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329 -function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default. +function TapedTask(f, args...; deepcopy_types=nothing, kwargs...) # deepcoy Array and Ref by default. if isnothing(deepcopy_types) deepcopy = BASE_COPY_TYPES else deepcopy = Union{BASE_COPY_TYPES, deepcopy_types} end - tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy) + tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy, kwargs...) + args = last(make_kwcall_maybe(f, args...; kwargs...)) TapedTask(tf, args...) end @@ -169,7 +170,9 @@ Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() # copy the task -function Base.copy(t::TapedTask; args=()) +function Base.copy(t::TapedTask; args=(), kwargs=()) + args = last(make_kwcall_maybe(func(t), args...; kwargs...)) + length(args) > 0 && t.tf.counter >1 && error("can't copy started task with new arguments") tf = copy(t.tf)