From 465c6606d2f30eb9a8654e07b9da7d6d64e65b97 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 26 Jun 2025 18:48:11 +0200 Subject: [PATCH 1/3] [mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShapedTypeOpInterface --- .../mlir/Dialect/MemRef/Transforms/Passes.td | 13 ++ .../Dialect/MemRef/Transforms/Transforms.h | 4 + .../ResolveShapedTypeResultDims.cpp | 126 ++++++++++++++++++ .../Dialect/Tensor/infer-static-shapes.mlir | 18 +++ 4 files changed, 161 insertions(+) create mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index a8d135caa74f0..2406b47538ddc 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> { ]; } +def InferStaticShapesPass : Pass<"infer-static-shapes"> { + let summary = "Resolve memref.dim of result values"; + let description = [{ + The pass resolves memref.dim of result of operations that + implement the `InferShapedTypeOpInterface` or + `ReifyRankedShapedTypeOpInterface` in terms of shapes of its + operands. + }]; + let dependentDialects = [ + "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" + ]; +} + def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> { let summary = "Expand memref operations into easier to analyze constructs"; let description = [{ diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index c2b8cb05be922..b069d5f284597 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns( /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); +/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops +/// shapes more static. +void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns); + /// Appends patterns for expanding memref operations that modify the metadata /// (sizes, offset, strides) of a memref into easier to analyze constructs. void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 89a3895d06ba5..919b3fbc95479 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -20,13 +20,22 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/InterleavedRange.h" + +#define DEBUG_TYPE "resolve-shaped-type" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace mlir { namespace memref { #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS +#define GEN_PASS_DEF_INFERSTATICSHAPESPASS #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir @@ -105,6 +114,99 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { } }; +struct ReifyToInferStaticShapePattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op, + PatternRewriter &rewriter) const override { + LLVM_DEBUG( + { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; }); + + bool rewriteToMoreStatic = false; + ReifiedRankedShapedTypeDims reifiedResultShapes; + if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) || + reifiedResultShapes.empty()) { + LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; }); + return failure(); + } + + SmallVector newTypes; + for (auto [t, reifiedShape] : + llvm::zip(op->getResultTypes(), reifiedResultShapes)) { + ShapedType st = dyn_cast(t); + if (!st) + continue; + + SmallVector newShape; + for (const auto &[s, ofr] : + llvm::zip_equal(st.getShape(), reifiedShape)) { + std::optional maybeCst = getConstantIntValue(ofr); + // Reification does not add static information, just use existing shape. + if (!maybeCst.has_value()) { + newShape.push_back(s); + continue; + } + int64_t cst = *maybeCst; + assert((ShapedType::isDynamic(s) || s == cst) && + "constants must agree!"); + newShape.push_back(cst); + } + + if (newShape == st.getShape()) { + newTypes.push_back(t); + continue; + } + + rewriteToMoreStatic = true; + Type newType = st.cloneWith(newShape, st.getElementType()); + newTypes.push_back(newType); + } + + LLVM_DEBUG({ + DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes()) + << " \n"; + DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n"; + }); + if (!rewriteToMoreStatic) { + LLVM_DEBUG({ DBGS() << "not more static\n"; }); + return failure(); + } + + // We now have newTypes that need to be turned to tensor::CastOp. + Location loc = op->getLoc(); + SmallVector newResults; + Operation *newOp = rewriter.clone(*op); + for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) { + Type ot = oldVal.getType(); + OpResult newResult = newOp->getResult(oldVal.getResultNumber()); + if (ot == nt) { + newResults.push_back(newResult); + continue; + } + newResult.setType(nt); + if (isa(nt)) { + newResults.push_back( + rewriter.create(loc, ot, newResult)); + } else if (isa(nt)) { + newResults.push_back( + rewriter.create(loc, ot, newResult)); + } else { + llvm_unreachable("expected RankedTensorType or MemRefType"); + } + } + + LLVM_DEBUG({ + op->getParentOp()->dump(); + DBGS() << "replace op " << *op << "\n"; + DBGS() << "with newResults " << llvm::interleaved_array(newResults) + << "\n\n\n\n"; + }); + rewriter.replaceAllOpUsesWith(op, newResults); + return success(); + } +}; + /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: /// /// ``` @@ -175,6 +277,11 @@ struct ResolveShapedTypeResultDimsPass final void runOnOperation() override; }; +struct InferStaticShapesPass final + : public memref::impl::InferStaticShapesPassBase { + void runOnOperation() override; +}; + } // namespace void memref::populateResolveRankedShapedTypeResultDimsPatterns( @@ -192,6 +299,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns( patterns.getContext()); } +void memref::populateReifyToInferStaticShapePatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); @@ -206,3 +318,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() { if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } + +void InferStaticShapesPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + SmallVector opsToSimplify; + getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) { + opsToSimplify.push_back(op); + }); + (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, + GreedyRewriteConfig().setStrictness( + GreedyRewriteStrictness::ExistingOps)); +} diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir new file mode 100644 index 0000000000000..1712ce7df38b1 --- /dev/null +++ b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @pad_reification +func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) + -> tensor<1x?x64xf32> { + %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx) + %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] + : tensor<64x?x64xf32> to tensor<1x?x64xf32> + +// CHECK: tensor.pad +// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32> + %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] { + ^bb0(%a: index, %b: index, %c: index): + tensor.yield %cst : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + + return %padded : tensor<1x?x64xf32> +} From 1f4fba78948dec85d558236f15678c10cd289fcc Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Thu, 26 Jun 2025 19:42:35 +0000 Subject: [PATCH 2/3] rename transform to reify-result-shapes --- .../mlir/Dialect/MemRef/Transforms/Passes.td | 39 ++++- .../Dialect/MemRef/Transforms/Transforms.h | 16 +- .../Dialect/MemRef/Transforms/CMakeLists.txt | 1 + .../MemRef/Transforms/ReifyResultShapes.cpp | 144 ++++++++++++++++++ .../ResolveShapedTypeResultDims.cpp | 126 --------------- .../Dialect/Tensor/infer-static-shapes.mlir | 18 --- mlir/test/Dialect/Tensor/reify-shapes.mlir | 31 ++++ 7 files changed, 221 insertions(+), 154 deletions(-) create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp delete mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir create mode 100644 mlir/test/Dialect/Tensor/reify-shapes.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index 2406b47538ddc..4645d49cab2be 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -182,13 +182,40 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> { ]; } -def InferStaticShapesPass : Pass<"infer-static-shapes"> { - let summary = "Resolve memref.dim of result values"; +def ReifyResultShapesPass : Pass<"reify-result-shapes"> { + let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations"; let description = [{ - The pass resolves memref.dim of result of operations that - implement the `InferShapedTypeOpInterface` or - `ReifyRankedShapedTypeOpInterface` in terms of shapes of its - operands. + This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface` + operation with ranked `memref` and `tensor` results. Replacing the + operations with their reified versions, and inserting casts when results + shapes are updated. + + Example: + ```mlir + #map = affine_map<(d0) -> (-d0 + 256)> + func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> { + %0 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %arg0 : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + return %padded : tensor<1x?x64xf32> + } + + // mlir-opt --reify-result-shapes + #map = affine_map<()[s0] -> (-s0 + 256)> + func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> { + %0 = affine.apply #map()[%arg1] + %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %arg0 : f32 + } : tensor<1x?x64xf32> to tensor<1x256x64xf32> + %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32> + return %cast : tensor<1x?x64xf32> + } + ``` }]; let dependentDialects = [ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index b069d5f284597..5f9f09d7992ca 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -23,6 +23,7 @@ class RewritePatternSet; class RewriterBase; class Value; class ValueRange; +class ReifyRankedShapedTypeOpInterface; namespace arith { class WideIntEmulationConverter; @@ -57,10 +58,6 @@ void populateResolveRankedShapedTypeResultDimsPatterns( /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); -/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops -/// shapes more static. -void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns); - /// Appends patterns for expanding memref operations that modify the metadata /// (sizes, offset, strides) of a memref into easier to analyze constructs. void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns); @@ -213,6 +210,17 @@ memref::AllocaOp allocToAlloca( RewriterBase &rewriter, memref::AllocOp alloc, function_ref filter = nullptr); +/// Reifies the results of `op`, potentially replacing `op` with a reified +/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure, +/// otherwise it always succeeds. Users of this transform should always expect +/// it to modify the IR, even when it fails. If any of the result types changes, +/// the transform will insert cast operations to the old type to keep the IR +/// consistent. +/// +/// Note: This transform only works on ranked `memref` or `tensor` results, +/// other types are ignored. +LogicalResult reifyOpResultShapes(RewriterBase &rewriter, + ReifyRankedShapedTypeOpInterface op); } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 637f5ec1c9f9b..9049faccadef3 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms IndependenceTransforms.cpp MultiBuffer.cpp NormalizeMemRefs.cpp + ReifyResultShapes.cpp ResolveShapedTypeResultDims.cpp RuntimeOpVerification.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp new file mode 100644 index 0000000000000..dcb601577f88f --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -0,0 +1,144 @@ +//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface` +// operations with ranked `memref` and `tensor` results. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/Support/InterleavedRange.h" + +#define DEBUG_TYPE "reify-result-shapes" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir + +using namespace mlir; + +LogicalResult +mlir::memref::reifyOpResultShapes(RewriterBase &rewriter, + ReifyRankedShapedTypeOpInterface op) { + LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; }); + // Get the reified out shapes. + ReifiedRankedShapedTypeDims reifiedResultShapes; + if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) || + reifiedResultShapes.empty()) { + return op.emitError() << "failed to get the reified shapes"; + } + + bool modified = false; + // Compute the new output types. + SmallVector outTypes; + for (const auto &[oldTy, reifiedShape] : + llvm::zip(op->getResultTypes(), reifiedResultShapes)) { + // Skip if it's not a memref or tensor type. + if (!isa(oldTy)) { + outTypes.push_back(oldTy); + continue; + } + + ShapedType shapedTy = dyn_cast(oldTy); + + SmallVector shape = llvm::to_vector(shapedTy.getShape()); + for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) { + std::optional maybeCst = getConstantIntValue(ofr); + // If the reified dim is dynamic set it appropriately. + if (!maybeCst.has_value()) { + dim = ShapedType::kDynamic; + continue; + } + // Set the static dim. + dim = *maybeCst; + } + + // If the shape didn't change continue. + if (shape == shapedTy.getShape()) { + outTypes.push_back(oldTy); + continue; + } + modified = true; + outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType())); + } + + // Return if we don't need to update. + if (!modified) { + LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; }); + return success(); + } + + LLVM_DEBUG({ + DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes()) + << " \n"; + DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n"; + }); + + // We now have outTypes that need to be turned to cast ops. + Location loc = op->getLoc(); + SmallVector newResults; + Operation *newOp = rewriter.clone(*op); + for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) { + OpResult newRes = newOp->getResult(oldRes.getResultNumber()); + Type oldTy = oldRes.getType(); + // Continue if the type remained invariant or is not shaped. + if (oldTy == reifiedTy || !isa(oldTy)) { + newResults.push_back(newRes); + continue; + } + + // Update the type. + newRes.setType(reifiedTy); + if (isa(reifiedTy)) { + newResults.push_back(rewriter.create(loc, oldTy, newRes)); + } else { + assert(isa(reifiedTy) && "expected a memref type"); + newResults.push_back(rewriter.create(loc, oldTy, newRes)); + } + } + + LLVM_DEBUG({ + DBGS() << "- reified results " << llvm::interleaved_array(newResults) + << "\n"; + }); + rewriter.replaceOp(op, newResults); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +struct ReifyResultShapesPass final + : public memref::impl::ReifyResultShapesPassBase { + void runOnOperation() override; +}; +} // namespace + +void ReifyResultShapesPass::runOnOperation() { + SmallVector ops; + getOperation()->walk( + [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); }); + IRRewriter rewriter(&getContext()); + for (ReifyRankedShapedTypeOpInterface op : ops) { + rewriter.setInsertionPoint(op); + if (failed(memref::reifyOpResultShapes(rewriter, op))) + return signalPassFailure(); + } +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 919b3fbc95479..89a3895d06ba5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -20,22 +20,13 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/InterleavedRange.h" - -#define DEBUG_TYPE "resolve-shaped-type" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace mlir { namespace memref { #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS -#define GEN_PASS_DEF_INFERSTATICSHAPESPASS #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir @@ -114,99 +105,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { } }; -struct ReifyToInferStaticShapePattern - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op, - PatternRewriter &rewriter) const override { - LLVM_DEBUG( - { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; }); - - bool rewriteToMoreStatic = false; - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) || - reifiedResultShapes.empty()) { - LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; }); - return failure(); - } - - SmallVector newTypes; - for (auto [t, reifiedShape] : - llvm::zip(op->getResultTypes(), reifiedResultShapes)) { - ShapedType st = dyn_cast(t); - if (!st) - continue; - - SmallVector newShape; - for (const auto &[s, ofr] : - llvm::zip_equal(st.getShape(), reifiedShape)) { - std::optional maybeCst = getConstantIntValue(ofr); - // Reification does not add static information, just use existing shape. - if (!maybeCst.has_value()) { - newShape.push_back(s); - continue; - } - int64_t cst = *maybeCst; - assert((ShapedType::isDynamic(s) || s == cst) && - "constants must agree!"); - newShape.push_back(cst); - } - - if (newShape == st.getShape()) { - newTypes.push_back(t); - continue; - } - - rewriteToMoreStatic = true; - Type newType = st.cloneWith(newShape, st.getElementType()); - newTypes.push_back(newType); - } - - LLVM_DEBUG({ - DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes()) - << " \n"; - DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n"; - }); - if (!rewriteToMoreStatic) { - LLVM_DEBUG({ DBGS() << "not more static\n"; }); - return failure(); - } - - // We now have newTypes that need to be turned to tensor::CastOp. - Location loc = op->getLoc(); - SmallVector newResults; - Operation *newOp = rewriter.clone(*op); - for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) { - Type ot = oldVal.getType(); - OpResult newResult = newOp->getResult(oldVal.getResultNumber()); - if (ot == nt) { - newResults.push_back(newResult); - continue; - } - newResult.setType(nt); - if (isa(nt)) { - newResults.push_back( - rewriter.create(loc, ot, newResult)); - } else if (isa(nt)) { - newResults.push_back( - rewriter.create(loc, ot, newResult)); - } else { - llvm_unreachable("expected RankedTensorType or MemRefType"); - } - } - - LLVM_DEBUG({ - op->getParentOp()->dump(); - DBGS() << "replace op " << *op << "\n"; - DBGS() << "with newResults " << llvm::interleaved_array(newResults) - << "\n\n\n\n"; - }); - rewriter.replaceAllOpUsesWith(op, newResults); - return success(); - } -}; - /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: /// /// ``` @@ -277,11 +175,6 @@ struct ResolveShapedTypeResultDimsPass final void runOnOperation() override; }; -struct InferStaticShapesPass final - : public memref::impl::InferStaticShapesPassBase { - void runOnOperation() override; -}; - } // namespace void memref::populateResolveRankedShapedTypeResultDimsPatterns( @@ -299,11 +192,6 @@ void memref::populateResolveShapedTypeResultDimsPatterns( patterns.getContext()); } -void memref::populateReifyToInferStaticShapePatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} - void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); @@ -318,17 +206,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() { if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } - -void InferStaticShapesPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - - SmallVector opsToSimplify; - getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) { - opsToSimplify.push_back(op); - }); - (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, - GreedyRewriteConfig().setStrictness( - GreedyRewriteStrictness::ExistingOps)); -} diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir deleted file mode 100644 index 1712ce7df38b1..0000000000000 --- a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s - -// CHECK-LABEL: func.func @pad_reification -func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) - -> tensor<1x?x64xf32> { - %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx) - %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] - : tensor<64x?x64xf32> to tensor<1x?x64xf32> - -// CHECK: tensor.pad -// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32> - %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] { - ^bb0(%a: index, %b: index, %c: index): - tensor.yield %cst : f32 - } : tensor<1x?x64xf32> to tensor<1x?x64xf32> - - return %padded : tensor<1x?x64xf32> -} diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir new file mode 100644 index 0000000000000..5569d90f8b731 --- /dev/null +++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -reify-result-shapes %s | FileCheck %s + +// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted. +// CHECK-LABEL: func.func @concat_reification +func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor) + -> (tensor<4x11x3xf32>, tensor) { + // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor) -> tensor<4x7x?xf32> + // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor + %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor) -> tensor + // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor + return %1, %2 : tensor<4x11x3xf32>, tensor +} + +// CHECK-LABEL: func.func @pad_reification +func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> { + %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx) + %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] + : tensor<64x?x64xf32> to tensor<1x?x64xf32> + + // CHECK: tensor.pad + // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32> + // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] { + ^bb0(%a: index, %b: index, %c: index): + tensor.yield %cst : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + + return %padded : tensor<1x?x64xf32> +} From ba51026882343871ebb06ced928b657c82c3a2f7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 27 Jun 2025 11:33:21 +0200 Subject: [PATCH 3/3] Don't fail the pass if we can't make shapes more static. Also, only allow tensor::PadOp and tensor::ConcatOp for now as more extensive testing showed that other ops are not ready yet (e.g. at least tensor::ExtractSliceOp / tensor::InsertSliceOp). --- .../MemRef/Transforms/ReifyResultShapes.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp index dcb601577f88f..b00a0f2103d43 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -40,7 +40,7 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter, ReifiedRankedShapedTypeDims reifiedResultShapes; if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) || reifiedResultShapes.empty()) { - return op.emitError() << "failed to get the reified shapes"; + return op->emitWarning() << "failed to get the reified shapes"; } bool modified = false; @@ -133,12 +133,16 @@ struct ReifyResultShapesPass final void ReifyResultShapesPass::runOnOperation() { SmallVector ops; - getOperation()->walk( - [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); }); + getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) { + // Some ops have rigid type checkers and need to update their operands. + // Only admit the ones that are explicitly supported for now. + if (!isa(op.getOperation())) + return; + ops.push_back(op); + }); IRRewriter rewriter(&getContext()); for (ReifyRankedShapedTypeOpInterface op : ops) { rewriter.setInsertionPoint(op); - if (failed(memref::reifyOpResultShapes(rewriter, op))) - return signalPassFailure(); + (void)memref::reifyOpResultShapes(rewriter, op); } }