-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Vector] Add vector bitwidth target to xfer op flattening #81966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,6 @@ | |
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" | ||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
|
@@ -535,9 +534,17 @@ namespace { | |
/// memref.collapse_shape on the source so that the resulting | ||
/// vector.transfer_read has a 1D source. Requires the source shape to be | ||
/// already reduced i.e. without unit dims. | ||
/// If `targetVectorBitwidth` is provided, the flattening will only happen if | ||
/// the trailing dimension of the vector read is smaller than the provided | ||
/// bitwidth. | ||
class FlattenContiguousRowMajorTransferReadPattern | ||
: public OpRewritePattern<vector::TransferReadOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
public: | ||
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, | ||
unsigned vectorBitwidth, | ||
PatternBenefit benefit) | ||
: OpRewritePattern<vector::TransferReadOp>(context, benefit), | ||
targetVectorBitwidth(vectorBitwidth) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, | ||
PatternRewriter &rewriter) const override { | ||
|
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern | |
// If this is already 0D/1D, there's nothing to do. | ||
if (vectorType.getRank() <= 1) | ||
return failure(); | ||
if (!vectorType.getElementType().isSignlessIntOrFloat()) | ||
return failure(); | ||
unsigned trailingVectorDimBitwidth = | ||
vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); | ||
if (trailingVectorDimBitwidth >= targetVectorBitwidth) | ||
return failure(); | ||
if (!vector::isContiguousSlice(sourceType, vectorType)) | ||
return failure(); | ||
// TODO: generalize this pattern, relax the requirements here. | ||
|
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern | |
transferReadOp, cast<VectorType>(vector.getType()), flatRead); | ||
return success(); | ||
} | ||
|
||
private: | ||
// Minimum bitwidth that the trailing vector dimension should have after | ||
// flattening. | ||
unsigned targetVectorBitwidth; | ||
}; | ||
|
||
/// Rewrites contiguous row-major vector.transfer_write ops by inserting | ||
|
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern | |
/// already reduced i.e. without unit dims. | ||
class FlattenContiguousRowMajorTransferWritePattern | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
public: | ||
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, | ||
unsigned vectorBitwidth, | ||
PatternBenefit benefit) | ||
: OpRewritePattern<vector::TransferWriteOp>(context, benefit), | ||
targetVectorBitwidth(vectorBitwidth) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, | ||
PatternRewriter &rewriter) const override { | ||
|
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern | |
if (vectorType.getRank() <= 1) | ||
// Already 0D/1D, nothing to do. | ||
return failure(); | ||
if (!vectorType.getElementType().isSignlessIntOrFloat()) | ||
return failure(); | ||
unsigned trailingVectorDimBitwidth = | ||
vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); | ||
if (trailingVectorDimBitwidth >= targetVectorBitwidth) | ||
return failure(); | ||
if (!vector::isContiguousSlice(sourceType, vectorType)) | ||
return failure(); | ||
int64_t firstContiguousInnerDim = | ||
|
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern | |
rewriter.eraseOp(transferWriteOp); | ||
return success(); | ||
} | ||
|
||
private: | ||
// Minimum bitwidth that the trailing vector dimension should have after | ||
// flattening. | ||
unsigned targetVectorBitwidth; | ||
Comment on lines
+736
to
+738
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How about update it to |
||
}; | ||
|
||
/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` | ||
|
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns( | |
} | ||
|
||
void mlir::vector::populateFlattenVectorTransferPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
RewritePatternSet &patterns, unsigned targetVectorBitwidth, | ||
PatternBenefit benefit) { | ||
patterns.add<FlattenContiguousRowMajorTransferReadPattern, | ||
FlattenContiguousRowMajorTransferWritePattern>( | ||
patterns.getContext(), benefit); | ||
patterns.getContext(), targetVectorBitwidth, benefit); | ||
populateShapeCastFoldingPatterns(patterns, benefit); | ||
populateDropUnitDimWithShapeCastPatterns(patterns, benefit); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( | |
%m_out: memref<1x2x6xi32>) { | ||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||
memref<1x43x4x6xi32>, vector<1x2x6xi32> | ||
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : | ||
vector<1x2x6xi32>, memref<1x2x6xi32> | ||
Comment on lines
+79
to
82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something for the future ... This example, on a machine with vectors which are 128 bits wide (e.g. Arm) would actually benefit from flattening. With 6 elements, we'd use 1.5 vector registers. And and with 12 elements, we'd use 3. That would be better utilization. Would that make sense as a TODO? (not asking for it in this patch) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector. |
||
|
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( | |
%m_out: memref<1x2x6xi32>) { | ||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||
memref<1x?x4x6xi32>, vector<1x2x6xi32> | ||
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : | ||
vector<1x2x6xi32>, memref<1x2x6xi32> | ||
|
@@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, | |
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> | ||
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> | ||
// CHECK: return %[[VAL_4]] : vector<8xi32> | ||
|
||
// ----- | ||
|
||
func.func @trailing_dim_larger_than_target_vector_bitwidth_read( | ||
%arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0 : i32 | ||
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : | ||
memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32> | ||
return %v : vector<5x4x3x20xi32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_read( | ||
// CHECK-NOT: tensor.collapse_shape | ||
|
||
// ----- | ||
|
||
func.func @trailing_dim_larger_than_target_vector_bitwidth_write( | ||
%arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, | ||
%arg1 : vector<5x4x3x20xi32>) { | ||
%c0 = arith.constant 0 : index | ||
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] : | ||
vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_write( | ||
// CHECK-NOT: tensor.collapse_shape | ||
|
||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns | |
} | ||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
populateFlattenVectorTransferPatterns(patterns); | ||
constexpr unsigned targetVectorBitwidth = 512; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you make this a pass option and the toggle it in the test to demonstrate how it impacts the pattern? |
||
populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be cases that only partial slice is contiguous. In this case, we could flatten trailing dims. I wonder if we will relax this a little more in the near future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.