diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 7c943f07066c7..46bb3ddec0baf 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -330,8 +330,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, /// These patterns insert memref.collapse_shape + vector.shape_cast patterns /// to transform multiple small n-D transfers into a larger 1-D transfer where /// the memref contiguity properties allow it. -void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +/// +/// Flattening is only applied if the bitwidth of the trailing vector dimension +/// is smaller or equal to `targetVectorBitwidth`. +void populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns, + unsigned targetVectorBitwidth = std::numeric_limits::max(), + PatternBenefit benefit = 1); /// Collect a set of patterns that bubble up/down bitcast ops. /// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index b761d1ed88897..04e5a816dd91e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -19,7 +19,6 @@ #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" @@ -535,9 +534,17 @@ namespace { /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be /// already reduced i.e. without unit dims. +/// If `targetVectorBitwidth` is provided, the flattening will only happen if +/// the trailing dimension of the vector read is smaller than the provided +/// bitwidth. class FlattenContiguousRowMajorTransferReadPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, + unsigned vectorBitwidth, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), + targetVectorBitwidth(vectorBitwidth) {} LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override { @@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern // If this is already 0D/1D, there's nothing to do. if (vectorType.getRank() <= 1) return failure(); + if (!vectorType.getElementType().isSignlessIntOrFloat()) + return failure(); + unsigned trailingVectorDimBitwidth = + vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); + if (trailingVectorDimBitwidth >= targetVectorBitwidth) + return failure(); if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); // TODO: generalize this pattern, relax the requirements here. @@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern transferReadOp, cast(vector.getType()), flatRead); return success(); } + +private: + // Minimum bitwidth that the trailing vector dimension should have after + // flattening. + unsigned targetVectorBitwidth; }; /// Rewrites contiguous row-major vector.transfer_write ops by inserting @@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern /// already reduced i.e. without unit dims. class FlattenContiguousRowMajorTransferWritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, + unsigned vectorBitwidth, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), + targetVectorBitwidth(vectorBitwidth) {} LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override { @@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); + if (!vectorType.getElementType().isSignlessIntOrFloat()) + return failure(); + unsigned trailingVectorDimBitwidth = + vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); + if (trailingVectorDimBitwidth >= targetVectorBitwidth) + return failure(); if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); int64_t firstContiguousInnerDim = @@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern rewriter.eraseOp(transferWriteOp); return success(); } + +private: + // Minimum bitwidth that the trailing vector dimension should have after + // flattening. + unsigned targetVectorBitwidth; }; /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` @@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns( } void mlir::vector::populateFlattenVectorTransferPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, unsigned targetVectorBitwidth, + PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + patterns.getContext(), targetVectorBitwidth, benefit); populateShapeCastFoldingPatterns(patterns, benefit); populateDropUnitDimWithShapeCastPatterns(patterns, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 9976048a3320b..1775b5fa4a346 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B func.func @transfer_read_dims_match_contiguous( %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { @@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous( // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> // CHECK: return %[[VEC2D]] +// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_read_dims_match_contiguous_empty_stride( @@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride( return %v : vector<5x4x3x2xi8> } -// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride +// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( // CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] // CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> // CHECK: return %[[VEC2D]] +// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( +// CHECK-128B: memref.collapse_shape + // ----- // The shape of the memref and the vector don't match, but the vector is a @@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous( // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous( +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_read_dims_mismatch_non_zero_indices( @@ -66,7 +76,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( %m_out: memref<1x2x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x43x4x6xi32>, vector<1x2x6xi32> vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x2x6xi32> @@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32> // CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32> +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( +// CHECK-128B-NOT: memref.collapse_shape + // ----- // The input memref has a dynamic trailing shape and hence is not flattened. @@ -99,7 +112,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( %m_out: memref<1x2x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x?x4x6xi32>, vector<1x2x6xi32> vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x2x6xi32> @@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32> // CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32> +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @transfer_read_dims_mismatch_non_contiguous( @@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( @@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( return %v : vector<2x1x2x2xi8> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @transfer_write_dims_match_contiguous( @@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous( return } -// CHECK-LABEL: func @transfer_write_dims_match_contiguous +// CHECK-LABEL: func @transfer_write_dims_match_contiguous( // CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> // CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> // CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous( +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_write_dims_mismatch_contiguous( @@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous( // CHECK: return // CHECK: } +// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_write_dims_mismatch_non_contiguous( @@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast +// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @transfer_write_0d(%arg : memref, %vec : vector) { @@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref, %vec : vector) { // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast +// CHECK-128B-LABEL: func @transfer_write_0d( +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + // ----- func.func @transfer_read_0d(%arg : memref) -> vector { @@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref) -> vector { // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast +// CHECK-128B-LABEL: func @transfer_read_0d( +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + // ----- func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { @@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> // CHECK: return %[[VEC2D]] : vector<8x4xi8> +// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices( +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref>, %arg1 : index, %arg2 : index) { @@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto // CHECK-SAME: {in_bounds = [true]} // CHECK-SAME: : vector<32xi8>, memref +// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices( +// CHECK-128B: memref.collapse_shape + // ----- func.func @transfer_read_flattenable_negative( @@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative( // CHECK-LABEL: func @transfer_read_flattenable_negative // CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> +// CHECK-128B-LABEL: func @transfer_read_flattenable_negative( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @transfer_read_flattenable_negative2( @@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2( // CHECK-LABEL: func @transfer_read_flattenable_negative2 // CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8> +// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { @@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32> // CHECK: return %[[VAL_4]] : vector<1x8xi32> +// CHECK-128B-LABEL: func @fold_unit_dim_add_basic( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> { @@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32> // CHECK: return %[[VAL_4]] : vector<1x8x1xi32> +// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, @@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, // CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32> // CHECK: return %[[VAL_4]] : vector<8xi32> +// CHECK-128B-LABEL: func @fold_unit_dim_add( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, @@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, // CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32> // CHECK: return %[[VAL_4]] : vector<8x[2]xf32> +// CHECK-128B-LABEL: func @fold_unit_dim_mulf( +// CHECK-128B-NOT: memref.collapse_shape + // ----- func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { @@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> // CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32> // CHECK: return %[[VAL_2]] : vector<8x[2]xf32> +// CHECK-128B-LABEL: func @fold_unit_dim_sitofp( +// CHECK-128B-NOT: memref.collapse_shape + // ----- // All shape casts are folded away @@ -389,3 +455,7 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> // CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> // CHECK: return %[[VAL_4]] : vector<8xi32> + +// CHECK-128B-LABEL: func @fold_unit_dims_entirely( +// CHECK-128B-NOT: memref.collapse_shape + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index acd38980514a5..178a58e796b24 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -466,21 +466,35 @@ struct TestFlattenVectorTransferPatterns MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestFlattenVectorTransferPatterns) + TestFlattenVectorTransferPatterns() = default; + TestFlattenVectorTransferPatterns( + const TestFlattenVectorTransferPatterns &pass) + : PassWrapper(pass) {} + StringRef getArgument() const final { return "test-vector-transfer-flatten-patterns"; } + StringRef getDescription() const final { return "Test patterns to rewrite contiguous row-major N-dimensional " "vector.transfer_{read,write} ops into 1D transfers"; } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); } + + Option targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits::max())}; + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateFlattenVectorTransferPatterns(patterns); + populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } };