Skip to content

Commit cb78a6b

Browse files
Don't strip strict_float() from lets (#5871)
* Don't strip strict_float() from lets Bug injected in #5856: the change in Simplify_Let.cpp was inadvertently stripping `strict_float()` calls that wrapped the RHS of a Let-expr, which can change results nontrivially in some cases. I don't think a new test for this fix is practical -- it would be a little fragile, as it would rely on the specifics of simplification that could change over time. As a drive-by, also added an explicit rule to Simplify_Call to ensure that strict_float(strict_float(x)) -> strict_float(x) in *all* cases. (The existing rule didn't do this in all cases.)
1 parent 896b260 commit cb78a6b

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

src/Simplify_Call.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,18 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
111111
}
112112

113113
if (op->is_intrinsic(Call::strict_float)) {
114-
ScopedValue<bool> save_no_float_simplify(no_float_simplify, true);
115-
Expr arg = mutate(op->args[0], nullptr);
116-
if (arg.same_as(op->args[0])) {
117-
return op;
114+
if (Call::as_intrinsic(op->args[0], {Call::strict_float})) {
115+
// Always simplify strict_float(strict_float(x)) -> strict_float(x).
116+
Expr arg = mutate(op->args[0], nullptr);
117+
return arg.same_as(op->args[0]) ? op->args[0] : arg;
118118
} else {
119-
return strict_float(arg);
119+
ScopedValue<bool> save_no_float_simplify(no_float_simplify, true);
120+
Expr arg = mutate(op->args[0], nullptr);
121+
if (arg.same_as(op->args[0])) {
122+
return op;
123+
} else {
124+
return strict_float(arg);
125+
}
120126
}
121127
} else if (op->is_intrinsic(Call::popcount) ||
122128
op->is_intrinsic(Call::count_leading_zeros) ||

src/Simplify_Let.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
9696
const Shuffle *shuffle = f.new_value.template as<Shuffle>();
9797
const Variable *var_b = nullptr;
9898
const Variable *var_a = nullptr;
99+
const Call *tag = nullptr;
99100

100101
if (add) {
101102
var_a = add->a.as<Variable>();
@@ -174,7 +175,9 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
174175
Expr op_b = var_a ? new_var : shuffle->vectors[1];
175176
replacement = substitute(f.new_name, Shuffle::make_concat({op_a, op_b}), replacement);
176177
f.new_value = var_a ? shuffle->vectors[1] : shuffle->vectors[0];
177-
} else if (const Call *tag = Call::as_tag(f.new_value)) {
178+
} else if ((tag = Call::as_tag(f.new_value)) != nullptr && !tag->is_intrinsic(Call::strict_float)) {
179+
// Most tags should be stripped here, but not strict_float(); removing it will change the semantics
180+
// of the let-expr we are producing.
178181
replacement = substitute(f.new_name, Call::make(tag->type, tag->name, {new_var}, Call::PureIntrinsic), replacement);
179182
f.new_value = tag->args[0];
180183
} else {

test/correctness/simplify.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,8 @@ void check_math() {
17631763
check(Halide::trunc(-1.6f), -1.0f);
17641764
check(Halide::floor(round(x)), round(x));
17651765
check(Halide::ceil(ceil(x)), ceil(x));
1766+
1767+
check(strict_float(strict_float(x)), strict_float(x));
17661768
}
17671769

17681770
void check_overflow() {

0 commit comments

Comments
 (0)