Skip to content

Conversation

@Yu-Zhewen
Copy link
Contributor

@Yu-Zhewen Yu-Zhewen commented Dec 23, 2025

Mirroring #164438 and #164955

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2025

@llvm/pr-subscribers-mlir

Author: Zhewen Yu (Yu-Zhewen)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/173356.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp (+36)
  • (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+32)
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 5bb6259dd543d..658bb0394862f 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -31,6 +31,27 @@ struct CastOpInterface
   }
 };
 
+struct CollapseShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+                                                   CollapseShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto collapseOp = cast<CollapseShapeOp>(op);
+    assert(value == collapseOp.getResult() && "invalid value");
+
+    // Multiply the expressions for the dimensions in the reassociation group.
+    const ReassociationIndices reassocIndices =
+        collapseOp.getReassociationIndices()[dim];
+    AffineExpr productExpr =
+        cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+    for (size_t i = 1; i < reassocIndices.size(); ++i) {
+      productExpr =
+          productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+    }
+    cstr.bound(value)[dim] == productExpr;
+  }
+};
+
 struct DimOpInterface
     : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
   void populateBoundsForIndexValue(Operation *op, Value value,
@@ -57,6 +78,17 @@ struct EmptyOpInterface
   }
 };
 
+struct ExpandShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                   ExpandShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto expandOp = cast<ExpandShapeOp>(op);
+    assert(value == expandOp.getResult() && "invalid value");
+    cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
+  }
+};
+
 struct ExtractSliceOpInterface
     : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
                                                    ExtractSliceOp> {
@@ -117,8 +149,12 @@ void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
     tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
+    tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
+        *ctx);
     tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
     tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
+    tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
+        *ctx);
     tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
         *ctx);
     tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index e526adb18cf1e..864f8c97a1ee4 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -230,3 +230,35 @@ func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
   %1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
   return
 }
+
+// -----
+
+//       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @tensor_collapse(
+//  CHECK-SAME:     %[[sz0:.*]]: index
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[c12:.*]] = arith.constant 12 : index
+//       CHECK:   %[[dim:.*]] = tensor.dim %{{.*}}, %[[c2]] : tensor<3x4x?x2xf32>
+//       CHECK:   %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+//       CHECK:   return %[[c12]], %[[mul]]
+func.func @tensor_collapse(%sz0: index) -> (index, index) {
+  %0 = tensor.empty(%sz0) : tensor<3x4x?x2xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<3x4x?x2xf32> into tensor<12x?xf32>
+  %2 = "test.reify_bound"(%1) {dim = 0} : (tensor<12x?xf32>) -> (index)
+  %3 = "test.reify_bound"(%1) {dim = 1} : (tensor<12x?xf32>) -> (index)
+  return %2, %3 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_expand(
+//  CHECK-SAME:     %[[t:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[c4]], %[[sz]]
+func.func @tensor_expand(%t: tensor<?xf32>, %sz: index) -> (index, index) {
+  %0 = tensor.expand_shape %t [[0, 1]] output_shape [4, %sz] : tensor<?xf32> into tensor<4x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<4x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<4x?xf32>) -> (index)
+  return %1, %2 : index, index
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Zhewen Yu (Yu-Zhewen)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/173356.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp (+36)
  • (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+32)
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 5bb6259dd543d..658bb0394862f 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -31,6 +31,27 @@ struct CastOpInterface
   }
 };
 
+struct CollapseShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+                                                   CollapseShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto collapseOp = cast<CollapseShapeOp>(op);
+    assert(value == collapseOp.getResult() && "invalid value");
+
+    // Multiply the expressions for the dimensions in the reassociation group.
+    const ReassociationIndices reassocIndices =
+        collapseOp.getReassociationIndices()[dim];
+    AffineExpr productExpr =
+        cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+    for (size_t i = 1; i < reassocIndices.size(); ++i) {
+      productExpr =
+          productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+    }
+    cstr.bound(value)[dim] == productExpr;
+  }
+};
+
 struct DimOpInterface
     : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
   void populateBoundsForIndexValue(Operation *op, Value value,
@@ -57,6 +78,17 @@ struct EmptyOpInterface
   }
 };
 
+struct ExpandShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                   ExpandShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto expandOp = cast<ExpandShapeOp>(op);
+    assert(value == expandOp.getResult() && "invalid value");
+    cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
+  }
+};
+
 struct ExtractSliceOpInterface
     : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
                                                    ExtractSliceOp> {
@@ -117,8 +149,12 @@ void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
     tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
+    tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
+        *ctx);
     tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
     tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
+    tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
+        *ctx);
     tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
         *ctx);
     tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index e526adb18cf1e..864f8c97a1ee4 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -230,3 +230,35 @@ func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
   %1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
   return
 }
+
+// -----
+
+//       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @tensor_collapse(
+//  CHECK-SAME:     %[[sz0:.*]]: index
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[c12:.*]] = arith.constant 12 : index
+//       CHECK:   %[[dim:.*]] = tensor.dim %{{.*}}, %[[c2]] : tensor<3x4x?x2xf32>
+//       CHECK:   %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+//       CHECK:   return %[[c12]], %[[mul]]
+func.func @tensor_collapse(%sz0: index) -> (index, index) {
+  %0 = tensor.empty(%sz0) : tensor<3x4x?x2xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<3x4x?x2xf32> into tensor<12x?xf32>
+  %2 = "test.reify_bound"(%1) {dim = 0} : (tensor<12x?xf32>) -> (index)
+  %3 = "test.reify_bound"(%1) {dim = 1} : (tensor<12x?xf32>) -> (index)
+  return %2, %3 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_expand(
+//  CHECK-SAME:     %[[t:[a-zA-Z0-9]+]]: tensor<?xf32>
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[c4]], %[[sz]]
+func.func @tensor_expand(%t: tensor<?xf32>, %sz: index) -> (index, index) {
+  %0 = tensor.expand_shape %t [[0, 1]] output_shape [4, %sz] : tensor<?xf32> into tensor<4x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<4x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<4x?xf32>) -> (index)
+  return %1, %2 : index, index
+}

Signed-off-by: Yu-Zhewen <[email protected]>
assert(value == collapseOp.getResult() && "invalid value");

// Multiply the expressions for the dimensions in the reassociation group.
const ReassociationIndices reassocIndices =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const reference

Signed-off-by: Yu-Zhewen <[email protected]>
@Yu-Zhewen
Copy link
Contributor Author

@matthias-springer, could you help merge if ok?

@matthias-springer matthias-springer merged commit 5154a05 into llvm:main Dec 24, 2025
10 checks passed
valadaptive pushed a commit to valadaptive/llvm-project that referenced this pull request Dec 24, 2025
@cota
Copy link
Contributor

cota commented Dec 25, 2025

This broke the address sanitizer build -- reverting now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants