Skip to content

[mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern. #145995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jun 26, 2025

The expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op.

The revision does not use affine ops because it introduces circular dependency, because the affine dialect is depending on memref dialect today.

… pattern.

The expand_shape op takes dynamic output value, and we need to take it
into account when we compose the op. Otherwise, it fails to create the
new expand_shape op.

Signed-off-by: hanhanW <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Han-Chung Wang (hanhanW)

Changes

The expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op.


Full diff: https://github.com/llvm/llvm-project/pull/145995.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+36-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+18)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+18)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       rewriter.replaceOpWithNewOp<CollapseOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
     } else if (srcRank < resultRank) {
+      // Compute the dynamic output shape for the new expand_shape op.
+      Location loc = collapseOp.getLoc();
+      SmallVector<OpFoldResult> origOutputShape =
+          expandOp.getMixedOutputShape();
+      SmallVector<OpFoldResult> newOutputShape;
+      for (auto indices : collapseOp.getReassociationIndices()) {
+        int64_t numStaticElems = 1;
+        SmallVector<Value> dynamicSizes;
+        for (auto idx : indices) {
+          OpFoldResult size = origOutputShape[idx];
+          if (auto maybeCst = getConstantIntValue(size)) {
+            numStaticElems *= maybeCst.value();
+            continue;
+          }
+          dynamicSizes.push_back(cast<Value>(size));
+        }
+        if (dynamicSizes.empty()) {
+          newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+          continue;
+        }
+
+        // There is at least one dynamic size, so we can intialize `result` to
+        // the first dynamic size.
+        Value result = dynamicSizes[0];
+        for (auto v : llvm::drop_begin(dynamicSizes))
+          result = rewriter.create<arith::MulIOp>(loc, result, v);
+        if (numStaticElems != 1) {
+          result = rewriter.create<arith::MulIOp>(
+              loc, result,
+              rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+        }
+        newOutputShape.push_back(result);
+      }
       rewriter.replaceOpWithNewOp<ExpandOpTy>(
-          collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+          collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+          newOutputShape);
     } else {
       // Collapses/expansions that do not change the rank are not allowed. Use
       // a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+  %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+  %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+  return %collapsed : memref<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @do_not_compose_collapse_of_expand_non_identity_layout(
     %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
     -> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+  %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+  %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+  return %collapsed : tensor<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
     -> tensor<1x1x1x1xf32> {
   %0 = tensor.collapse_shape %arg0 []

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir-memref

Author: Han-Chung Wang (hanhanW)

Changes

The expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op.


Full diff: https://github.com/llvm/llvm-project/pull/145995.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+36-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+18)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+18)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       rewriter.replaceOpWithNewOp<CollapseOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
     } else if (srcRank < resultRank) {
+      // Compute the dynamic output shape for the new expand_shape op.
+      Location loc = collapseOp.getLoc();
+      SmallVector<OpFoldResult> origOutputShape =
+          expandOp.getMixedOutputShape();
+      SmallVector<OpFoldResult> newOutputShape;
+      for (auto indices : collapseOp.getReassociationIndices()) {
+        int64_t numStaticElems = 1;
+        SmallVector<Value> dynamicSizes;
+        for (auto idx : indices) {
+          OpFoldResult size = origOutputShape[idx];
+          if (auto maybeCst = getConstantIntValue(size)) {
+            numStaticElems *= maybeCst.value();
+            continue;
+          }
+          dynamicSizes.push_back(cast<Value>(size));
+        }
+        if (dynamicSizes.empty()) {
+          newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+          continue;
+        }
+
+        // There is at least one dynamic size, so we can intialize `result` to
+        // the first dynamic size.
+        Value result = dynamicSizes[0];
+        for (auto v : llvm::drop_begin(dynamicSizes))
+          result = rewriter.create<arith::MulIOp>(loc, result, v);
+        if (numStaticElems != 1) {
+          result = rewriter.create<arith::MulIOp>(
+              loc, result,
+              rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+        }
+        newOutputShape.push_back(result);
+      }
       rewriter.replaceOpWithNewOp<ExpandOpTy>(
-          collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+          collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+          newOutputShape);
     } else {
       // Collapses/expansions that do not change the rank are not allowed. Use
       // a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+  %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+  %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+  return %collapsed : memref<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @do_not_compose_collapse_of_expand_non_identity_layout(
     %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
     -> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+  %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+  %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+  return %collapsed : tensor<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
     -> tensor<1x1x1x1xf32> {
   %0 = tensor.collapse_shape %arg0 []

@MaheshRavishankar
Copy link
Contributor

I dont know, this canonicalization pattern has always seemed broken to me (I actually do not follow this logic at all). I have been trying to adapt

struct BubbleUpExpandThroughParallelCollapse
more too see if it can cover everything covered by the canonicalization which I understand a bit more. Just a FYI more than anything.

Signed-off-by: hanhanW <[email protected]>
@hanhanW hanhanW merged commit 0515449 into llvm:main Jun 27, 2025
7 checks passed
@hanhanW hanhanW deleted the fix-collapse-of-expand-pattern branch June 27, 2025 02:39
rupprecht added a commit to rupprecht/llvm-project that referenced this pull request Jun 27, 2025
rupprecht added a commit that referenced this pull request Jun 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants