Skip to content

Commit 5eeeb4f

Browse files
[compiler] Add mtrt-scf-float-strength-reduce pass to preprocessing pipeline
In the Stablehlo preprocessing pipeline, enable the `mtrt-scf-float-strength-reduce` pass in order to convert while-style loops to for-style loops where possible. In order for this to work on some common JAX use-cases, we also need to more aggressively detensorize loops in the `convert-stablehlo-to-scf` pass. GitOrigin-RevId: 2614ebf00b5c2d29de32b09a78e7b1a2f42c13cc
1 parent 440b3af commit 5eeeb4f

File tree

4 files changed

+37
-120
lines changed

4 files changed

+37
-120
lines changed

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerStableHloToExecutable
3636
MLIRTensorRTTargetLua
3737
MLIRTensorRTTensorRTBackend
3838
MLIRTensorRTTensorRTToTensorRTRuntime
39+
MLIRTensorRTTransformsSCFFloatStrengthReduce
3940
MLIRTensorRTTransformsUnrollForLoops
4041
StablehloLinalgTransforms
4142
MLIR_LIBS PUBLIC

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void mtrt::compiler::buildStablehloPreProcessingPipeline(
9393
// `convert-stablehlo-to-scf`:
9494
if (opts.legalizeControlFlowToSCF) {
9595
pm.addNestedPass<func::FuncOp>(mlir::createConvertStablehloToScfPass());
96+
pm.addNestedPass<func::FuncOp>(mtrt::createSCFFloatStrengthReducePass());
9697
pm.addNestedPass<func::FuncOp>(mtrt::createSCFUnrollPass(
9798
mtrt::SCFUnrollPassOptions{opts.unrollThreshold}));
9899
}

mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp

Lines changed: 19 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -357,91 +357,24 @@ struct ScalarizeWhileConditionProducers
357357
};
358358
} // namespace
359359

