Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const s
} else if (is_const_one(split.factor)) {
// The split factor trivially divides the old extent,
// but we know nothing new about the outer dimension.
} else if (tail == TailStrategy::GuardWithIf) {
} else if (tail == TailStrategy::GuardWithIf ||
tail == TailStrategy::Predicate) {
// It's an exact split but we failed to prove that the
// extent divides the factor. Use predication to avoid
// running off the end of the original loop.
Expand All @@ -70,6 +71,10 @@ vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const s
// Inject the if condition *after* doing the substitution
// for the guarded version.
Expr cond = likely(old_var <= old_max);
if (tail == TailStrategy::Predicate) {
// Add the hint for predication.
cond = predicate(cond);
}
result.emplace_back(cond);

} else if (tail == TailStrategy::ShiftInwards) {
Expand Down
10 changes: 3 additions & 7 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,7 @@ class Bounds : public IRVisitor {

interval.min = Interval::make_max(interval.min, lower.min);
interval.max = Interval::make_min(interval.max, upper.max);
} else if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost)) {
internal_assert(op->args.size() == 1);
} else if (Call::as_tag(op)) {
op->args[0].accept(this);
} else if (op->is_intrinsic(Call::return_second)) {
internal_assert(op->args.size() == 2);
Expand Down Expand Up @@ -2477,10 +2475,8 @@ class BoxesTouched : public IRGraphVisitor {
for (const auto &pair : cases) {
Expr c = pair.first;
Stmt body = pair.second;
const Call *call = c.as<Call>();
if (call && (call->is_intrinsic(Call::likely) ||
call->is_intrinsic(Call::likely_if_innermost) ||
call->is_intrinsic(Call::strict_float))) {
const Call *call = Call::as_tag(c);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (const Call *call == Call::as_tag(c)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call is used below :) I made this change at first and had to revert it for that.

if (call) {
c = call->args[0];
}

Expand Down
2 changes: 1 addition & 1 deletion src/Derivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ void ReverseAccumulationVisitor::visit(const Call *op) {
accumulate(op->args[0], adjoint * (make_one(op->type) - op->args[2]));
accumulate(op->args[1], adjoint * op->args[2]);
accumulate(op->args[2], adjoint * (op->args[1] - op->args[0]));
} else if (op->is_intrinsic(Call::likely)) {
} else if (Call::as_tag(op)) {
accumulate(op->args[0], adjoint);
} else if (op->is_intrinsic(Call::return_second)) {
accumulate(op->args[0], make_const(op->type, 0.0));
Expand Down
4 changes: 2 additions & 2 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,9 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
}

if (exact) {
user_assert(tail == TailStrategy::GuardWithIf)
user_assert(tail == TailStrategy::GuardWithIf || tail == TailStrategy::Predicate)
<< "When splitting Var " << old_name
<< " the tail strategy must be GuardWithIf or Auto. "
<< " the tail strategy must be GuardWithIf, Predicate or Auto. "
<< "Anything else may change the meaning of the algorithm\n";
}

Expand Down
1 change: 1 addition & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ const char *const intrinsic_op_names[] = {
"mulhi_shr",
"mux",
"popcount",
"predicate",
"prefetch",
"promise_clamped",
"random",
Expand Down
5 changes: 5 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ struct Call : public ExprNode<Call> {
mulhi_shr, // Compute high_half(arg[0] * arg[1]) >> arg[3]. Note that this is a shift in addition to taking the upper half of multiply result. arg[3] must be an unsigned integer immediate.
mux,
popcount,
predicate,
prefetch,
promise_clamped,
random,
Expand Down Expand Up @@ -661,6 +662,10 @@ struct Call : public ExprNode<Call> {
return nullptr;
}

static const Call *as_tag(const Expr &e) {
return as_intrinsic(e, {Call::likely, Call::likely_if_innermost, Call::predicate, Call::strict_float});
}

bool is_extern() const {
return (call_type == Extern ||
call_type == ExternCPlusPlus ||
Expand Down
13 changes: 13 additions & 0 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,19 @@ Stmt remove_likelies(const Stmt &s) {
return RemoveLikelies().mutate(s);
}

Expr unwrap_tags(const Expr &e) {
if (const Call *tag = Call::as_tag(e)) {
return unwrap_tags(tag->args[0]);
}
return e;
}

Expr predicate(Expr e) {
Type t = e.type();
return Internal::Call::make(t, Internal::Call::predicate,
{std::move(e)}, Internal::Call::PureIntrinsic);
}

Expr requirement_failed_error(Expr condition, const std::vector<Expr> &args) {
return Internal::Call::make(Int(32),
"halide_error_requirement_failed",
Expand Down
10 changes: 10 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ Expr remove_likelies(const Expr &e);
* all calls to likely() and likely_if_innermost() removed. */
Stmt remove_likelies(const Stmt &s);

/** If the expression is a tag helper call, remove it and return
* the tagged expression. If not, returns the expression. */
Expr unwrap_tags(const Expr &e);

/** Expressions tagged with this intrinsic are suggestions that
* vectorization of loops with guard ifs should be implemented with
* non-faulting predicated loads and stores, instead of scalarizing
* an if statement. */
Expr predicate(Expr e);

// Secondary args to print can be Exprs or const char *
inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args) {
}
Expand Down
3 changes: 3 additions & 0 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ std::ostream &operator<<(std::ostream &out, const TailStrategy &t) {
case TailStrategy::GuardWithIf:
out << "GuardWithIf";
break;
case TailStrategy::Predicate:
out << "Predicate";
break;
case TailStrategy::ShiftInwards:
out << "ShiftInwards";
break;
Expand Down
4 changes: 1 addition & 3 deletions src/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ class LiftLoopInvariants : public IRMutator {
}
}
if (const Call *call = e.as<Call>()) {
if (call->is_intrinsic(Call::strict_float) ||
call->is_intrinsic(Call::likely) ||
call->is_intrinsic(Call::likely_if_innermost) ||
if (Call::as_tag(call) ||
call->is_intrinsic(Call::reinterpret)) {
// Don't lift these intrinsics. They're free.
return should_lift(call->args[0]);
Expand Down
21 changes: 9 additions & 12 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,6 @@ Module lower(const vector<Function> &output_funcs,
debug(1) << "Lowering after final simplification:\n"
<< s << "\n\n";

if (t.arch != Target::Hexagon && t.has_feature(Target::HVX)) {
debug(1) << "Splitting off Hexagon offload...\n";
s = inject_hexagon_rpc(s, t, result_module);
debug(2) << "Lowering after splitting off Hexagon offload:\n"
<< s << "\n";
} else {
debug(1) << "Skipping Hexagon offload...\n";
}

// TODO: Several tests depend on these custom passes running before
// inject_gpu_offload. We should either make this consistent with
// inject_hexagon_rpc above, or find a way to avoid this dependency.
if (!custom_passes.empty()) {
for (size_t i = 0; i < custom_passes.size(); i++) {
debug(1) << "Running custom lowering pass " << i << "...\n";
Expand All @@ -429,6 +417,15 @@ Module lower(const vector<Function> &output_funcs,
}
}

if (t.arch != Target::Hexagon && t.has_feature(Target::HVX)) {
debug(1) << "Splitting off Hexagon offload...\n";
s = inject_hexagon_rpc(s, t, result_module);
debug(2) << "Lowering after splitting off Hexagon offload:\n"
<< s << "\n";
} else {
debug(1) << "Skipping Hexagon offload...\n";
}

if (t.has_gpu_feature()) {
debug(1) << "Offloading GPU loops...\n";
s = inject_gpu_offload(s, t);
Expand Down
3 changes: 1 addition & 2 deletions src/Monotonic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,7 @@ class DerivativeBounds : public IRVisitor {

void visit(const Call *op) override {
// Some functions are known to be monotonic
if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost) ||
if (Call::as_tag(op) ||
op->is_intrinsic(Call::return_second)) {
op->args.back().accept(this);
return;
Expand Down
5 changes: 2 additions & 3 deletions src/RegionCosts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,8 @@ class ExprCost : public IRVisitor {
call->is_intrinsic(Call::count_leading_zeros) ||
call->is_intrinsic(Call::count_trailing_zeros)) {
arith += 5;
} else if (call->is_intrinsic(Call::likely) ||
call->is_intrinsic(Call::likely_if_innermost)) {
// Likely does not result in actual operations.
} else if (Call::as_tag(call)) {
// Tags do not result in actual operations.
} else {
// For other intrinsics, use 1 for the arithmetic cost.
arith += 1;
Expand Down
10 changes: 10 additions & 0 deletions src/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ enum class TailStrategy {
* case to handle the if statement. */
GuardWithIf,

/** Guard the inner loop with an if statement that prevents
* evaluation beyond the original extent, with a hint that the
* if statement should be implemented with predicated operations.
* Always legal. The if statement is treated like a boundary
* condition, and factored out into a loop epilogue if possible.
* Pros: no redundant re-evaluation; does not constrain input our
* output sizes. Cons: increases code size due to separate
* tail-case handling. */
Predicate,

/** Prevent evaluation beyond the original extent by shifting
* the tail case inwards, re-evaluating some points near the
* end. Only legal for pure variables in pure definitions. If
Expand Down
24 changes: 7 additions & 17 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,9 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {
learn_upper_bound(v, i.max - 1);
}
}
} else if (const Call *c = fact.as<Call>()) {
if (c->is_intrinsic(Call::likely) || c->is_intrinsic(Call::likely_if_innermost)) {
learn_false(c->args[0]);
return;
}
} else if (const Call *c = Call::as_tag(fact)) {
learn_false(c->args[0]);
return;
} else if (const Or *o = fact.as<Or>()) {
// Both must be false
learn_false(o->a);
Expand Down Expand Up @@ -286,11 +284,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
learn_lower_bound(v, i.min);
}
}
} else if (const Call *c = fact.as<Call>()) {
if (c->is_intrinsic(Call::likely) || c->is_intrinsic(Call::likely_if_innermost)) {
learn_true(c->args[0]);
return;
}
} else if (const Call *c = Call::as_tag(fact)) {
learn_true(c->args[0]);
return;
} else if (const And *a = fact.as<And>()) {
// Both must be true
learn_true(a->a);
Expand Down Expand Up @@ -421,13 +417,7 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) {
}
s[p.second] = make_const(p.first, (int)(rng() & 0xffff) - 0x7fff);
}
Expr probe = simplify(substitute(s, e));
if (const Call *c = probe.as<Call>()) {
if (c->is_intrinsic(Call::likely) ||
c->is_intrinsic(Call::likely_if_innermost)) {
probe = c->args[0];
}
}
Expr probe = unwrap_tags(simplify(substitute(s, e)));
if (!is_const_one(probe)) {
// Found a counter-example, or something that fails to fold
return false;
Expand Down
13 changes: 3 additions & 10 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
{arg, lower, upper},
Call::Intrinsic);
}
} else if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost)) {
} else if (Call::as_tag(op)) {
// The bounds of the result are the bounds of the arg
internal_assert(op->args.size() == 1);
Expr arg = mutate(op->args[0], bounds);
Expand All @@ -605,14 +604,8 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
internal_assert(op->args.size() == 3);
Expr cond_value = mutate(op->args[0], nullptr);

// Ignore likelies for our purposes here
Expr cond = cond_value;
if (const Call *c = cond.as<Call>()) {
if (c->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost)) {
cond = c->args[0];
}
}
// Ignore tags for our purposes here
Expr cond = unwrap_tags(cond_value);

if (is_const_one(cond)) {
return mutate(op->args[1], bounds);
Expand Down
6 changes: 3 additions & 3 deletions src/Simplify_Let.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
Expr op_b = var_a ? new_var : shuffle->vectors[1];
replacement = substitute(f.new_name, Shuffle::make_concat({op_a, op_b}), replacement);
f.new_value = var_a ? shuffle->vectors[1] : shuffle->vectors[0];
} else if (const Call *likely = Call::as_intrinsic(f.new_value, {Call::likely, Call::likely_if_innermost})) {
replacement = substitute(f.new_name, Call::make(likely->type, likely->name, {new_var}, Call::PureIntrinsic), replacement);
f.new_value = likely->args[0];
} else if (const Call *tag = Call::as_tag(f.new_value)) {
replacement = substitute(f.new_name, Call::make(tag->type, tag->name, {new_var}, Call::PureIntrinsic), replacement);
f.new_value = tag->args[0];
} else {
break;
}
Expand Down
8 changes: 2 additions & 6 deletions src/Simplify_Stmts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ using std::vector;
Stmt Simplify::visit(const IfThenElse *op) {
Expr condition = mutate(op->condition, nullptr);

// If (likely(true)) ...
const Call *likely = Call::as_intrinsic(condition, {Call::likely, Call::likely_if_innermost});
Expr unwrapped_condition = condition;
if (likely) {
unwrapped_condition = likely->args[0];
}
// Remove tags
Expr unwrapped_condition = unwrap_tags(condition);

// If (true) ...
if (is_const_one(unwrapped_condition)) {
Expand Down
4 changes: 1 addition & 3 deletions src/Solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,7 @@ class SolveExpression : public IRMutator {

Expr visit(const Call *op) override {
// Ignore intrinsics that shouldn't affect the results.
if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost) ||
op->is_intrinsic(Call::promise_clamped)) {
if (Call::as_tag(op)) {
return mutate(op->args[0]);
} else {
return IRMutator::visit(op);
Expand Down
4 changes: 1 addition & 3 deletions src/TrimNoOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class StripIdentities : public IRMutator {
using IRMutator::visit;

Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::return_second) ||
op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost)) {
if (Call::as_tag(op) || op->is_intrinsic(Call::return_second)) {
return mutate(op->args.back());
} else {
return IRMutator::visit(op);
Expand Down
Loading