-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern. #145995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… pattern. The expand_shape op takes dynamic output value, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op. Signed-off-by: hanhanW <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Han-Chung Wang (hanhanW) ChangesThe expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op. Full diff: https://github.com/llvm/llvm-project/pull/145995.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else if (srcRank < resultRank) {
+ // Compute the dynamic output shape for the new expand_shape op.
+ Location loc = collapseOp.getLoc();
+ SmallVector<OpFoldResult> origOutputShape =
+ expandOp.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape;
+ for (auto indices : collapseOp.getReassociationIndices()) {
+ int64_t numStaticElems = 1;
+ SmallVector<Value> dynamicSizes;
+ for (auto idx : indices) {
+ OpFoldResult size = origOutputShape[idx];
+ if (auto maybeCst = getConstantIntValue(size)) {
+ numStaticElems *= maybeCst.value();
+ continue;
+ }
+ dynamicSizes.push_back(cast<Value>(size));
+ }
+ if (dynamicSizes.empty()) {
+ newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+ continue;
+ }
+
+ // There is at least one dynamic size, so we can intialize `result` to
+ // the first dynamic size.
+ Value result = dynamicSizes[0];
+ for (auto v : llvm::drop_begin(dynamicSizes))
+ result = rewriter.create<arith::MulIOp>(loc, result, v);
+ if (numStaticElems != 1) {
+ result = rewriter.create<arith::MulIOp>(
+ loc, result,
+ rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+ }
+ newOutputShape.push_back(result);
+ }
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+ collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+ newOutputShape);
} else {
// Collapses/expansions that do not change the rank are not allowed. Use
// a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+ %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+ %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+ return %collapsed : memref<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
-> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+ %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+ %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+ return %collapsed : tensor<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
-> tensor<1x1x1x1xf32> {
%0 = tensor.collapse_shape %arg0 []
|
@llvm/pr-subscribers-mlir-memref Author: Han-Chung Wang (hanhanW) ChangesThe expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op. Full diff: https://github.com/llvm/llvm-project/pull/145995.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else if (srcRank < resultRank) {
+ // Compute the dynamic output shape for the new expand_shape op.
+ Location loc = collapseOp.getLoc();
+ SmallVector<OpFoldResult> origOutputShape =
+ expandOp.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape;
+ for (auto indices : collapseOp.getReassociationIndices()) {
+ int64_t numStaticElems = 1;
+ SmallVector<Value> dynamicSizes;
+ for (auto idx : indices) {
+ OpFoldResult size = origOutputShape[idx];
+ if (auto maybeCst = getConstantIntValue(size)) {
+ numStaticElems *= maybeCst.value();
+ continue;
+ }
+ dynamicSizes.push_back(cast<Value>(size));
+ }
+ if (dynamicSizes.empty()) {
+ newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+ continue;
+ }
+
+ // There is at least one dynamic size, so we can intialize `result` to
+ // the first dynamic size.
+ Value result = dynamicSizes[0];
+ for (auto v : llvm::drop_begin(dynamicSizes))
+ result = rewriter.create<arith::MulIOp>(loc, result, v);
+ if (numStaticElems != 1) {
+ result = rewriter.create<arith::MulIOp>(
+ loc, result,
+ rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+ }
+ newOutputShape.push_back(result);
+ }
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+ collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+ newOutputShape);
} else {
// Collapses/expansions that do not change the rank are not allowed. Use
// a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+ %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+ %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+ return %collapsed : memref<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
-> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+ %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+ %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+ return %collapsed : tensor<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
-> tensor<1x1x1x1xf32> {
%0 = tensor.collapse_shape %arg0 []
|
I dont know, this canonicalization pattern has always seemed broken to me (I actually do not follow this logic at all). I have been trying to adapt
|
Signed-off-by: hanhanW <[email protected]>
The expand_shape op takes mixed values for output shape, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op.
The revision does not use affine ops because it introduces circular dependency, because the affine dialect is depending on memref dialect today.