Skip to content

Commit 357cb5d

Browse files
committed
Fix Comments
1 parent d463fa9 commit 357cb5d

File tree

5 files changed

+27
-37
lines changed

5 files changed

+27
-37
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
128128
let summary = "Convert SCF parallel for loops to nested SCF for loops";
129129
let constructor = "mlir::createParallelForToNestedForsPass()";
130130
let description = [{
131-
This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
131+
This pass transforms SCF::ParallelOp operations into a nest of SCF::ForOp
132132
operations. The transformation is useful for cases where the parallel loop
133133
can be expressed as a series of sequential iterations, allowing for more
134134
fine-grained control over the loop execution.

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
1414
#define MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
1515

16+
#include "mlir/Dialect/SCF/IR/SCF.h"
1617
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
1718
#include "mlir/Support/LLVM.h"
1819
#include "llvm/ADT/ArrayRef.h"
@@ -43,10 +44,9 @@ LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
4344
ParallelOp *result = nullptr);
4445

4546
/// Try converting scf.forall into an scf.parallel loop.
46-
/// The conversion is only supported for forall operations with no results.
47-
LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
48-
ParallelOp parallelOp,
49-
ForOp *result = nullptr);
47+
/// The conversion is only supported for parallel operations with no results.
48+
FailureOr<scf::LoopNest> parallelForToNestedFors(RewriterBase &rewriter,
49+
ParallelOp parallelOp);
5050

5151
/// Fuses all adjacent scf.parallel operations with identical bounds and step
5252
/// into one scf.parallel operations. Uses a naive aliasing and dependency

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
176176
return diag;
177177
}
178178

179-
scf::ForOp opResult;
180-
if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
179+
FailureOr<scf::LoopNest> loopNest =
180+
scf::parallelForToNestedFors(rewriter, target);
181+
if (failed(loopNest)) {
181182
DiagnosedSilenceableFailure diag =
182183
emitSilenceableError() << "failed to convert parallel into nested fors";
183184
return diag;
184185
}
185186

186-
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
187+
results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
187188
return DiagnosedSilenceableFailure::success();
188189
}
189190

mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,51 +23,40 @@ namespace mlir {
2323

2424
using namespace mlir;
2525

26-
LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
27-
scf::ParallelOp parallelOp,
28-
scf::ForOp *result) {
26+
FailureOr<scf::LoopNest>
27+
mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
28+
scf::ParallelOp parallelOp) {
2929

3030
if (!parallelOp.getResults().empty()) {
31-
parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
32-
"doesn't support ScfParallel with results.");
31+
parallelOp->emitError("Currently scf.parallel to scf.for conversion "
32+
"doesn't support scf.parallel with results.");
3333
return failure();
3434
}
3535

3636
rewriter.setInsertionPoint(parallelOp);
3737

3838
Location loc = parallelOp.getLoc();
39-
auto lowerBounds = parallelOp.getLowerBound();
40-
auto upperBounds = parallelOp.getUpperBound();
41-
auto steps = parallelOp.getStep();
39+
SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
40+
SmallVector<Value> upperBounds = parallelOp.getUpperBound();
41+
SmallVector<Value> steps = parallelOp.getStep();
4242

4343
assert(lowerBounds.size() == upperBounds.size() &&
4444
lowerBounds.size() == steps.size() &&
4545
"Mismatched parallel loop bounds");
4646

4747
SmallVector<Value> ivs;
48-
auto loopNest =
48+
scf::LoopNest loopNest =
4949
scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
5050

51-
auto oldInductionVars = parallelOp.getInductionVars();
52-
auto newInductionVars = llvm::map_to_vector(
51+
SmallVector<Value> newInductionVars = llvm::map_to_vector(
5352
loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
54-
assert(oldInductionVars.size() == newInductionVars.size() &&
55-
"Mismatched induction variables");
56-
for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
57-
oldIV.replaceAllUsesWith(newIV);
58-
59-
auto *linearizedBody = loopNest.loops.back().getBody();
60-
Block &parallelBody = *parallelOp.getBody();
61-
for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
62-
// Skip the terminator of the parallelOp body.
63-
if (&op == parallelBody.getTerminator())
64-
continue;
65-
op.moveBefore(linearizedBody->getTerminator());
66-
}
53+
Block *linearizedBody = loopNest.loops.back().getBody();
54+
Block *parallelBody = parallelOp.getBody();
55+
rewriter.eraseOp(parallelBody->getTerminator());
56+
rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
57+
newInductionVars);
6758
rewriter.eraseOp(parallelOp);
68-
if (result)
69-
*result = loopNest.loops.front();
70-
return success();
59+
return loopNest;
7160
}
7261

7362
namespace {

mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func.func private @callee(%i: index, %j: index) -> i32
6767

6868
func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
6969
%c0 = arith.constant 0 : i32
70-
// expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
70+
// expected-error@+1 {{Currently scf.parallel to scf.for conversion doesn't support scf.parallel with results}}
7171
%0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
7272
%curr = func.call @callee(%i, %j) : (index, index) -> i32
7373
scf.reduce(%curr : i32) {
@@ -77,4 +77,4 @@ func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: in
7777
}
7878
}
7979
return %0 : i32
80-
}
80+
}

0 commit comments

Comments
 (0)