360-
/// Check if the add op is a valid induction variable increment.
361-
static bool matchInductionVariableIncrement(stablehlo::AddOp op,
362-
scf::WhileOp parentWhile) {
363-
Value lhs = op.getLhs();
364-
Value rhs = op.getRhs();
365-
if (matchPattern(lhs, m_Constant()) || matchPattern(rhs, m_Constant()))
366-
return true;
367-
Region *whileRegion = parentWhile->getParentRegion();
368-
return lhs.getParentRegion()->isAncestor(whileRegion) ||
369-
rhs.getParentRegion()->isAncestor(whileRegion);
370-
}
371-
372360
namespace {
373361
/// Scalarize any `stablehlo.add` operations in the 'after' region of
374362
/// a scf.while op.
375-
struct ScalarizeStablehloAddOp : public OpRewritePattern<stablehlo::AddOp> {
376-
using OpRewritePattern<stablehlo::AddOp>::OpRewritePattern;
377-
LogicalResult matchAndRewrite(stablehlo::AddOp op,
363+
struct ScalarizeStablehloAddOp : public OpRewritePattern<tensor::ExtractOp> {
364+
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
365+
LogicalResult matchAndRewrite(tensor::ExtractOp op,
378366
PatternRewriter &rewriter) const override {
379-
if (!op->hasOneUse())
380-
return rewriter.notifyMatchFailure(
381-
op, "op has more than one use, cannot scalarize");
382-
auto extractUser = dyn_cast<tensor::ExtractOp>(*op->user_begin());
383-
if (!extractUser || !extractUser->hasOneUse() ||
384-
!isa<scf::YieldOp>(*extractUser->user_begin()))
385-
return rewriter.notifyMatchFailure(
386-
op, "op result is not extracted and yielded from region");
387-
388-
auto scfWhile = extractUser->getParentOfType<scf::WhileOp>();
389-
if (!scfWhile || scfWhile.getAfter() != op->getParentRegion())
390-
return rewriter.notifyMatchFailure(
391-
op, "op is not in the after region of a scf.while op");
392-
393-
// One operand must be a constant or defined above in order to be
394-
// considered as the loop step.
395-
if (!matchInductionVariableIncrement(op, scfWhile))
396-
return rewriter.notifyMatchFailure(
397-
op, "op is not a valid induction variable increment");
398-
399-
// Find a block argument that has been scalarized.
400-
auto findBlockArgument = [](Value v) -> BlockArgument {
401-
Value source{};
402-
if (matchPattern(v,
403-
m_Op<tensor::FromElementsOp>(matchers::m_Any(&source))))
404-
return dyn_cast<BlockArgument>(source);
405-
return {};
406-
};
407-
BlockArgument arg = findBlockArgument(op.getLhs());
408-
if (!arg)
409-
arg = findBlockArgument(op.getRhs());
410-
if (!arg || arg.getParentRegion() != scfWhile.getAfter())
411-
return rewriter.notifyMatchFailure(
412-
op, "could not find block argument in after region");
413-
414-
// Check that the corresponding block argument in the `before` region feeds
415-
// into a comparison.
416-
Region &before = scfWhile.getBefore();
417-
if (arg.getArgNumber() >= before.getNumArguments() ||
418-
before.getArgument(arg.getArgNumber()).getType() != arg.getType())
419-
return rewriter.notifyMatchFailure(
420-
op, "could not find block argument in before region");
421-
auto beforeArg = before.getArgument(arg.getArgNumber());
422-
if (!llvm::all_of(beforeArg.getUsers(),
423-
llvm::IsaPred<scf::ConditionOp, arith::CmpIOp>))
424-
return rewriter.notifyMatchFailure(
425-
op, "block argument is not consumed by a comparison op");
426-
427-
// Check that the before region has a block argument in the same position
428-
// and is consumed by a comparison op.
429-
RankedTensorType rtt = op.getType();
430-
Type elementType = rtt.getElementType();
431-
if (!rtt.hasStaticShape() || rtt.getNumElements() != 1 ||
432-
!elementType.isSignlessIntOrIndex())
433-
return rewriter.notifyMatchFailure(op, "op is not a scalar add op");
434-
435-
auto scalarOperands = llvm::map_to_vector(op.getOperands(), [&](Value v) {
436-
return extractScalarFromTensorValue(rewriter, v);
437-
});
438-
439-
auto scalarAdd =
440-
stablehlo::StablehloOpToStdScalarOp::mapOp<stablehlo::AddOp>(
441-
op, elementType, scalarOperands, &rewriter);
442-
auto fromElements =
443-
rewriter.create<tensor::FromElementsOp>(op.getLoc(), rtt, scalarAdd);
444-
rewriter.replaceOp(op, fromElements);
367+
auto addOp = op.getTensor().getDefiningOp<stablehlo::AddOp>();
368+
if (!addOp || !addOp.getType().hasStaticShape() ||
369+
addOp.getType().getNumElements() != 1)
370+
return failure();
371+
rewriter.setInsertionPoint(addOp);
372+
SmallVector<Value> scalarOperands;
373+
for (Value operand : addOp.getOperands())
374+
scalarOperands.push_back(extractScalarFromTensorValue(rewriter, operand));
375+
auto scalarAdd = stablehlo::StablehloOpToStdScalarOp::mapOp(
376+
addOp, addOp.getType().getElementType(), scalarOperands, &rewriter);
377+
rewriter.replaceOp(op, scalarAdd);
445378
return success();
446379
}
447380
};
@@ -453,9 +386,7 @@ struct ScalarizeStablehloAddOp : public OpRewritePattern<stablehlo::AddOp> {
453386
/// for loop. It will have a user like `stablehlo.compare` or `tensor.extract`.
454387
static bool shouldScalarizeWhileBeforeArg(BlockArgument arg, Value initOperand,
455388
Value yieldOperand) {
456-
return cast<RankedTensorType>(arg.getType())
457-
.getElementType()
458-
.isSignlessIntOrIndex() &&
389+
return cast<RankedTensorType>(arg.getType()).getElementType() &&
459390
llvm::count_if(arg.getUsers(),
460391
llvm::IsaPred<stablehlo::CompareOp, arith::CmpIOp,
461392
tensor::ExtractOp>) >= 1;
@@ -473,17 +404,17 @@ static bool shouldScalarizeWhileAfterArg(BlockArgument arg, Value condOperand,
473404
if (before.getNumArguments() <= arg.getArgNumber() ||
474405
before.getArgument(arg.getArgNumber()).getType() !=
475406
rtt.getElementType() ||
476-
!llvm::all_of(before.getArgument(arg.getArgNumber()).getUsers(),
477-
llvm::IsaPred<arith::CmpIOp, tensor::FromElementsOp>))
407+
!llvm::all_of(
408+
before.getArgument(arg.getArgNumber()).getUsers(),
409+
llvm::IsaPred<arith::CmpIOp, arith::CmpFOp, tensor::FromElementsOp>))
478410
return false;
479411

480412
auto condProducer = condOperand.getDefiningOp<tensor::FromElementsOp>();
481413
if (!condProducer || condProducer.getElements().size() != 1 ||
482414
!isa<BlockArgument>(condProducer.getElements().front()))
483415
return false;
484416

485-
return rtt.getElementType().isSignlessIntOrIndex() &&
486-
llvm::count_if(arg.getUsers(),
417+
return llvm::count_if(arg.getUsers(),
487418
llvm::IsaPred<stablehlo::AddOp, arith::AddIOp,
488419
tensor::ExtractOp>) >= 1;
489420
}

mlir-tensorrt/compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,18 @@ func.func @stablehlo_while_to_scf_while(%arg0: tensor<i64>, %arg1: tensor<i64>)
7070

7171
func.func private @some_compute(tensor<f32>) -> tensor<1xf32>
7272

73-
func.func @stablehlo_while_regression(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
74-
%c_33 = stablehlo.constant dense<0> : tensor<i32>
73+
func.func @stablehlo_while_single_iteration(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
74+
%c0 = stablehlo.constant dense<0> : tensor<i32>
7575
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>
76-
%c_31 = stablehlo.constant dense<1> : tensor<i32>
76+
%c1 = stablehlo.constant dense<1> : tensor<i32>
7777
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
78-
%5:2 = stablehlo.while(%iterArg = %c_33, %iterArg_34 = %cst) : tensor<i32>, tensor<1xf32>
78+
%5:2 = stablehlo.while(%iterArg = %c0, %iterArg_34 = %cst) : tensor<i32>, tensor<1xf32>
7979
cond {
80-
%6 = stablehlo.compare LT, %iterArg, %c_31, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
80+
%6 = stablehlo.compare LT, %iterArg, %c1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
8181
stablehlo.return %6 : tensor<i1>
8282
} do {
83-
%6 = stablehlo.compare LT, %iterArg, %c_33, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
84-
%7 = stablehlo.add %iterArg, %c_31 : tensor<i32>
83+
%6 = stablehlo.compare LT, %iterArg, %c0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
84+
%7 = stablehlo.add %iterArg, %c1 : tensor<i32>
8585
%8 = stablehlo.select %6, %7, %iterArg : tensor<i1>, tensor<i32>
8686
%10 = stablehlo.dynamic_slice %arg0, %8, sizes = [1] : (tensor<1xf32>, tensor<i32>) -> tensor<1xf32>
8787
%11 = stablehlo.reshape %10 : (tensor<1xf32>) -> tensor<f32>
@@ -94,30 +94,14 @@ func.func @stablehlo_while_regression(%arg0: tensor<1xf32>, %arg1: tensor<f32>)
9494
return %5#1 : tensor<1xf32>
9595
}
9696

97-
// CHECK-LABEL: func.func @stablehlo_while_regression
98-
// CHECK: scf.while
99-
100-
// -----
101-
102-
103-
func.func @dont_scalarize_while(%arg0: tensor<f32>) -> tensor<f32> {
104-
%0 = stablehlo.while(%iterArg = %arg0) : tensor<f32>
105-
cond {
106-
%c0 = stablehlo.constant dense<0.0> : tensor<f32>
107-
%1 = stablehlo.compare LT, %iterArg, %c0, SIGNED : (tensor<f32>, tensor<f32>) -> tensor<i1>
108-
stablehlo.return %1 : tensor<i1>
109-
} do {
110-
%c1 = stablehlo.constant dense<1.0> : tensor<f32>
111-
%2 = stablehlo.subtract %iterArg, %c1 : tensor<f32>
112-
stablehlo.return %2 : tensor<f32>
113-
}
114-
return %0 : tensor<f32>
115-
}
116-
117-
// CHECK-LABEL: @dont_scalarize_while
118-
// CHECK: scf.while {{.*}} (tensor<f32>) -> tensor<f32>
119-
// CHECK: scf.condition{{.*}} : tensor<f32>
120-
// CHECK: scf.yield{{.*}} : tensor<f32>
97+
// CHECK-LABEL: func.func @stablehlo_while_single_iteration
98+
// CHECK-NOT: scf.while
99+
// CHECK-NOT: scf.for
100+
// CHECK: stablehlo.compare
101+
// CHECK: stablehlo.add
102+
// CHECK: stablehlo.dynamic_slice
103+
// CHECK: call @some_compute
104+
// CHECK: return
121105

122106
// -----
123107

@@ -254,4 +238,4 @@ func.func @case_three_branches(
254238
// CHECK-DAG: %[[v5:.+]] = stablehlo.multiply %[[v4]], %[[arg2]] : tensor<2xi64>
255239
// CHECK-DAG: scf.yield %[[v5]] : tensor<2xi64>
256240
// CHECK: scf.yield %[[v3]] : tensor<2xi64>
257-
// CHECK: return %[[v1]] : tensor<2xi64>
241+
// CHECK: return %[[v1]] : tensor<2xi64>

0 commit comments

Comments
 (0)