Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/tutorials/control-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ In addition to conditional evaluations, [`@trace`](@ref) also supports capturing
loops. This is possible in the form of both for and while loops.
This enables one to write algorithm that would not be possible otherwise such as
performing computations until convergence or running a computation for an certain
number of iterations which is only known during runtime.
number of iterations which is only known during runtime.

Here is an example of a function which computes the cumsum in non-optimized manner
using a for loop:
Expand Down
64 changes: 45 additions & 19 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Returns true if this function is executed in a Reactant compilation context, oth

# Code generation
"""
@trace <expr>
@trace [key = val,...] <expr>

Converts certain expressions like control flow into a Reactant friendly form. Importantly,
if no traced value is found inside the expression, then there is no overhead.
Expand All @@ -53,7 +53,8 @@ if no traced value is found inside the expression, then there is no overhead.
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
of the macro needs to be before the assignment and not before the `if`)
- `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers.
- `for` statements with a single induction variable iterating over integers with known `step`
- `while` statements

## Special Considerations

Expand Down Expand Up @@ -129,20 +130,42 @@ function fn(x)
return y, nothing
end
```

### Configuration

The behavior of loops can be configured with the following configuration options:

- `track_numbers::Union{Bool,Datatype}` - whether Julia numbers should be automatically promoted to traced numbers upon entering the loop.
- `checkpointing::Bool` - whether or not to enable checkpointing when performing reverse mode differentiation (default: `false`).
- `mincut::Bool` - whether or not to enable the mincut algorithm when performing reverse mode differentiation (default: `false`).
"""
macro trace(args...)
track_numbers = true
expr = first(args)
if length(args) > 1 && Meta.isexpr(args[1], :(=))
tn_expr = args[1]
tn_expr.args[1] == :track_numbers ||
error("@trace supports setting track_numbers, but got $(tn_expr)")
checkpointing = false
mincut = false

track_numbers = tn_expr.args[2]
expr = only(args[2:end])
else
expr = only(args)
expr = first(args)
while length(args) > 1
if Meta.isexpr(args[1], :(=))
tn_expr = args[1]
key, val = tn_expr.args
key ∈ (:track_numbers, :checkpointing, :mincut) ||
error("@trace supports setting track_numbers, checkpointing or mincut, but got $(tn_expr)")

if key === :track_numbers
track_numbers = val
elseif key === :checkpointing
checkpointing = val
elseif key === :mincut
mincut = val
end
args = args[2:end]
else
break
end
end
expr = only(args)

track_numbers = track_numbers ? Number : Union{}
expr = macroexpand(__module__, expr)

Expand All @@ -159,14 +182,14 @@ macro trace(args...)
return esc(trace_call(__module__, call))
end
Meta.isexpr(expr, :if) && return esc(trace_if(expr; track_numbers))
Meta.isexpr(expr, :for) && return (esc(trace_for(expr; track_numbers)))
Meta.isexpr(expr, :while) && return (esc(trace_while(expr; track_numbers)))
Meta.isexpr(expr, :for) && return (esc(trace_for(expr; track_numbers, checkpointing, mincut)))
Meta.isexpr(expr, :while) && return (esc(trace_while(expr; track_numbers, checkpointing, mincut)))
return error(
"Only `if-elseif-else` blocks, `for` and `while` loops are currently supported by `@trace`",
)
end

function trace_while(expr; track_numbers, first_arg=nothing)
function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothing)
Meta.isexpr(expr, :while, 2) || error("expected while expr")
cond, body = expr.args

Expand Down Expand Up @@ -233,21 +256,23 @@ function trace_while(expr; track_numbers, first_arg=nothing)
$(args_sym);
track_numbers=($(track_numbers)),
verify_arg_names=($(verify_arg_names_sym)),
mincut=($(mincut)),
checkpointing=($(checkpointing)),
)
end
end

return quote
if $(within_compile)() &&
$(any)($(is_traced), $(Expr(:tuple, cond_val.(all_syms.args)...)))
$(any)($(is_traced), $(Expr(:tuple, cond_val.(all_syms.args)...)))
$(reactant_code_block)
else
$(expr)
end
end
end

function trace_for(expr; track_numbers)
function trace_for(expr; track_numbers, checkpointing, mincut)
Meta.isexpr(expr, :for, 2) || error("expected for expr")
assign, body = expr.args

Expand Down Expand Up @@ -325,6 +350,7 @@ function trace_for(expr; track_numbers)
);
track_numbers,
first_arg=counter,
checkpointing, mincut,
))
end
end
Expand Down Expand Up @@ -374,7 +400,7 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
@assert expr.args[2].head == :block "currently we only support blocks"
expr.args[2] = Expr(:block, expr.args[2].args...)
true_last_line = expr.args[2].args[end]
remaining_lines = expr.args[2].args[1:(end - 1)]
remaining_lines = expr.args[2].args[1:(end-1)]
else
true_last_line = expr.args[2]
remaining_lines = []
Expand Down Expand Up @@ -417,7 +443,7 @@ function trace_if(expr; store_last_line=nothing, depth=0, track_numbers)
if else_block isa Expr
@assert else_block.head == :block "currently we only support blocks"
false_last_line = else_block.args[end]
remaining_lines = else_block.args[1:(end - 1)]
remaining_lines = else_block.args[1:(end-1)]
else
false_last_line = else_block
remaining_lines = []
Expand Down Expand Up @@ -571,7 +597,7 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
if startswith(string(x.args[1]), string(prepend))
return Expr(
:kw,
Symbol(string(x.args[1])[(length(string(prepend)) + 1):end]),
Symbol(string(x.args[1])[(length(string(prepend))+1):end]),
x.args[2],
)
end
Expand Down
12 changes: 10 additions & 2 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ function ReactantCore.traced_call(f::Function, args...)
end

function ReactantCore.traced_while(
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing
cond_fn::CFn,
body_fn::BFn,
args;
track_numbers=Number,
verify_arg_names=nothing,
checkpointing=false,
mincut=false,
) where {CFn,BFn}
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names)
return Ops.while_loop(
cond_fn, body_fn, args...; track_numbers, verify_arg_names, checkpointing, mincut
)
end
18 changes: 16 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1849,7 +1849,13 @@ end
end

@noinline function while_loop(
cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing
cond_fn::CFn,
body_fn::BFn,
args...;
track_numbers,
verify_arg_names=nothing,
checkpointing=false,
mincut=false,
) where {CFn,BFn}
# TODO: detect and prevent mutation within the condition

Expand Down Expand Up @@ -1915,7 +1921,15 @@ end
cond=cond_reg,
body=body_reg,
)
MLIR.IR.attr!(while_op, "enzymexla.disable_min_cut", MLIR.IR.UnitAttribute())

if !mincut
MLIR.IR.attr!(while_op, "enzymexla.disable_min_cut", MLIR.IR.UnitAttribute())
end

if checkpointing
MLIR.IR.attr!(while_op, "enzymexla.enable_checkpointing", MLIR.IR.Attribute(true))
end

return map(enumerate(linear_args)) do (i, arg)
Reactant.TracedUtils.set_mlir_data!(arg, MLIR.IR.result(while_op, i))
end
Expand Down
24 changes: 14 additions & 10 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,17 @@ function while_convergence(x, y)
return diff
end

@testset "while: convergence" begin
x = [1.0, 10.0, 20.0]
y = [0.0, -2.0, -3.0]
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(while_convergence(x_ra, y_ra)) ≈ while_convergence(x, y)
end

function for_no_track_numbers(x, n)
@trace track_numbers = false for i in n:16
@trace mincut = false checkpointing = true track_numbers = false for i in n:16
x = x .+ 1
end
return x
Expand All @@ -694,16 +703,11 @@ end
for_no_track_numbers_ra = @compile optimize = "enzyme-batch" for_no_track_numbers(
x_ra, n_ra
)
for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)
end

@testset "while: convergence" begin
x = [1.0, 10.0, 20.0]
y = [0.0, -2.0, -3.0]
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)
@test for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)

@test @jit(while_convergence(x_ra, y_ra)) ≈ while_convergence(x, y)
ir = sprint(show, @code_hlo optimize = "enzyme-batch" for_no_track_numbers(x_ra, n_ra))
@test contains(ir, "enzymexla.disable_min_cut")
@test contains(ir, "enzymexla.enable_checkpointing")
end

_call1(a, b) = a
Expand Down
Loading