Skip to content

[mlir][vector] shape_cast(constant) -> constant fold for non-splats #145539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented Jun 24, 2025

The folder shape_cast(splat constant) -> splat constant was first introduced here (Nov 2020). In that commit there is a comment to Only handle splat for now. Based on that I assume the intention was to, at a later time, support a general shape_cast(constant) -> constant folder. That is what this PR does

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

  func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) {
    %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
    %0 = vector.shape_cast %cst : vector<4xi32> to vector<2x2xi32>
    return %cst, %0 : vector<4xi32>, vector<2x2xi32>
  }

gets folded with this new folder to

   func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) {
    %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
    %cst_0 = arith.constant dense<[[1, 2], [3, 4]]> : vector<2x2xi32> # 'large' constant 2
    return %cst, %cst_0 : vector<4xi32>, vector<2x2xi32>
  }

Notes on the above case:

  1. This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the DenseIntOrFPElementsAttrStorage constructor) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.
  2. This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

The folder shape_cast(splat constant) -&gt; splat constant was first introduced here (Nov 2020). In that commit there is a comment to Only handle splat for now. Based on that I assume the intention was to, at a later time, support a general shape_cast(constant) -&gt; constant folder. That is what this PR does

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

  func.func @<!-- -->foo() -&gt; (vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;) {
    %cst = arith.constant dense&lt;[1, 2, 3, 4]&gt; : vector&lt;4xi32&gt; # 'large' constant 1
    %0 = vector.shape_cast %cst : vector&lt;4xi32&gt; to vector&lt;2x2xi32&gt;
    return %cst, %0 : vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;
  }

gets folded with this new folder to

   func.func @<!-- -->foo() -&gt; (vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;) {
    %cst = arith.constant dense&lt;[1, 2, 3, 4]&gt; : vector&lt;4xi32&gt; # 'large' constant 1
    %cst_0 = arith.constant dense&lt;[[1, 2], [3, 4]]&gt; : vector&lt;2x2xi32&gt; # 'large' constant 2
    return %cst, %cst_0 : vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;
  }

Notes on the above case:

  1. This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the DenseIntOrFPElementsAttrStorage constructor) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.
  2. This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.

Full diff: https://github.com/llvm/llvm-project/pull/145539.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4-5)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+34-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..ddc80063fd340 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5881,14 +5881,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   }
 
   // shape_cast(constant) -> constant
-  if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
-    return splatAttr.reshape(getType());
+  if (auto denseAttr =
+          dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+    return denseAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
-  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
     return ub::PoisonAttr::get(getContext());
-  }
 
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..a06a98ee1b93b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1219,11 +1219,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
 
 // -----
 
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
 //       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
   %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
   %cst_1 = arith.constant dense<1> : vector<12x2xi32>
   %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1233,6 +1233,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+//               CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+//               CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+  %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+  %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+  return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+//  CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+//  CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+//      CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+  %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+  %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+  return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: shape_cast_poison
 //       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -1549,7 +1579,7 @@ func.func @negative_store_to_load_tensor_memref(
     %arg0 : tensor<?x?xf32>,
     %arg1 : memref<?x?xf32>,
     %v0 : vector<4x2xf32>
-  ) -> vector<4x2xf32> 
+  ) -> vector<4x2xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1636,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
 //       CHECK:   vector.transfer_read
 func.func @negative_store_to_load_tensor_broadcast_masked(
     %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-  -> vector<4x2x6xf32> 
+  -> vector<4x2x6xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants