-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[mlir][tensor] Add ValueBoundsOpInterface for ExpandShapeOp and CollapseShapeOp #173356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Add ValueBoundsOpInterface for ExpandShapeOp and CollapseShapeOp #173356
Conversation
|
@llvm/pr-subscribers-mlir Author: Zhewen Yu (Yu-Zhewen) ChangesFull diff: https://github.com/llvm/llvm-project/pull/173356.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 5bb6259dd543d..658bb0394862f 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -31,6 +31,27 @@ struct CastOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct DimOpInterface
: public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
@@ -57,6 +78,17 @@ struct EmptyOpInterface
}
};
+struct ExpandShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+ ExpandShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto expandOp = cast<ExpandShapeOp>(op);
+ assert(value == expandOp.getResult() && "invalid value");
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
+ }
+};
+
struct ExtractSliceOpInterface
: public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
ExtractSliceOp> {
@@ -117,8 +149,12 @@ void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
+ tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
+ *ctx);
tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
+ tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
+ *ctx);
tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
*ctx);
tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index e526adb18cf1e..864f8c97a1ee4 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -230,3 +230,35 @@ func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
%1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
return
}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @tensor_collapse(
+// CHECK-SAME: %[[sz0:.*]]: index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
+// CHECK: %[[dim:.*]] = tensor.dim %{{.*}}, %[[c2]] : tensor<3x4x?x2xf32>
+// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+// CHECK: return %[[c12]], %[[mul]]
+func.func @tensor_collapse(%sz0: index) -> (index, index) {
+ %0 = tensor.empty(%sz0) : tensor<3x4x?x2xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<3x4x?x2xf32> into tensor<12x?xf32>
+ %2 = "test.reify_bound"(%1) {dim = 0} : (tensor<12x?xf32>) -> (index)
+ %3 = "test.reify_bound"(%1) {dim = 1} : (tensor<12x?xf32>) -> (index)
+ return %2, %3 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_expand(
+// CHECK-SAME: %[[t:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index
+// CHECK: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: return %[[c4]], %[[sz]]
+func.func @tensor_expand(%t: tensor<?xf32>, %sz: index) -> (index, index) {
+ %0 = tensor.expand_shape %t [[0, 1]] output_shape [4, %sz] : tensor<?xf32> into tensor<4x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<4x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<4x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
|
|
@llvm/pr-subscribers-mlir-tensor Author: Zhewen Yu (Yu-Zhewen) ChangesFull diff: https://github.com/llvm/llvm-project/pull/173356.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 5bb6259dd543d..658bb0394862f 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -31,6 +31,27 @@ struct CastOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct DimOpInterface
: public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
@@ -57,6 +78,17 @@ struct EmptyOpInterface
}
};
+struct ExpandShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+ ExpandShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto expandOp = cast<ExpandShapeOp>(op);
+ assert(value == expandOp.getResult() && "invalid value");
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
+ }
+};
+
struct ExtractSliceOpInterface
: public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
ExtractSliceOp> {
@@ -117,8 +149,12 @@ void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
+ tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
+ *ctx);
tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
+ tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
+ *ctx);
tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
*ctx);
tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index e526adb18cf1e..864f8c97a1ee4 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -230,3 +230,35 @@ func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
%1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
return
}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @tensor_collapse(
+// CHECK-SAME: %[[sz0:.*]]: index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
+// CHECK: %[[dim:.*]] = tensor.dim %{{.*}}, %[[c2]] : tensor<3x4x?x2xf32>
+// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+// CHECK: return %[[c12]], %[[mul]]
+func.func @tensor_collapse(%sz0: index) -> (index, index) {
+ %0 = tensor.empty(%sz0) : tensor<3x4x?x2xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<3x4x?x2xf32> into tensor<12x?xf32>
+ %2 = "test.reify_bound"(%1) {dim = 0} : (tensor<12x?xf32>) -> (index)
+ %3 = "test.reify_bound"(%1) {dim = 1} : (tensor<12x?xf32>) -> (index)
+ return %2, %3 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_expand(
+// CHECK-SAME: %[[t:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index
+// CHECK: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: return %[[c4]], %[[sz]]
+func.func @tensor_expand(%t: tensor<?xf32>, %sz: index) -> (index, index) {
+ %0 = tensor.expand_shape %t [[0, 1]] output_shape [4, %sz] : tensor<?xf32> into tensor<4x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<4x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<4x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
|
Signed-off-by: Yu-Zhewen <[email protected]>
| assert(value == collapseOp.getResult() && "invalid value"); | ||
|
|
||
| // Multiply the expressions for the dimensions in the reassociation group. | ||
| const ReassociationIndices reassocIndices = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: const reference
Signed-off-by: Yu-Zhewen <[email protected]>
|
@matthias-springer, could you help merge if ok? |
…pseShapeOp (llvm#173356) Mirroring llvm#164438 and llvm#164955 --------- Signed-off-by: Yu-Zhewen <[email protected]>
|
This broke the address sanitizer build -- reverting now. |
…nd CollapseShapeOp (#173356)" This reverts commit 5154a05. It broke sanitizer build bots -- see https://lab.llvm.org/buildbot/#/builders/52/builds/13831
Mirroring #164438 and #164955