Skip to content

[mlir][memref] Add a new ReifyResultShapes pass #145927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nicolasvasilache
Copy link
Contributor

@nicolasvasilache nicolasvasilache commented Jun 26, 2025

This patch introduces the ReifyResultShapes pass. This pass reifies the shapes of every ReifyRankedShapedTypeOpInterface operation with ranked memref and tensor results. Replacing the operations with their reified versions, and inserting casts when results shapes are updated.

Example:

#map = affine_map<(d0) -> (-d0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
  %0 = affine.apply #map(%arg1)
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
  return %padded : tensor<1x?x64xf32>
}

// mlir-opt --reify-result-shapes
#map = affine_map<()[s0] -> (-s0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
  %0 = affine.apply #map()[%arg1]
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
  %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
  return %cast : tensor<1x?x64xf32>
}

@nicolasvasilache
Copy link
Contributor Author

@fabianmcg I rewrote the shape inference part as a separate pass, not tested yet and it still requires a safe listener-based rewriter that cleans up after itself on scope_guard exit.
We also discussed with @matthias-springer to get such a rewriter and to forbid modification of newly created IR.

This still requires testing and I prob. made mistakes in my generalization to multi-result rankreify stuff from tensor.pad but this is the best I could do before leaving for the day.

Implementation that worked for tensor.pad specifically is still here: #145732 in case it is useful.

@nicolasvasilache nicolasvasilache force-pushed the users/nico/infer-static-shapes branch from 475e919 to b8539b7 Compare June 26, 2025 16:58
@nicolasvasilache nicolasvasilache force-pushed the users/nico/infer-static-shapes branch from b8539b7 to 465c660 Compare June 26, 2025 17:12
@nicolasvasilache nicolasvasilache changed the title [mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShaped… [mlir][memref] Add a new InferStaticShapes pass for ReifyRankedShaped… Jun 26, 2025
@fabianmcg fabianmcg changed the title [mlir][memref] Add a new InferStaticShapes pass for ReifyRankedShaped… [mlir][memref] Add a new ReifyResultShapes pass Jun 26, 2025
@fabianmcg fabianmcg force-pushed the users/nico/infer-static-shapes branch from 52d5c1f to 1f4fba7 Compare June 26, 2025 19:48
@fabianmcg fabianmcg marked this pull request as ready for review June 26, 2025 19:48
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

Changes

This patch introduces the ReifyResultShapes pass. This pass reifies the shapes of every ReifyRankedShapedTypeOpInterface operation with ranked memref and tensor results. Replacing the operations with their reified versions, and inserting casts when results shapes are updated.

Example:

#map = affine_map&lt;(d0) -&gt; (-d0 + 256)&gt;
func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor&lt;64x?x64xf32&gt;) -&gt; tensor&lt;1x?x64xf32&gt; {
  %0 = affine.apply #map(%arg1)
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor&lt;64x?x64xf32&gt; to tensor&lt;1x?x64xf32&gt;
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor&lt;1x?x64xf32&gt; to tensor&lt;1x?x64xf32&gt;
  return %padded : tensor&lt;1x?x64xf32&gt;
}

// mlir-opt --reify-result-shapes
#map = affine_map&lt;()[s0] -&gt; (-s0 + 256)&gt;
func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor&lt;64x?x64xf32&gt;) -&gt; tensor&lt;1x?x64xf32&gt; {
  %0 = affine.apply #map()[%arg1]
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor&lt;64x?x64xf32&gt; to tensor&lt;1x?x64xf32&gt;
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor&lt;1x?x64xf32&gt; to tensor&lt;1x256x64xf32&gt;
  %cast = tensor.cast %padded : tensor&lt;1x256x64xf32&gt; to tensor&lt;1x?x64xf32&gt;
  return %cast : tensor&lt;1x?x64xf32&gt;
}

---
Full diff: https://github.com/llvm/llvm-project/pull/145927.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (+40) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+12) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp (+144) 
- (added) mlir/test/Dialect/Tensor/reify-shapes.mlir (+31) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..4645d49cab2be 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,46 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
 ];
}

+def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
+  let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
+  let description = [{
+    This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
+    operation with ranked `memref` and `tensor` results. Replacing the
+    operations with their reified versions, and inserting casts when results
+    shapes are updated.
+
+    Example:
+    ```mlir
+    #map = affine_map<(d0) -> (-d0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map(%arg1)
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+      return %padded : tensor<1x?x64xf32>
+    }
+
+    // mlir-opt --reify-result-shapes
+    #map = affine_map<()[s0] -> (-s0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map()[%arg1]
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+      %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+      return %cast : tensor<1x?x64xf32>
+    }
+    ```
+  }];
+  let dependentDialects = [
+    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
+
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
 let summary = "Expand memref operations into easier to analyze constructs";
 let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..5f9f09d7992ca 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class RewritePatternSet;
class RewriterBase;
class Value;
class ValueRange;
+class ReifyRankedShapedTypeOpInterface;

namespace arith {
class WideIntEmulationConverter;
@@ -209,6 +210,17 @@ memref::AllocaOp allocToAlloca(
   RewriterBase &rewriter, memref::AllocOp alloc,
   function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);

+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+///
+/// Note: This transform only works on ranked `memref` or `tensor` results,
+/// other types are ignored.
+LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op);
} // namespace memref
} // namespace mlir

diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 637f5ec1c9f9b..9049faccadef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
 IndependenceTransforms.cpp
 MultiBuffer.cpp
 NormalizeMemRefs.cpp
+  ReifyResultShapes.cpp
 ResolveShapedTypeResultDims.cpp
 RuntimeOpVerification.cpp

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
new file mode 100644
index 0000000000000..dcb601577f88f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -0,0 +1,144 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op) {
+  LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+  // Get the reified out shapes.
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
+  if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+      reifiedResultShapes.empty()) {
+    return op.emitError() << "failed to get the reified shapes";
+  }
+
+  bool modified = false;
+  // Compute the new output types.
+  SmallVector<Type> outTypes;
+  for (const auto &[oldTy, reifiedShape] :
+       llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+    // Skip if it's not a memref or tensor type.
+    if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+
+    ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+    SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+    for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // If the reified dim is dynamic set it appropriately.
+      if (!maybeCst.has_value()) {
+        dim = ShapedType::kDynamic;
+        continue;
+      }
+      // Set the static dim.
+      dim = *maybeCst;
+    }
+
+    // If the shape didn't change continue.
+    if (shape == shapedTy.getShape()) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+    modified = true;
+    outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+  }
+
+  // Return if we don't need to update.
+  if (!modified) {
+    LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+    return success();
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+           << " \n";
+    DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+  });
+
+  // We now have outTypes that need to be turned to cast ops.
+  Location loc = op->getLoc();
+  SmallVector<Value> newResults;
+  Operation *newOp = rewriter.clone(*op);
+  for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
+    OpResult newRes = newOp->getResult(oldRes.getResultNumber());
+    Type oldTy = oldRes.getType();
+    // Continue if the type remained invariant or is not shaped.
+    if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
+      newResults.push_back(newRes);
+      continue;
+    }
+
+    // Update the type.
+    newRes.setType(reifiedTy);
+    if (isa<RankedTensorType>(reifiedTy)) {
+      newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+    } else {
+      assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
+      newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+    }
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- reified results " << llvm::interleaved_array(newResults)
+           << "\n";
+  });
+  rewriter.replaceOp(op, newResults);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ReifyResultShapesPass final
+    : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ReifyResultShapesPass::runOnOperation() {
+  SmallVector<ReifyRankedShapedTypeOpInterface> ops;
+  getOperation()->walk(
+      [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+  IRRewriter rewriter(&getContext());
+  for (ReifyRankedShapedTypeOpInterface op : ops) {
+    rewriter.setInsertionPoint(op);
+    if (failed(memref::reifyOpResultShapes(rewriter, op)))
+      return signalPassFailure();
+  }
+}
diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir
new file mode 100644
index 0000000000000..5569d90f8b731
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -reify-result-shapes  %s | FileCheck %s
+
+// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
+// CHECK-LABEL:  func.func @concat_reification
+func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
+  -> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
+  // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
+  // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+  return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
+    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+  // CHECK: tensor.pad
+  // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+    ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianmcg
Copy link
Contributor

This seems to be similar to https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp . What is this doing on top of that pass?

They are not exactly similar.
ResolveShapedTypeResultDims looks for dims and reifies when they are present.
In this pass we reify ops independent if there's a dim or not. That's why the pad test in tests is not modified by ResolveShapedTypeResultDims, but is correctly handled by this pass.

Also, just noticed that this pattern is wrong: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp#L57-L59

It has the potential of returning failure after having modified the IR, violating the contracts stablished with the rewriter. That's also the main reason this pass doesn't use patterns, as we would need a rollback rewriter.

@MaheshRavishankar
Copy link
Contributor

Looking at

#map = affine_map<(d0) -> (-d0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
  %0 = affine.apply #map(%arg1)
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
  return %padded : tensor<1x?x64xf32>
}

// mlir-opt --reify-result-shapes
#map = affine_map<()[s0] -> (-s0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
  %0 = affine.apply #map()[%arg1]
  %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
  %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %arg0 : f32
  } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
  %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
  return %cast : tensor<1x?x64xf32>
}

This has nothing to do with reifying result shapes. This is just a const shape propagation that should be covered by the "cast folding" pattern here

struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
(or the one below it).

There isnt a tensor.dim in the input IR, so there is nothing to "reify" . Might be a mismatch in what you expect that interface to do. It is explicitly meant to resolve tensor.dim operations.

@fabianmcg
Copy link
Contributor

fabianmcg commented Jun 26, 2025

We might need to update the docs, because at least I cannot infer that from the docs:

def ReifyRankedShapedTypeOpInterface :
    OpInterface<"ReifyRankedShapedTypeOpInterface"> {
  let description = [{
    Interface to compute the shape of the result of an operation when
    the result is a ranked shape type, i.e. `RankedTensorType` or
    `MemRefType`.
  }];
  let cppNamespace = "::mlir";

  let methods = [
    InterfaceMethod<
      /*desc=*/[{
        Reify the shape of the result of an operation (typically in terms of the
        shape of its operands).

        `reifiedReturnShapes` is populated with one vector per op result. Each
        of those vectors contains an OpFoldResult for each dimension of the
        shaped type. The given builder may be used to insert ops that compute
        result shapes.

        If the shape of a particular result cannot be computed it must be empty.
      }],
      /*retTy=*/"::llvm::LogicalResult",
      /*methodName=*/"reifyResultShapes",
      /*args=*/(ins "::mlir::OpBuilder &":$builder,
        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
    >
  ];
}

Also, there's no cast to fold into or from in this case. So no way to augment those patterns.

Finally, I'd argue that it is reification because we are coming from dynamic dimensions and inferring that the output type is static.

@MaheshRavishankar
Copy link
Contributor

We might need to update the docs, because at least I cannot infer that from the docs:

def ReifyRankedShapedTypeOpInterface :
    OpInterface<"ReifyRankedShapedTypeOpInterface"> {
  let description = [{
    Interface to compute the shape of the result of an operation when
    the result is a ranked shape type, i.e. `RankedTensorType` or
    `MemRefType`.
  }];
  let cppNamespace = "::mlir";

  let methods = [
    InterfaceMethod<
      /*desc=*/[{
        Reify the shape of the result of an operation (typically in terms of the
        shape of its operands).

        `reifiedReturnShapes` is populated with one vector per op result. Each
        of those vectors contains an OpFoldResult for each dimension of the
        shaped type. The given builder may be used to insert ops that compute
        result shapes.

        If the shape of a particular result cannot be computed it must be empty.
      }],
      /*retTy=*/"::llvm::LogicalResult",
      /*methodName=*/"reifyResultShapes",
      /*args=*/(ins "::mlir::OpBuilder &":$builder,
        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
    >
  ];
}

Fair point. Document needs to be updated.

Also, there's no cast to fold into or from in this case. So no way to augment those patterns.

Ok, maybe I misread what it the intent was and got thrown off by the use of the ReifyRankedShapedTypeOpInterface. (the reify was meant to say "generate the shape of the output" implicitly saying generate the code for tensor.dim <result>, <dim>.).

For other operations this is done by just canonicalizing the op itself to go to a more "static" version of the op, and introduce tensor.cast operations to fix up the tasks, and the tensor.cast folding patterns that are added absorb these into producer/consumers and propagate the "static" information through the program. AFAICS, the first pattern is missing and that has nothing to do with the ReifyRankedShapedTypeOpInterface. We need to add a pattern by itself that is similar to what is done for tensor.extract_slice here. Sorry I think you intended to do that from the get go, but I was thrown off by the ReifyRankedShapedTypeOpInterface angle.

Finally, I'd argue that it is reification because we are coming from dynamic dimensions and inferring that the output type is static.

@nicolasvasilache
Copy link
Contributor Author

We need to add a pattern by itself that is similar to what is done for tensor.extract_slice here. Sorry I think you intended to do that from the get go, but I was thrown off by the ReifyRankedShapedTypeOpInterface angle.

Well, the tensor.extract_slice points to a canonicalization pattern but in #145732 you write """Bundling that in a canonicalization that allows no control of what is run is a huge flaw that in my view we should work ourselves out of."""
So which is it ? :)

@nicolasvasilache
Copy link
Contributor Author

Ok, maybe I misread what it the intent was and got thrown off by the use of the ReifyRankedShapedTypeOpInterface. (the reify was meant to say "generate the shape of the output" implicitly saying generate the code for tensor.dim <result>, <dim>.).

There is no need to limit the interface to only exist in conjunction with tensor.dim ops, see my explanation here #145732 (comment)

I also don't understand "implicitly saying generate the code for tensor.dim <result>, <dim>": this is really a shape question, which may or may not involve tensor.dim ops. The more we fold away such ops the better to obtain static shapes.

Happy to rename the functionality if that is what causes hiccups (naming is still hard in 2025).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants