1- // RUN: mlir-tensorrt-opt %s --stablehlo-ext-canonicalize-scatter --stablehlo-aggressive-simplification - split-input-file | FileCheck %s
1+ // RUN: mlir-tensorrt-opt %s --stablehlo-ext-canonicalize-scatter -split-input-file | FileCheck %s
22
33
44func.func @whisper_jax_scatter (%arg0: tensor <1 x51865 xf32 >) -> tensor <1 x51865 xf32 > {
@@ -22,22 +22,19 @@ func.func @whisper_jax_scatter(%arg0: tensor<1x51865xf32>) -> tensor<1x51865xf32
2222}
2323
2424
25- // CHECK-LABEL: @whisper_jax_scatter
26- // CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>)
27- // CHECK-DAG: %[[v0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<1x1xf32>
25+ // CHECK-LABEL: func.func @whisper_jax_scatter
26+ // CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>) -> tensor<1x51865xf32> {
2827// CHECK-DAG: %[[cst:.+]] = arith.constant dense<50257> : tensor<1x1xi32>
29- // CHECK: %[[v1:.+]] = stablehlo.reshape %[[arg0]]
30- // CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v1]], %[[cst]], %[[v0]])
31- // CHECK-SAME: indices_are_sorted = false
32- // CHECK-SAME: #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
33-
34- // CHECK-SAME: unique_indices = false
35- // CHECK-NEXT: ^bb0(%[[arg1:.+]]: tensor<f32>, %[[arg2:.+]]: tensor<f32>):
36- // CHECK-NEXT: stablehlo.return %[[arg2]] : tensor<f32>
37- // CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
38- // CHECK-SAME: : (tensor<51865x1xf32>, tensor<1x1xi32>, tensor<1x1xf32>) -> tensor<51865x1xf32>
39-
40- // CHECK: %[[v3:.+]] = stablehlo.reshape
28+ // CHECK-DAG: %[[cst_0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<1x1x1xf32>
29+ // CHECK-DAG: %[[v0:.+]] = stablehlo.transpose %[[arg0]], dims = [1, 0] : (tensor<1x51865xf32>) -> tensor<51865x1xf32>
30+ // CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[cst_0]] : (tensor<1x1x1xf32>) -> tensor<1x1xf32>
31+ // CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v0]], %[[cst]], %[[v1]])
32+ // CHECK-SAME: <{indices_are_sorted = false,
33+ // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
34+ // CHECK-SAME: inserted_window_dims = [0], scatter_dims_to_operand_dims = [0],
35+ // CHECK-SAME: index_vector_dim = 1>, unique_indices = false}> ({
36+ // CHECK: }) {tensorrt.canonicalized_scatter} : (tensor<51865x1xf32>, tensor<1x1xi32>, tensor<1x1xf32>) -> tensor<51865x1xf32>
37+ // CHECK: %[[v3:.+]] = stablehlo.transpose %[[v2]], dims = [1, 0] : (tensor<51865x1xf32>) -> tensor<1x51865xf32>
4138// CHECK: return %[[v3]] : tensor<1x51865xf32>
4239
4340// -----
@@ -60,20 +57,21 @@ func.func @whisper_jax_scatter2(%arg0: tensor<1x51865xf32>, %arg1: tensor<88x1xi
6057 return %3 : tensor <1 x51865 xf32 >
6158}
6259
63- // CHECK-LABEL: @whisper_jax_scatter2
60+ // CHECK-LABEL: func.func @whisper_jax_scatter2
6461// CHECK-SAME: (%[[arg0:.+]]: tensor<1x51865xf32>, %[[arg1:.+]]: tensor<88x1xi32>) -> tensor<1x51865xf32> {
65- // CHECK: %[[v0:.+]] = stablehlo.constant dense<0xFF800000> : tensor<88x1xf32>
66- // CHECK: %[[v1:.+]] = stablehlo.reshape
67- // CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v1]], %[[arg1]], %[[v0]])
68- // CHECK-SAME: indices_are_sorted = false
69- // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
70- // CHECK-SAME: unique_indices = false
71- // CHECK-NEXT: ^bb0(%[[arg2:.+]]: tensor<f32>, %[[arg3:.+]]: tensor<f32>):
72- // CHECK-NEXT: stablehlo.return %[[arg3]] : tensor<f32>
73- // CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
74- // CHECK-SAME: (tensor<51865x1xf32>, tensor<88x1xi32>, tensor<88x1xf32>) -> tensor<51865x1xf32>
75- // CHECK: %[[v3:.+]] = stablehlo.reshape
76- // CHECK: return %[[v3]] : tensor<1x51865xf32>
62+ // CHECK: %[[cst:.+]] = stablehlo.constant dense<0xFF800000> : tensor<88x1x1xf32>
63+ // CHECK: %[[v0:.+]] = stablehlo.transpose %[[arg0]], dims = [1, 0] : (tensor<1x51865xf32>) -> tensor<51865x1xf32>
64+ // CHECK: %[[v1:.+]] = stablehlo.reshape %[[cst]] : (tensor<88x1x1xf32>) -> tensor<88x1xf32>
65+ // CHECK: %[[v2:.+]] = "stablehlo.scatter"(%[[v0]], %[[arg1]], %[[v1]])
66+ // CHECK-SAME: <{indices_are_sorted = false,
67+ // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
68+ // CHECK-SAME: inserted_window_dims = [0], scatter_dims_to_operand_dims = [0],
69+ // CHECK-SAME: index_vector_dim = 1>, unique_indices = false}>
70+ // CHECK: }) {tensorrt.canonicalized_scatter}
71+ // CHECK-SAME: : (tensor<51865x1xf32>, tensor<88x1xi32>, tensor<88x1xf32>) -> tensor<51865x1xf32>
72+ // CHECK: %[[v3:.+]] = stablehlo.transpose %[[v2]], dims = [1, 0] : (tensor<51865x1xf32>) -> tensor<1x51865xf32>
73+ // CHECK: return %[[v3]] : tensor<1x51865xf32>
74+
7775// -----
7876
7977func.func @stablehlo_scatter_canonicalize (%arg0: tensor <3 x3 xf32 >, %arg1: tensor <3 x3 xf32 >, %arg2: tensor <2 xi32 >, %arg3: tensor <2 x3 xf32 >, %arg4: tensor <2 x3 xf32 >) -> tensor <3 x3 xf32 > {
@@ -95,18 +93,22 @@ func.func @stablehlo_scatter_canonicalize(%arg0: tensor<3x3xf32>, %arg1: tensor<
9593 return %0#0 : tensor <3 x3 xf32 >
9694}
9795
98- // CHECK-LABEL: @stablehlo_scatter_canonicalize
96+ // CHECK-LABEL: func.func @stablehlo_scatter_canonicalize
9997// CHECK-SAME: (%[[arg0:.+]]: tensor<3x3xf32>, %[[arg1:.+]]: tensor<3x3xf32>, %[[arg2:.+]]: tensor<2xi32>, %[[arg3:.+]]: tensor<2x3xf32>, %[[arg4:.+]]: tensor<2x3xf32>) -> tensor<3x3xf32> {
100- // CHECK: %[[v0:.+]] = stablehlo.reshape %[[arg2]] : (tensor<2xi32>) -> tensor<2x1xi32>
101- // CHECK: %[[v2:.+]]:2 = "stablehlo.scatter"(%[[arg0]], %[[arg1]], %[[v0]], %[[arg3]], %[[arg4]])
102- // CHECK-SAME: indices_are_sorted = false
103- // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>
104- // CHECK-SAME: unique_indices = false
105- // CHECK-NEXT: ^bb0(%[[arg5:.+]]: tensor<f32>, %[[arg6:.+]]: tensor<f32>, %[[arg7:.+]]: tensor<f32>, %[[arg8:.+]]: tensor<f32>):
106- // CHECK-NEXT: stablehlo.return %[[arg5]], %[[arg7]] : tensor<f32>, tensor<f32>
107- // CHECK-NEXT: }) {tensorrt.canonicalized_scatter}
108- // CHECK-SAME: : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<2x1xi32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<3x3xf32>, tensor<3x3xf32>)
109- // CHECK: return %[[v2]]#0 : tensor<3x3xf32>
98+ // CHECK-DAG: %[[v0:.+]] = stablehlo.reshape %[[arg2]] : (tensor<2xi32>) -> tensor<2x1xi32>
99+ // CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[arg3]] : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
100+ // CHECK-DAG: %[[v2:.+]] = stablehlo.reshape %[[arg4]] : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
101+ // CHECK-DAG: %[[v3:.+]] = stablehlo.reshape %[[v1]] : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
102+ // CHECK-DAG: %[[v4:.+]] = stablehlo.reshape %[[v2]] : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
103+ // CHECK: %[[v5:.+]]:2 = "stablehlo.scatter"(%[[arg0]], %[[arg1]], %[[v0]], %[[v3]], %[[v4]])
104+ // CHECK-SAME: <{indices_are_sorted = false,
105+ // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1],
106+ // CHECK-SAME: inserted_window_dims = [0],
107+ // CHECK-SAME: scatter_dims_to_operand_dims = [0],
108+ // CHECK-SAME: index_vector_dim = 1>, unique_indices = false}>
109+ // CHECK: }) {tensorrt.canonicalized_scatter}
110+ // CHECK-SAME: : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<2x1xi32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<3x3xf32>, tensor<3x3xf32>)
111+ // CHECK: return %[[v5]]#0 : tensor<3x3xf32>
110112
111113// -----
112114
0 commit comments