Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
/// to transform multiple small n-D transfers into a larger 1-D transfer where
/// the memref contiguity properties allow it.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
///
/// Flattening is only applied if the bitwidth of the trailing vector dimension
/// is smaller or equal to `targetVectorBitwidth`.
void populateFlattenVectorTransferPatterns(
RewritePatternSet &patterns,
unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1);

/// Collect a set of patterns that bubble up/down bitcast ops.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -535,9 +534,17 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
public:
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}

LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
Expand All @@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be cases that only partial slice is contiguous. In this case, we could flatten trailing dims. I wonder if we will relax this a little more in the near future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.

return failure();
// TODO: generalize this pattern, relax the requirements here.
Expand Down Expand Up @@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
}

private:
// Minimum bitwidth that the trailing vector dimension should have after
// flattening.
unsigned targetVectorBitwidth;
};

/// Rewrites contiguous row-major vector.transfer_write ops by inserting
Expand All @@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
/// already reduced i.e. without unit dims.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
public:
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}

LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
Expand All @@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
Expand Down Expand Up @@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
rewriter.eraseOp(transferWriteOp);
return success();
}

private:
// Minimum bitwidth that the trailing vector dimension should have after
// flattening.
unsigned targetVectorBitwidth;
Comment on lines +736 to +738
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after flattening? It seems not correct with the implementation. With targetVectorBitWidth=128 and vector<1x2x6xi32> type, it becomes vector<12xi32> after flattening; the bit-width of trailing dim is 384. It sounds like we should flatten it with the configuration.

How about update it to Maximum bitwidth that ... before flattening?

};

/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
Expand Down Expand Up @@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
}

void mlir::vector::populateFlattenVectorTransferPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, unsigned targetVectorBitwidth,
PatternBenefit benefit) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), benefit);
patterns.getContext(), targetVectorBitwidth, benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
}
36 changes: 34 additions & 2 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
Comment on lines +79 to 82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something for the future ...

This example, on a machine with vectors which are 128 bits wide (e.g. Arm) would actually benefit from flattening. With 6 elements, we'd use 1.5 vector registers. And and with 12 elements, we'd use 3. That would be better utilization.

Would that make sense as a TODO? (not asking for it in this patch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.

Expand Down Expand Up @@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
Expand Down Expand Up @@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>

// -----

func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
%arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i32
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
return %v : vector<5x4x3x20xi32>
}

// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
// CHECK-NOT: tensor.collapse_shape

// -----

func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
%arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
%arg1 : vector<5x4x3x20xi32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
return
}

// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
// CHECK-NOT: tensor.collapse_shape




3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
constexpr unsigned targetVectorBitwidth = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this a pass option and the toggle it in the test to demonstrate how it impacts the pattern?

populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down