From a25509263063f973013159d7e4b25937f65f6ed7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 7 Feb 2024 23:31:50 +0100 Subject: [PATCH 1/6] [mlir][vector] ND vectors linearization pass Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors as `array>` and SPIR-V doesn't handle them as all at the moment. Sometime it's prefferable to tread multidim vectors as linearized. Add pass to do this. Only constants and simple elementwise ops are supported for now. Also, move generic op return type utility to common place and add ConversionPattern operating on traits. --- .../mlir/Dialect/Vector/Transforms/Passes.td | 9 ++ .../Vector/Transforms/VectorRewritePatterns.h | 6 + .../mlir/Transforms/DialectConversion.h | 23 ++++ .../Dialect/Math/Transforms/LegalizeToF32.cpp | 20 +-- .../Dialect/Vector/Transforms/CMakeLists.txt | 1 + .../Vector/Transforms/VectorLinearize.cpp | 122 ++++++++++++++++++ .../Transforms/Utils/DialectConversion.cpp | 21 +++ mlir/test/Dialect/Vector/linearize.mlir | 15 +++ 8 files changed, 204 insertions(+), 13 deletions(-) create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp create mode 100644 mlir/test/Dialect/Vector/linearize.mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td index 4911a61ab3c25..71f412507457c 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> { let constructor = "mlir::vector::createLowerVectorMaskPass()"; } +def VectorLinearize : Pass<"vector-linearize"> { + let summary = "Linearize ND vectors into !d"; + let description = [{ + Linearizes ND vectors for N >= 2 into 1D vectors. + }]; + let dependentDialects = ["vector::VectorDialect"]; + } + + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index f5941d32e683f..45f54fc70e326 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -20,7 +20,9 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" namespace mlir { +class ConversionTarget; class RewritePatternSet; +class TypeConverter; namespace arith { class AndIOp; @@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +void populateVectorLinearizeTypeConversionsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target); + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 51e3e413b516f..5081b4c06a617 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern { using ConversionPattern::matchAndRewrite; }; +/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows +/// for matching and rewriting against instances of an operation that possess a +/// given trait. +template