Skip to content

Commit 3e59294

Browse files
Add TailStrategy::Predicate (#5856)
* Add TailStrategy::Predicate * Add some tests for TailStrategy::Predicate. * Fix missing override. * Fix comment. * Tweak target behavior. * Remove all heuristics * clang-format. * clang-tidy. * TailStrategy::GuardWithIf isn't always faster than scalar code :( * Use TailStrategy::Predicate in the predicated store/load test. * What is this test * Fix test bug. * Revert x86 behavior. * Move predicate to Internal namespace. * Recursively strip tags. * trigger buildbots * strip_tags -> unwrap_tags * Fix comment. Co-authored-by: Steven Johnson <[email protected]>
1 parent 7bbe2fd commit 3e59294

27 files changed

+293
-226
lines changed

src/ApplySplit.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const s
4949
} else if (is_const_one(split.factor)) {
5050
// The split factor trivially divides the old extent,
5151
// but we know nothing new about the outer dimension.
52-
} else if (tail == TailStrategy::GuardWithIf) {
52+
} else if (tail == TailStrategy::GuardWithIf ||
53+
tail == TailStrategy::Predicate) {
5354
// It's an exact split but we failed to prove that the
5455
// extent divides the factor. Use predication to avoid
5556
// running off the end of the original loop.
@@ -70,6 +71,10 @@ vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const s
7071
// Inject the if condition *after* doing the substitution
7172
// for the guarded version.
7273
Expr cond = likely(old_var <= old_max);
74+
if (tail == TailStrategy::Predicate) {
75+
// Add the hint for predication.
76+
cond = predicate(cond);
77+
}
7378
result.emplace_back(cond);
7479

7580
} else if (tail == TailStrategy::ShiftInwards) {

src/Bounds.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,9 +1232,7 @@ class Bounds : public IRVisitor {
12321232

12331233
interval.min = Interval::make_max(interval.min, lower.min);
12341234
interval.max = Interval::make_min(interval.max, upper.max);
1235-
} else if (op->is_intrinsic(Call::likely) ||
1236-
op->is_intrinsic(Call::likely_if_innermost)) {
1237-
internal_assert(op->args.size() == 1);
1235+
} else if (Call::as_tag(op)) {
12381236
op->args[0].accept(this);
12391237
} else if (op->is_intrinsic(Call::return_second)) {
12401238
internal_assert(op->args.size() == 2);
@@ -2477,10 +2475,8 @@ class BoxesTouched : public IRGraphVisitor {
24772475
for (const auto &pair : cases) {
24782476
Expr c = pair.first;
24792477
Stmt body = pair.second;
2480-
const Call *call = c.as<Call>();
2481-
if (call && (call->is_intrinsic(Call::likely) ||
2482-
call->is_intrinsic(Call::likely_if_innermost) ||
2483-
call->is_intrinsic(Call::strict_float))) {
2478+
const Call *call = Call::as_tag(c);
2479+
if (call) {
24842480
c = call->args[0];
24852481
}
24862482

src/Derivative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ void ReverseAccumulationVisitor::visit(const Call *op) {
11661166
accumulate(op->args[0], adjoint * (make_one(op->type) - op->args[2]));
11671167
accumulate(op->args[1], adjoint * op->args[2]);
11681168
accumulate(op->args[2], adjoint * (op->args[1] - op->args[0]));
1169-
} else if (op->is_intrinsic(Call::likely)) {
1169+
} else if (Call::as_tag(op)) {
11701170
accumulate(op->args[0], adjoint);
11711171
} else if (op->is_intrinsic(Call::return_second)) {
11721172
accumulate(op->args[0], make_const(op->type, 0.0));

src/Func.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,9 +1057,9 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
10571057
}
10581058

10591059
if (exact) {
1060-
user_assert(tail == TailStrategy::GuardWithIf)
1060+
user_assert(tail == TailStrategy::GuardWithIf || tail == TailStrategy::Predicate)
10611061
<< "When splitting Var " << old_name
1062-
<< " the tail strategy must be GuardWithIf or Auto. "
1062+
<< " the tail strategy must be GuardWithIf, Predicate or Auto. "
10631063
<< "Anything else may change the meaning of the algorithm\n";
10641064
}
10651065

src/IR.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ const char *const intrinsic_op_names[] = {
617617
"mulhi_shr",
618618
"mux",
619619
"popcount",
620+
"predicate",
620621
"prefetch",
621622
"promise_clamped",
622623
"random",

src/IR.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ struct Call : public ExprNode<Call> {
529529
mulhi_shr, // Compute high_half(arg[0] * arg[1]) >> arg[3]. Note that this is a shift in addition to taking the upper half of multiply result. arg[3] must be an unsigned integer immediate.
530530
mux,
531531
popcount,
532+
predicate,
532533
prefetch,
533534
promise_clamped,
534535
random,
@@ -661,6 +662,10 @@ struct Call : public ExprNode<Call> {
661662
return nullptr;
662663
}
663664

665+
static const Call *as_tag(const Expr &e) {
666+
return as_intrinsic(e, {Call::likely, Call::likely_if_innermost, Call::predicate, Call::strict_float});
667+
}
668+
664669
bool is_extern() const {
665670
return (call_type == Extern ||
666671
call_type == ExternCPlusPlus ||

src/IROperator.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,19 @@ Stmt remove_likelies(const Stmt &s) {
10351035
return RemoveLikelies().mutate(s);
10361036
}
10371037

1038+
Expr unwrap_tags(const Expr &e) {
1039+
if (const Call *tag = Call::as_tag(e)) {
1040+
return unwrap_tags(tag->args[0]);
1041+
}
1042+
return e;
1043+
}
1044+
1045+
Expr predicate(Expr e) {
1046+
Type t = e.type();
1047+
return Internal::Call::make(t, Internal::Call::predicate,
1048+
{std::move(e)}, Internal::Call::PureIntrinsic);
1049+
}
1050+
10381051
Expr requirement_failed_error(Expr condition, const std::vector<Expr> &args) {
10391052
return Internal::Call::make(Int(32),
10401053
"halide_error_requirement_failed",

src/IROperator.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,16 @@ Expr remove_likelies(const Expr &e);
307307
* all calls to likely() and likely_if_innermost() removed. */
308308
Stmt remove_likelies(const Stmt &s);
309309

310+
/** If the expression is a tag helper call, remove it and return
311+
* the tagged expression. If not, returns the expression. */
312+
Expr unwrap_tags(const Expr &e);
313+
314+
/** Expressions tagged with this intrinsic are suggestions that
315+
* vectorization of loops with guard ifs should be implemented with
316+
* non-faulting predicated loads and stores, instead of scalarizing
317+
* an if statement. */
318+
Expr predicate(Expr e);
319+
310320
// Secondary args to print can be Exprs or const char *
311321
inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args) {
312322
}

src/IRPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ std::ostream &operator<<(std::ostream &out, const TailStrategy &t) {
147147
case TailStrategy::GuardWithIf:
148148
out << "GuardWithIf";
149149
break;
150+
case TailStrategy::Predicate:
151+
out << "Predicate";
152+
break;
150153
case TailStrategy::ShiftInwards:
151154
out << "ShiftInwards";
152155
break;

src/LICM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ class LiftLoopInvariants : public IRMutator {
9797
}
9898
}
9999
if (const Call *call = e.as<Call>()) {
100-
if (call->is_intrinsic(Call::strict_float) ||
101-
call->is_intrinsic(Call::likely) ||
102-
call->is_intrinsic(Call::likely_if_innermost) ||
100+
if (Call::as_tag(call) ||
103101
call->is_intrinsic(Call::reinterpret)) {
104102
// Don't lift these intrinsics. They're free.
105103
return should_lift(call->args[0]);

0 commit comments

Comments
 (0)