Skip to content

Conversation

arnab-polymage
Copy link
Contributor

Enhance unionBoundingBox utility to work with input constraints having local variables.

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2025

@llvm/pr-subscribers-mlir

Author: Arnab Dutta (arnab-polymage)

Changes

Enhance unionBoundingBox utility to work with input constraints having local variables.


Full diff: https://github.com/llvm/llvm-project/pull/128709.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Analysis/FlatLinearValueConstraints.h (+4-5)
  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+11-13)
  • (modified) mlir/lib/Analysis/FlatLinearValueConstraints.cpp (-2)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+18-28)
  • (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+14)
diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index c8167014b5300..15387201affa8 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -474,11 +474,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
   bool areVarsAlignedWithOther(const FlatLinearConstraints &other);
 
   /// Updates the constraints to be the smallest bounding (enclosing) box that
-  /// contains the points of `this` set and that of `other`, with the symbols
-  /// being treated specially. For each of the dimensions, the min of the lower
-  /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
-  /// to determine such a bounding box. `other` is expected to have the same
-  /// dimensional variables as this constraint system (in the same order).
+  /// contains the points of `this` set and that of `other`. For each of the
+  /// dimensions, the min of the lower bounds and the max of the upper bounds is
+  /// computed to determine such a bounding box. `other` is expected to have the
+  /// same dimensional variables as this constraint system (in the same order).
   ///
   /// E.g.:
   /// 1) this   = {0 <= d0 <= 127},
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index ddc18038e869c..ae45743ecc1be 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -489,11 +489,10 @@ class IntegerRelation {
   void constantFoldVarRange(unsigned pos, unsigned num);
 
   /// Updates the constraints to be the smallest bounding (enclosing) box that
-  /// contains the points of `this` set and that of `other`, with the symbols
-  /// being treated specially. For each of the dimensions, the min of the lower
-  /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
-  /// to determine such a bounding box. `other` is expected to have the same
-  /// dimensional variables as this constraint system (in the same order).
+  /// contains the points of `this` set and that of `other`. For each of the
+  /// dimensions, the min of the lower bounds and the max of the upper bounds is
+  /// computed to determine such a bounding box. `other` is expected to have the
+  /// same dimensional variables as this constraint system (in the same order).
   ///
   /// E.g.:
   /// 1) this   = {0 <= d0 <= 127},
@@ -512,14 +511,13 @@ class IntegerRelation {
   /// than or equal to 'exclusive upper bound' - 'lower bound' of the
   /// variable. This constant bound is guaranteed to be non-negative. Returns
   /// std::nullopt if it's not a constant. This method employs trivial (low
-  /// complexity / cost) checks and detection. Symbolic variables are treated
-  /// specially, i.e., it looks for constant differences between affine
-  /// expressions involving only the symbolic variables. `lb` and `ub` (along
-  /// with the `boundFloorDivisor`) are set to represent the lower and upper
-  /// bound associated with the constant difference: `lb`, `ub` have the
-  /// coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` and
-  /// `minUbPos` if non-null are set to the position of the constant lower bound
-  /// and upper bound respectively (to the same if they are from an
+  /// complexity / cost) checks and detection. It looks for constant differences
+  /// between affine expressions involving symbolic and local variables. `lb`
+  /// and `ub` (along with the `boundFloorDivisor`) are set to represent the
+  /// lower and upper bound associated with the constant difference: `lb`, `ub`
+  /// have the coefficients, and `boundFloorDivisor`, their divisor. `minLbPos`
+  /// and `minUbPos` if non-null are set to the position of the constant lower
+  /// bound and upper bound respectively (to the same if they are from an
   /// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a
   /// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See
   /// comments at function definition for examples.
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 4653eca9887ce..ae9f9acd89c2e 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1303,8 +1303,6 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
                     otherMaybeValues.begin(),
                     otherMaybeValues.begin() + getNumDimVars()) &&
          "dim values mismatch");
-  assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
-  assert(getNumLocalVars() == 0 && "local vars not supported yet here");
 
   // Align `other` to this.
   if (!areVarsAligned(*this, otherCst)) {
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 74cdf567c0e56..89d3a936e8e9e 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1578,13 +1578,11 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
 
 /// Returns a non-negative constant bound on the extent (upper bound - lower
 /// bound) of the specified variable if it is found to be a constant; returns
-/// std::nullopt if it's not a constant. This methods treats symbolic variables
-/// specially, i.e., it looks for constant differences between affine
-/// expressions involving only the symbolic variables. See comments at function
-/// definition for example. 'lb', if provided, is set to the lower bound
-/// associated with the constant difference. Note that 'lb' is purely symbolic
-/// and thus will contain the coefficients of the symbolic variables and the
-/// constant coefficient.
+/// std::nullopt if it's not a constant. This methods looks for constant
+/// differences between affine expressions. See comments at function definition
+/// for example. 'lb', if provided, is set to the lower bound associated with
+/// the constant difference. `lb' will contain the coefficients of the symbolic
+/// variables, local variables and the constant coefficient.
 //  Egs: 0 <= i <= 15, return 16.
 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
@@ -1600,22 +1598,15 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
   // of the symbolic variables (+ constant).
   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
   if (eqPos != -1) {
-    auto eq = getEquality(eqPos);
-    // If the equality involves a local var, punt for now.
-    // TODO: this can be handled in the future by using the explicit
-    // representation of the local vars.
-    if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
-                     [](const DynamicAPInt &coeff) { return coeff == 0; }))
-      return std::nullopt;
-
     // This variable can only take a single value.
     if (lb) {
       // Set lb to that symbolic value.
-      lb->resize(getNumSymbolVars() + 1);
+      lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
       if (ub)
-        ub->resize(getNumSymbolVars() + 1);
-      for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
-        DynamicAPInt v = atEq(eqPos, pos);
+        ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
+      for (unsigned c = 0, f = getNumSymbolVars() + getNumLocalVars() + 1;
+           c < f; c++) {
+        MPInt v = atEq(eqPos, pos);
         // atEq(eqRow, pos) is either -1 or 1.
         assert(v * v == 1);
         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
@@ -1687,27 +1678,30 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
   }
   if (lb && minDiff) {
     // Set lb to the symbolic lower bound.
-    lb->resize(getNumSymbolVars() + 1);
+    lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
     if (ub)
-      ub->resize(getNumSymbolVars() + 1);
+      ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
     // The lower bound is the ceildiv of the lb constraint over the coefficient
     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
     *boundFloorDivisor = atIneq(minLbPosition, pos);
     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
-    for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
+    for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; c < e;
+         c++) {
       (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
     }
     if (ub) {
-      for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
+      for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1;
+           c < e; c++)
         (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
     }
     // The lower bound leads to a ceildiv while the upper bound is a floordiv
     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
     // the constant term for the lower bound.
-    (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
+    (*lb)[getNumSymbolVars() + getNumLocalVars()] +=
+        atIneq(minLbPosition, pos) - 1;
   }
   if (minLbPos)
     *minLbPos = minLbPosition;
@@ -2180,8 +2174,6 @@ static void getCommonConstraints(const IntegerRelation &a,
 LogicalResult
 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
   assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
-  assert(getNumLocalVars() == 0 && "local ids not supported yet here");
-
   // Get the constraints common to both systems; these will be added as is to
   // the union.
   IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
@@ -2211,11 +2203,9 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
     auto otherExtent = otherCst.getConstantBoundOnDimSize(
         d, &otherLb, &otherLbFloorDivisor, &otherUb);
     if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
-      // TODO: symbolic extents when necessary.
       return failure();
 
     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
-
     auto res = compareBounds(lb, otherLb);
     // Identify min.
     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..44c67a301b110 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
 }
+
+// Test union of two integer relations if they have local variable(s).
+TEST(IntegerRelationTest, unionBoundingBox) {
+  IntegerRelation relA = parseRelationFromSet(
+      "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x"
+      ">= 0, x + y + z floordiv 6 == 0)",
+      1);
+  IntegerRelation relB = parseRelationFromSet(
+      "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x"
+      ">= 0, x + y + z floordiv 7 == 0)",
+      1);
+  assert(relA.getNumLocalVars() > 0);
+  EXPECT_TRUE(relA.unionBoundingBox(relB).succeeded());
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2025

@llvm/pr-subscribers-mlir-presburger

Author: Arnab Dutta (arnab-polymage)

Changes

Enhance unionBoundingBox utility to work with input constraints having local variables.


Full diff: https://github.com/llvm/llvm-project/pull/128709.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Analysis/FlatLinearValueConstraints.h (+4-5)
  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+11-13)
  • (modified) mlir/lib/Analysis/FlatLinearValueConstraints.cpp (-2)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+18-28)
  • (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+14)
diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index c8167014b5300..15387201affa8 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -474,11 +474,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
   bool areVarsAlignedWithOther(const FlatLinearConstraints &other);
 
   /// Updates the constraints to be the smallest bounding (enclosing) box that
-  /// contains the points of `this` set and that of `other`, with the symbols
-  /// being treated specially. For each of the dimensions, the min of the lower
-  /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
-  /// to determine such a bounding box. `other` is expected to have the same
-  /// dimensional variables as this constraint system (in the same order).
+  /// contains the points of `this` set and that of `other`. For each of the
+  /// dimensions, the min of the lower bounds and the max of the upper bounds is
+  /// computed to determine such a bounding box. `other` is expected to have the
+  /// same dimensional variables as this constraint system (in the same order).
   ///
   /// E.g.:
   /// 1) this   = {0 <= d0 <= 127},
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index ddc18038e869c..ae45743ecc1be 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -489,11 +489,10 @@ class IntegerRelation {
   void constantFoldVarRange(unsigned pos, unsigned num);
 
   /// Updates the constraints to be the smallest bounding (enclosing) box that
-  /// contains the points of `this` set and that of `other`, with the symbols
-  /// being treated specially. For each of the dimensions, the min of the lower
-  /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
-  /// to determine such a bounding box. `other` is expected to have the same
-  /// dimensional variables as this constraint system (in the same order).
+  /// contains the points of `this` set and that of `other`. For each of the
+  /// dimensions, the min of the lower bounds and the max of the upper bounds is
+  /// computed to determine such a bounding box. `other` is expected to have the
+  /// same dimensional variables as this constraint system (in the same order).
   ///
   /// E.g.:
   /// 1) this   = {0 <= d0 <= 127},
@@ -512,14 +511,13 @@ class IntegerRelation {
   /// than or equal to 'exclusive upper bound' - 'lower bound' of the
   /// variable. This constant bound is guaranteed to be non-negative. Returns
   /// std::nullopt if it's not a constant. This method employs trivial (low
-  /// complexity / cost) checks and detection. Symbolic variables are treated
-  /// specially, i.e., it looks for constant differences between affine
-  /// expressions involving only the symbolic variables. `lb` and `ub` (along
-  /// with the `boundFloorDivisor`) are set to represent the lower and upper
-  /// bound associated with the constant difference: `lb`, `ub` have the
-  /// coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` and
-  /// `minUbPos` if non-null are set to the position of the constant lower bound
-  /// and upper bound respectively (to the same if they are from an
+  /// complexity / cost) checks and detection. It looks for constant differences
+  /// between affine expressions involving symbolic and local variables. `lb`
+  /// and `ub` (along with the `boundFloorDivisor`) are set to represent the
+  /// lower and upper bound associated with the constant difference: `lb`, `ub`
+  /// have the coefficients, and `boundFloorDivisor`, their divisor. `minLbPos`
+  /// and `minUbPos` if non-null are set to the position of the constant lower
+  /// bound and upper bound respectively (to the same if they are from an
   /// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a
   /// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See
   /// comments at function definition for examples.
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 4653eca9887ce..ae9f9acd89c2e 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1303,8 +1303,6 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
                     otherMaybeValues.begin(),
                     otherMaybeValues.begin() + getNumDimVars()) &&
          "dim values mismatch");
-  assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
-  assert(getNumLocalVars() == 0 && "local vars not supported yet here");
 
   // Align `other` to this.
   if (!areVarsAligned(*this, otherCst)) {
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 74cdf567c0e56..89d3a936e8e9e 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1578,13 +1578,11 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
 
 /// Returns a non-negative constant bound on the extent (upper bound - lower
 /// bound) of the specified variable if it is found to be a constant; returns
-/// std::nullopt if it's not a constant. This methods treats symbolic variables
-/// specially, i.e., it looks for constant differences between affine
-/// expressions involving only the symbolic variables. See comments at function
-/// definition for example. 'lb', if provided, is set to the lower bound
-/// associated with the constant difference. Note that 'lb' is purely symbolic
-/// and thus will contain the coefficients of the symbolic variables and the
-/// constant coefficient.
+/// std::nullopt if it's not a constant. This methods looks for constant
+/// differences between affine expressions. See comments at function definition
+/// for example. 'lb', if provided, is set to the lower bound associated with
+/// the constant difference. `lb' will contain the coefficients of the symbolic
+/// variables, local variables and the constant coefficient.
 //  Egs: 0 <= i <= 15, return 16.
 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
@@ -1600,22 +1598,15 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
   // of the symbolic variables (+ constant).
   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
   if (eqPos != -1) {
-    auto eq = getEquality(eqPos);
-    // If the equality involves a local var, punt for now.
-    // TODO: this can be handled in the future by using the explicit
-    // representation of the local vars.
-    if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
-                     [](const DynamicAPInt &coeff) { return coeff == 0; }))
-      return std::nullopt;
-
     // This variable can only take a single value.
     if (lb) {
       // Set lb to that symbolic value.
-      lb->resize(getNumSymbolVars() + 1);
+      lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
       if (ub)
-        ub->resize(getNumSymbolVars() + 1);
-      for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
-        DynamicAPInt v = atEq(eqPos, pos);
+        ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
+      for (unsigned c = 0, f = getNumSymbolVars() + getNumLocalVars() + 1;
+           c < f; c++) {
+        MPInt v = atEq(eqPos, pos);
         // atEq(eqRow, pos) is either -1 or 1.
         assert(v * v == 1);
         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
@@ -1687,27 +1678,30 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
   }
   if (lb && minDiff) {
     // Set lb to the symbolic lower bound.
-    lb->resize(getNumSymbolVars() + 1);
+    lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
     if (ub)
-      ub->resize(getNumSymbolVars() + 1);
+      ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
     // The lower bound is the ceildiv of the lb constraint over the coefficient
     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
     *boundFloorDivisor = atIneq(minLbPosition, pos);
     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
-    for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
+    for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; c < e;
+         c++) {
       (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
     }
     if (ub) {
-      for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
+      for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1;
+           c < e; c++)
         (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
     }
     // The lower bound leads to a ceildiv while the upper bound is a floordiv
     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
     // the constant term for the lower bound.
-    (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
+    (*lb)[getNumSymbolVars() + getNumLocalVars()] +=
+        atIneq(minLbPosition, pos) - 1;
   }
   if (minLbPos)
     *minLbPos = minLbPosition;
@@ -2180,8 +2174,6 @@ static void getCommonConstraints(const IntegerRelation &a,
 LogicalResult
 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
   assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
-  assert(getNumLocalVars() == 0 && "local ids not supported yet here");
-
   // Get the constraints common to both systems; these will be added as is to
   // the union.
   IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
@@ -2211,11 +2203,9 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
     auto otherExtent = otherCst.getConstantBoundOnDimSize(
         d, &otherLb, &otherLbFloorDivisor, &otherUb);
     if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
-      // TODO: symbolic extents when necessary.
       return failure();
 
     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
-
     auto res = compareBounds(lb, otherLb);
     // Identify min.
     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..44c67a301b110 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
 }
+
+// Test union of two integer relations if they have local variable(s).
+TEST(IntegerRelationTest, unionBoundingBox) {
+  IntegerRelation relA = parseRelationFromSet(
+      "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x"
+      ">= 0, x + y + z floordiv 6 == 0)",
+      1);
+  IntegerRelation relB = parseRelationFromSet(
+      "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x"
+      ">= 0, x + y + z floordiv 7 == 0)",
+      1);
+  assert(relA.getNumLocalVars() > 0);
+  EXPECT_TRUE(relA.unionBoundingBox(relB).succeeded());
+}

@arnab-polymage arnab-polymage force-pushed the ornib/union_bounding_box branch from daaad70 to f736501 Compare February 25, 2025 12:58
@bondhugula bondhugula changed the title Enhance unionBoundingBox utility [MLIR][Affine] Enhance unionBoundingBox utility Feb 26, 2025
@bondhugula bondhugula changed the title [MLIR][Affine] Enhance unionBoundingBox utility [MLIR] Enhance unionBoundingBox utility Feb 26, 2025
@bondhugula
Copy link
Contributor

Please fix the build failure.

@arnab-polymage arnab-polymage force-pushed the ornib/union_bounding_box branch 2 times, most recently from b96a128 to 9843043 Compare February 26, 2025 05:23
Enhance `unionBoundingBox` utility to work with input
constraints having local variables.
@arnab-polymage arnab-polymage force-pushed the ornib/union_bounding_box branch from 9843043 to bda7cd5 Compare February 26, 2025 06:53
Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is making IntegerRelation::getConstantBoundOnDimSize incorrect. Also, the test isn't really testing the output of the union but simply that it didn't fail.

/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
/// to determine such a bounding box. `other` is expected to have the same
/// dimensional variables as this constraint system (in the same order).
/// contains the points of `this` set and that of `other`. For each of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The symbols are still expected to be treated specially. Why is that removed from this comment now? Aren't the users of this method expecting symbolic lower and upper bound functions?

// This variable can only take a single value.
if (lb) {
// Set lb to that symbolic value.
lb->resize(getNumSymbolVars() + 1);
lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
Copy link
Contributor

@bondhugula bondhugula Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not right. The local variable itself may not be symbolic, and as a result, you can't conclude that the pos^th variable can only take a single value contrary to the comment above. You need to find out what the local variable is.

Eg:

d1 - s0 - l0 = 1,
d0 - l0 = 0
0 <= d0 <= s1
0 <= d1 < s2

getConstantBoundOnDimSize on d1 for the above set should return nothing, it's not bounded by a constant size, but you are returning a bounded size of one post this change and that's wrong.

Comment on lines +614 to +622
IntegerRelation relA = parseRelationFromSet(
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x"
">= 0, x + y + z floordiv 6 == 0)",
1);
IntegerRelation relB = parseRelationFromSet(
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x"
">= 0, x + y + z floordiv 7 == 0)",
1);
assert(relA.getNumLocalVars() > 0);
Copy link
Contributor

@bondhugula bondhugula Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this isn't testing whether the bounding box is indeed fine! This isn't really a test for the changes. Add a test that checks that the right bounding box is found.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants