Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,7 @@ OrderedAssignmentRewriter::generateYieldedEntity(
return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
}

assert(region.hasOneBlock() && "region must contain one block");
auto oldYield = getYield(region);
mlir::Block::OpListType &ops = region.back().getOperations();

// Inside Forall, scalars that do not depend on forall indices can be hoisted
// here because their evaluation is required to only call pure procedures, and
// if they depend on a variable previously assigned to in a forall assignment,
Expand All @@ -674,24 +671,24 @@ OrderedAssignmentRewriter::generateYieldedEntity(
bool hoistComputation = false;
if (fir::isa_trivial(oldYield.getEntity().getType()) &&
!constructStack.empty()) {
hoistComputation = true;
for (mlir::Operation &op : ops)
if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
return isForallIndex(value);
})) {
hoistComputation = false;
break;
}
mlir::WalkResult walkResult =
region.walk([&](mlir::Operation *op) -> mlir::WalkResult {
if (llvm::any_of(op->getOperands(), [](mlir::Value value) {
return isForallIndex(value);
}))
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
});
hoistComputation = !walkResult.wasInterrupted();
}
auto insertionPoint = builder.saveInsertionPoint();
if (hoistComputation)
builder.setInsertionPoint(constructStack[0]);

// Clone all operations except the final hlfir.yield.
assert(!ops.empty() && "yield block cannot be empty");
auto end = ops.end();
for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
(void)builder.clone(*opIt, mapper);
assert(region.hasOneBlock() && "region must contain one block");
for (auto &op : region.back().without_terminator())
(void)builder.clone(op, mapper);
// Get the value for the yielded entity, it may be the result of an operation
// that was cloned, or it may be the same as the previous value if the yield
// operand was created before the ordered assignment tree.
Expand Down
64 changes: 64 additions & 0 deletions flang/test/HLFIR/order_assignments/forall-issue120190.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Regression test for https://github.com/llvm/llvm-project/issues/120190
// Verify that hlfir.forall lowering does not try hoisting mask evaluation
// that refer to the forall index inside nested regions only.
// RUN: fir-opt %s --lower-hlfir-ordered-assignments | FileCheck %s

func.func @issue120190(%array: !fir.ref<!fir.array<100xf32>>, %cdt: i1) {
%cst = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1 : i64
%c50 = arith.constant 50 : i64
%c100 = arith.constant 100 : i64
hlfir.forall lb {
hlfir.yield %c1 : i64
} ub {
hlfir.yield %c100 : i64
} (%forall_index: i64) {
hlfir.forall_mask {
%mask = fir.if %cdt -> i1 {
// Reference to %forall_index is not directly in
// hlfir.forall_mask region, but is nested.
%res = arith.cmpi slt, %forall_index, %c50 : i64
fir.result %res : i1
} else {
%res = arith.cmpi sgt, %forall_index, %c50 : i64
fir.result %res : i1
}
hlfir.yield %mask : i1
} do {
hlfir.region_assign {
hlfir.yield %cst : f32
} to {
%6 = hlfir.designate %array (%forall_index) : (!fir.ref<!fir.array<100xf32>>, i64) -> !fir.ref<f32>
hlfir.yield %6 : !fir.ref<f32>
}
}
}
return
}

// CHECK-LABEL: func.func @issue120190(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<100xf32>>,
// CHECK-SAME: %[[VAL_1:.*]]: i1) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_4:.*]] = arith.constant 50 : i64
// CHECK: %[[VAL_5:.*]] = arith.constant 100 : i64
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] {
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (index) -> i64
// CHECK: %[[VAL_11:.*]] = fir.if %[[VAL_1]] -> (i1) {
// CHECK: %[[VAL_12:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_4]] : i64
// CHECK: fir.result %[[VAL_12]] : i1
// CHECK: } else {
// CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_10]], %[[VAL_4]] : i64
// CHECK: fir.result %[[VAL_13]] : i1
// CHECK: }
// CHECK: fir.if %[[VAL_11]] {
// CHECK: %[[VAL_14:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_10]]) : (!fir.ref<!fir.array<100xf32>>, i64) -> !fir.ref<f32>
// CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_14]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }
Loading