-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] shape_cast(constant) -> constant fold for non-splats #145539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesThe folder Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants: func.func @<!-- -->foo() -> (vector<4xi32>, vector<2x2xi32>) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
%0 = vector.shape_cast %cst : vector<4xi32> to vector<2x2xi32>
return %cst, %0 : vector<4xi32>, vector<2x2xi32>
} gets folded with this new folder to func.func @<!-- -->foo() -> (vector<4xi32>, vector<2x2xi32>) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
%cst_0 = arith.constant dense<[[1, 2], [3, 4]]> : vector<2x2xi32> # 'large' constant 2
return %cst, %cst_0 : vector<4xi32>, vector<2x2xi32>
} Notes on the above case:
Full diff: https://github.com/llvm/llvm-project/pull/145539.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..ddc80063fd340 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5881,14 +5881,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
}
// shape_cast(constant) -> constant
- if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ if (auto denseAttr =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+ return denseAttr.reshape(getType());
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..a06a98ee1b93b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1219,11 +1219,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1233,6 +1233,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -1549,7 +1579,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
- ) -> vector<4x2xf32>
+ ) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1636,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
- -> vector<4x2x6xf32>
+ -> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
|
The folder
shape_cast(splat constant) -> splat constant
was first introduced here (Nov 2020). In that commit there is a comment to Only handle splat for now. Based on that I assume the intention was to, at a later time, support a generalshape_cast(constant) -> constant
folder. That is what this PR doesPotential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:
gets folded with this new folder to
Notes on the above case:
DenseIntOrFPElementsAttrStorage
constructor) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.