Skip to content

Commit e086494

Browse files
shelkesagar29Copybara Bot
andauthored
Move internal changes (#411)
This is a combination of the following changes: [compiler/Dialect/Plan] Drop `getClusterKindName` from `ClusterKindAttrInterface` Removes the concept of having a cluster "name" from the `plan::ClusterKindAttrInterface`. This concept is unnecessary since all concrete implementations of the interface have their own mnemonic. NFC: [compiler/StableHloExt] Update tests to avoid relying on other canonicalization patterns Updates `stablehlo-ext-canonicalize-scatter` pass test cases to void relying on other canonicalization patterns and just tests the effects of the 'stablehlo.scatter' canonicalization patterns. Co-authored-by: Copybara Bot <[email protected]>
1 parent 24da898 commit e086494

File tree

5 files changed

+111
-107
lines changed

5 files changed

+111
-107
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ def TensorBoundsAttrInterface : AttrInterface<"TensorBoundsAttrInterface"> {
3838
def ClusterKindAttrInterface : AttrInterface<"ClusterKindAttrInterface"> {
3939
let cppNamespace = "::mlir::plan";
4040
let methods = [
41-
InterfaceMethod<
42-
/*desc=*/"Return the name of the cluster",
43-
/*retTy-*/"std::string",
44-
"getClusterKindName",
45-
/*args=*/(ins ),
46-
/*body=*/"",
47-
""
48-
>,
4941
InterfaceMethod<
5042
/*desc=*/"Return the clustering options for the cluster",
5143
/*retTy-*/"::mlir::ClusteringOpts",

mlir-tensorrt/compiler/lib/Dialect/Plan/IR/BuiltinClusterKinds.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ using namespace mlir::plan;
5151
// HostClusterKindAttr
5252
//===----------------------------------------------------------------------===//
5353

54-
std::string HostClusterKindAttr::getClusterKindName() const { return "host"; }
55-
5654
int64_t HostClusterKindAttr::getClusterBenefit() const { return getBenefit(); }
5755

5856
/// ClusteringOpts that identifies groups of `stablehlo` ops that can be
@@ -124,10 +122,6 @@ HostClusterKindAttr::getClusterFilter() const {
124122
// TensorRTClusterKindAttr
125123
//===----------------------------------------------------------------------===//
126124

127-
std::string TensorRTClusterKindAttr::getClusterKindName() const {
128-
return "tensorrt";
129-
}
130-
131125
static ShapeInfoCallbacks getShapeInfoCallbacks() {
132126
ShapeInfoCallbacks shapeInfoCallbacks{};
133127
shapeInfoCallbacks.isElementValueEqualToConstant =

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/StablehloClustering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ applyClusteringToFunc(RewriterBase &rewriter, func::FuncOp func,
9191
const StablehloClusteringPassOptions &opts) {
9292
ClusteringPatternSet<ClusteringRewriter> patterns;
9393
for (const auto &[idx, target] : llvm::enumerate(clusters)) {
94-
if (target.getClusterKindName() == "tensorrt") {
94+
if (isa<TensorRTClusterKindAttr>(target)) {
9595
patterns.add(target.getClusterKindOptions(solver, opts.trtMajorVersion),
9696
createInlineGroupOp, isOpInClusterRegion,
9797
target.getClusterFilter(),

mlir-tensorrt/test/Dialect/StableHloExt/canonicalize-scatter-nd.mlir

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-tensorrt-opt %s --stablehlo-ext-canonicalize-scatter --stablehlo-aggressive-simplification -split-input-file | FileCheck %s
1+
// RUN: mlir-tensorrt-opt %s --stablehlo-ext-canonicalize-scatter -split-input-file | FileCheck %s
22

33

44
func.func @whisper_jax_scatter(%arg0: tensor<1x51865xf32>) -> tensor<1x51865xf32> {
@@ -22,22 +22,19 @@ func.func @whisper_jax_scatter(%arg0: tensor<1x51865xf32>) -> tensor<1x51865xf32
2222
}
2323

2424

25-
// CHECK-LABEL: @whisper_jax_scatter
26-
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>)
27-
// CHECK-DAG: %[[v0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<1x1xf32>
25+
// CHECK-LABEL: func.func @whisper_jax_scatter
26+
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>) -> tensor<1x51865xf32> {
2827
// CHECK-DAG: %[[cst:.+]] = arith.constant dense<50257> : tensor<1x1xi32>
29-
// CHECK: %[[v1:.+]] = stablehlo.reshape %[[arg0]]
30-
// CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v1]], %[[cst]], %[[v0]])
31-
// CHECK-SAME: indices_are_sorted = false
32-
// CHECK-SAME: #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
33-
34-
// CHECK-SAME: unique_indices = false
35-
// CHECK-NEXT: ^bb0(%[[arg1:.+]]: tensor<f32>, %[[arg2:.+]]: tensor<f32>):
36-
// CHECK-NEXT: stablehlo.return %[[arg2]] : tensor<f32>
37-
// CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
38-
// CHECK-SAME: : (tensor<51865x1xf32>, tensor<1x1xi32>, tensor<1x1xf32>) -> tensor<51865x1xf32>
39-
40-
// CHECK: %[[v3:.+]] = stablehlo.reshape
28+
// CHECK-DAG: %[[cst_0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<1x1x1xf32>
29+
// CHECK-DAG: %[[v0:.+]] = stablehlo.transpose %[[arg0]], dims = [1, 0] : (tensor<1x51865xf32>) -> tensor<51865x1xf32>
30+
// CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[cst_0]] : (tensor<1x1x1xf32>) -> tensor<1x1xf32>
31+
// CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v0]], %[[cst]], %[[v1]])
32+
// CHECK-SAME: <{indices_are_sorted = false,
33+
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
34+
// CHECK-SAME: inserted_window_dims = [0], scatter_dims_to_operand_dims = [0],
35+
// CHECK-SAME: index_vector_dim = 1>, unique_indices = false}> ({
36+
// CHECK: }) {tensorrt.canonicalized_scatter} : (tensor<51865x1xf32>, tensor<1x1xi32>, tensor<1x1xf32>) -> tensor<51865x1xf32>
37+
// CHECK: %[[v3:.+]] = stablehlo.transpose %[[v2]], dims = [1, 0] : (tensor<51865x1xf32>) -> tensor<1x51865xf32>
4138
// CHECK: return %[[v3]] : tensor<1x51865xf32>
4239

4340
// -----
@@ -60,20 +57,21 @@ func.func @whisper_jax_scatter2(%arg0: tensor<1x51865xf32>, %arg1: tensor<88x1xi
6057
return %3 : tensor<1x51865xf32>
6158
}
6259

63-
// CHECK-LABEL: @whisper_jax_scatter2
60+
// CHECK-LABEL: func.func @whisper_jax_scatter2
6461
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>, %[[arg1:.+]]: tensor<88x1xi32>) -> tensor<1x51865xf32> {
65-
// CHECK: %[[v0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<88x1xf32>
66-
// CHECK: %[[v1:.+]] = stablehlo.reshape
67-
// CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v1]], %[[arg1]], %[[v0]])
68-
// CHECK-SAME: indices_are_sorted = false
69-
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
70-
// CHECK-SAME: unique_indices = false
71-
// CHECK-NEXT: ^bb0(%[[arg2:.+]]: tensor<f32>, %[[arg3:.+]]: tensor<f32>):
72-
// CHECK-NEXT: stablehlo.return %[[arg3]] : tensor<f32>
73-
// CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
74-
// CHECK-SAME: (tensor<51865x1xf32>, tensor<88x1xi32>, tensor<88x1xf32>) -> tensor<51865x1xf32>
75-
// CHECK: %[[v3:.+]] = stablehlo.reshape
76-
// CHECK: return %[[v3]] : tensor<1x51865xf32>
62+
// CHECK: %[[cst:.+]] = stablehlo.constant dense<0xFF800000> : tensor<88x1x1xf32>
63+
// CHECK: %[[v0:.+]] = stablehlo.transpose %[[arg0]], dims = [1, 0] : (tensor<1x51865xf32>) -> tensor<51865x1xf32>
64+
// CHECK: %[[v1:.+]] = stablehlo.reshape %[[cst]] : (tensor<88x1x1xf32>) -> tensor<88x1xf32>
65+
// CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v0]], %[[arg1]], %[[v1]])
66+
// CHECK-SAME: <{indices_are_sorted = false,
67+
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
68+
// CHECK-SAME: inserted_window_dims = [0], scatter_dims_to_operand_dims = [0],
69+
// CHECK-SAME: index_vector_dim = 1>, unique_indices = false}>
70+
// CHECK: }) {tensorrt.canonicalized_scatter}
71+
// CHECK-SAME: : (tensor<51865x1xf32>, tensor<88x1xi32>, tensor<88x1xf32>) -> tensor<51865x1xf32>
72+
// CHECK: %[[v3:.+]] = stablehlo.transpose %[[v2]], dims = [1, 0] : (tensor<51865x1xf32>) -> tensor<1x51865xf32>
73+
// CHECK: return %[[v3]] : tensor<1x51865xf32>
74+
7775
// -----
7876

7977
func.func @stablehlo_scatter_canonicalize(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<2xi32>, %arg3: tensor<2x3xf32>, %arg4: tensor<2x3xf32>) -> tensor<3x3xf32> {
@@ -95,18 +93,22 @@ func.func @stablehlo_scatter_canonicalize(%arg0: tensor<3x3xf32>, %arg1: tensor<
9593
return %0#0 : tensor<3x3xf32>
9694
}
9795

98-
// CHECK-LABEL: @stablehlo_scatter_canonicalize
96+
// CHECK-LABEL: func.func @stablehlo_scatter_canonicalize
9997
// CHECK-SAME: (%[[arg0:.+]]: tensor<3x3xf32>, %[[arg1:.+]]: tensor<3x3xf32>, %[[arg2:.+]]: tensor<2xi32>, %[[arg3:.+]]: tensor<2x3xf32>, %[[arg4:.+]]: tensor<2x3xf32>) -> tensor<3x3xf32> {
100-
// CHECK: %[[v0:.+]] = stablehlo.reshape %[[arg2]] : (tensor<2xi32>) -> tensor<2x1xi32>
101-
// CHECK: %[[v2:.+]]:2 = "stablehlo.scatter"(%[[arg0]], %[[arg1]], %[[v0]], %[[arg3]], %[[arg4]])
102-
// CHECK-SAME: indices_are_sorted = false
103-
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
104-
// CHECK-SAME: unique_indices = false
105-
// CHECK-NEXT: ^bb0(%[[arg5:.+]]: tensor<f32>, %[[arg6:.+]]: tensor<f32>, %[[arg7:.+]]: tensor<f32>, %[[arg8:.+]]: tensor<f32>):
106-
// CHECK-NEXT: stablehlo.return %[[arg5]], %[[arg7]] : tensor<f32>, tensor<f32>
107-
// CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
108-
// CHECK-SAME: : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<2x1xi32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<3x3xf32>, tensor<3x3xf32>)
109-
// CHECK: return %[[v2]]#0 : tensor<3x3xf32>
98+
// CHECK-DAG: %[[v0:.+]] = stablehlo.reshape %[[arg2]] : (tensor<2xi32>) -> tensor<2x1xi32>
99+
// CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[arg3]] : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
100+
// CHECK-DAG: %[[v2:.+]] = stablehlo.reshape %[[arg4]] : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
101+
// CHECK-DAG: %[[v3:.+]] = stablehlo.reshape %[[v1]] : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
102+
// CHECK-DAG: %[[v4:.+]] = stablehlo.reshape %[[v2]] : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
103+
// CHECK: %[[v5:.+]]:2 = "stablehlo.scatter"(%[[arg0]], %[[arg1]], %[[v0]], %[[v3]], %[[v4]])
104+
// CHECK-SAME: <{indices_are_sorted = false,
105+
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
106+
// CHECK-SAME: inserted_window_dims = [0],
107+
// CHECK-SAME: scatter_dims_to_operand_dims = [0],
108+
// CHECK-SAME: index_vector_dim = 1>, unique_indices = false}>
109+
// CHECK: }) {tensorrt.canonicalized_scatter}
110+
// CHECK-SAME: : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<2x1xi32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<3x3xf32>, tensor<3x3xf32>)
111+
// CHECK: return %[[v5]]#0 : tensor<3x3xf32>
110112

111113
// -----
112114

0 commit comments

Comments
 (0)