Skip to content

Commit 114d0cd

Browse files
[StablehloExt] Refactor simplification patterns into separate files
Extract simplification patterns from ConstantFolding.cpp into a new Simplifications.cpp file to improve code organization and maintainability. Also create CastPatterns.cpp to separate cast-related patterns. Changes: - Create Simplifications.cpp with simplification patterns (transpose combining, min/max simplification, etc.) - Create CastPatterns.cpp for cast-related pattern matching - Update ConstantFolding.cpp to remove extracted patterns - Reorganize tests: move simplification tests to simplifications.mlir, rename constant-folding-scatter.mlir to simplifications-scatter.mlir - Update pass registration and CMakeLists.txt accordingly In addition, adds a generic pattern for `stablehlo.dot_general` to `stablehlo.multiply` along with a number of test cases. This replaces the previous pattern that was under the "stablehlo-ext-canonicalize-dot-general" pass and was less robust. GitOrigin-RevId: 13535beff08e9693ab0e3c2e16317c93cef1f97e
1 parent df0a3f8 commit 114d0cd

File tree

27 files changed

+1451
-953
lines changed

27 files changed

+1451
-953
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
add_subdirectory(mlir-tensorrt)
2-

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ add_mlir_tensorrt_doc(ConversionPasses
1515
OUTPUT_FILE docs/Passes/ConversionPasses.md
1616
COMMAND -gen-pass-doc ${_TABLEGEN_ARGS}
1717
)
18+

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ endif()
55
add_subdirectory(CUDA)
66
add_subdirectory(Plan)
77
add_subdirectory(TensorRTRuntime)
8+

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/Transforms/Passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@
2222

2323
include "mlir/Pass/PassBase.td"
2424

25+
//===----------------------------------------------------------------------===//
26+
// StablehloExtSimplificationsPass
27+
//===----------------------------------------------------------------------===//
28+
29+
def StablehloExtSimplificationsPass : Pass<"stablehlo-ext-simplifications"> {
30+
let summary = "Applies Stablehlo device-agnostic simplifications along with additional patterns";
31+
let description = [{
32+
This pass runs the patterns defined by
33+
`stablehlo-aggressive-simplification` along with additional custom
34+
patterns not yet upstreamed.
35+
}];
36+
37+
let dependentDialects = [
38+
"::mlir::stablehlo::StablehloDialect",
39+
"::mlir::arith::ArithDialect",
40+
"::mlir::tensor::TensorDialect"
41+
];
42+
43+
let options = [
44+
Option<"foldOpElementLimit", "fold-op-element-limit", "int64_t", "1024",
45+
"The computation size limit for constant folding">
46+
];
47+
}
48+
2549
//===----------------------------------------------------------------------===//
2650
// StablehloRaiseQDQPass
2751
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/Transforms/Patterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ namespace stablehlo_ext {
3333
/// `tensor.cast` producers.
3434
void populateStableHloAbsorbTensorCastPatterns(RewritePatternSet &patterns);
3535

36+
/// Populate patterns that perform simplifications.
37+
void populateStableHloExtSimplificationsPatterns(
38+
RewritePatternSet &patterns,
39+
const stablehlo::StablehloAggressiveSimplificationPassOptions &options,
40+
PatternBenefit benefit = 1);
41+
42+
/// Populate patterns that simplify `stablehlo.dot_general` to
43+
/// `stablehlo.multiply`.
44+
void populateStablehloDotGeneralToMultiplyPatterns(RewritePatternSet &patterns);
45+
3646
/// Populate patterns that perform target-independent simplifications.
3747
/// The `sizeLimit` is the maximum tensor volume beyond which constant folding
3848
/// is not attempted.

mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ endif()
99

1010
mlir_tablegen(Passes.h.inc ${_TABLEGEN_ARGS})
1111
mtrt_add_public_tablegen_target(MLIRTensorRTGenericTransformPassIncGen)
12+

mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ add_mlir_tensorrt_library(MLIRTensorRTInitAllPasses
4848
InitAllPasses.cpp
4949
PARTIAL_SOURCES_INTENDED
5050

51+
DEPENDS
52+
mtrt-headers
53+
5154
LINK_LIBS PUBLIC
5255
MLIRTensorRTCompilerIncludes
5356
${MLIR_TENSORRT_DIALECT_LIBS}

mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ add_subdirectory(TensorRTRuntimeToExecutor)
2424
add_subdirectory(TensorRTRuntimeToLLVM)
2525
add_subdirectory(TensorRTToEmitC)
2626
add_subdirectory(TensorRTToTensorRTRuntime)
27+

mlir-tensorrt/compiler/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ add_subdirectory(TensorRTRuntime)
55
if(MLIR_TRT_ENABLE_HLO)
66
add_subdirectory(StablehloExt)
77
endif()
8+
Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
add_mlir_tensorrt_library(MLIRTensorRTStableHloExtTransforms
2-
CanonicalizeConvolution.cpp
3-
CanonicalizeDotGeneral.cpp
4-
CanonicalizeGather.cpp
5-
CanonicalizeScatter.cpp
6-
CanonicalizeShapes.cpp
7-
ConstantFolding.cpp
8-
ExpandTuples.cpp
9-
GatherToSlice.cpp
10-
LowerSpecialCustomCalls.cpp
11-
MaterializeDenseResourceElementsAttr.cpp
12-
StablehloRaiseQDQ.cpp
13-
TargetSpecificOptimizations.cpp
2+
CastPatterns.cpp
3+
CanonicalizeConvolution.cpp
4+
CanonicalizeDotGeneral.cpp
5+
CanonicalizeGather.cpp
6+
CanonicalizeScatter.cpp
7+
CanonicalizeShapes.cpp
8+
ConstantFolding.cpp
9+
ExpandTuples.cpp
10+
GatherToSlice.cpp
11+
LowerSpecialCustomCalls.cpp
12+
MaterializeDenseResourceElementsAttr.cpp
13+
Simplifications.cpp
14+
StablehloRaiseQDQ.cpp
15+
TargetSpecificOptimizations.cpp
1416

15-
LINK_LIBS PUBLIC
16-
ChloOps
17-
MLIRExecutorCommonUtils
18-
MLIRTensorRTConstantFoldingUtils
19-
MLIRTensorRTShapeUtils
20-
MLIRTensorRTStableHloExtUtils
21-
StablehloOps
22-
StablehloPasses
23-
StablehloOptimizationPasses
24-
MLIR_LIBS PUBLIC
25-
MLIRRewrite
26-
MLIRTensorDialect
27-
DEPENDS
17+
LINK_LIBS PUBLIC
18+
ChloOps
19+
MLIRExecutorCommonUtils
20+
MLIRTensorRTConstantFoldingUtils
21+
MLIRTensorRTShapeUtils
22+
MLIRTensorRTStableHloExtUtils
23+
StablehloOps
24+
StablehloPasses
25+
StablehloOptimizationPasses
26+
MLIR_LIBS PUBLIC
27+
MLIRRewrite
28+
MLIRTensorDialect
29+
DEPENDS
2830
MLIRTensorRTStableHloExtTransformsPassIncGen
2931
)

0 commit comments

Comments
 (0)