-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Legalize certain vector.transfer_read
ops of scalable vectors
#143146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Legalize certain vector.transfer_read
ops of scalable vectors
#143146
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesTHis patch add a transform of Full diff: https://github.com/llvm/llvm-project/pull/143146.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref<?x?x2x8xi8> into memref<?x?xi8>
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref<?x?xi8>, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+ // We handle transfers of vectors with rank >= 2 and a single scalable
+ // dimension.
+ VectorType origVT = readOp.getVectorType();
+ ArrayRef<bool> origScalableDims = origVT.getScalableDims();
+ const int64_t origVRank = origVT.getRank();
+ if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+ // Number of trailing dimensions to collapse, including the scalable
+ // dimension. Nothing to do if the single scalable dimension is already the
+ // last one.
+ const int64_t numCollapseDims = std::distance(
+ llvm::find(origScalableDims, true), origScalableDims.end());
+ if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+ // We want a simple memref (not a tensor) with contiguous elements for at
+ // least all the trailing dimensions up to and including the scalable one.
+ auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
+ if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+ // The collapsed dimensions (excluding the scalable one) of the vector and
+ // the memref must match and the corresponding indices must be in-bounds (it
+ // follows these indices would be zero). This guarantees that the operation
+ // transfers a contiguous block.
+ if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+ SmallVector<bool> origInBounds = readOp.getInBoundsValues();
+ if (!llvm::all_of(
+ ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
+ [](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+ // Collapse the trailing dimensions of the memref.
+ SmallVector<ReassociationIndices> reassoc;
+ for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+ for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
+ ++i)
+ reassoc.back().push_back(i);
+ if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
+ return failure();
+ Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
+ readOp.getLoc(), readOp.getBase(), reassoc);
+
+ // Get a vector type with collapsed trailing dimensions.
+ SmallVector<int64_t> shape(origVT.getShape());
+ for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
+ shape[origVRank - numCollapseDims] *= shape[i];
+ shape.pop_back_n(numCollapseDims - 1);
+ auto collapsedVT =
+ VectorType::get(shape, origVT.getElementType(),
+ origScalableDims.drop_back(numCollapseDims - 1));
+
+ // Drop the extra (zero) indices.
+ auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
+
+ // Create the new `transfer_read`.
+ auto newReadOp = rewriter.create<vector::TransferReadOp>(
+ readOp.getLoc(), collapsedVT, collapsedMem, indices,
+ ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
+
+ // Cast back to the orignal vector type.
+ auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
+ origVT, newReadOp);
+
+ rewriter.replaceOp(readOp, toOrigShape);
+ return success();
+ }
+};
+
} // namespace
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
@@ -306,7 +413,8 @@ void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
LegalizeSVEMaskAllocation<memref::AllocaOp>,
LegalizeSVEMaskAllocation<memref::AllocOp>,
LegalizeSVEMaskTypeCastConversion,
- LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
+ LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
+ LegalizeTransferRead>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
new file mode 100644
index 0000000000000..d12a2c11bbdba
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -0,0 +1,226 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x?xi8>, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
+// CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref<?x?x?x8xi8>, vector<[8]xi8>
+
+ return %A : vector<[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_1
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x[8]xi8>
+
+ return %A : vector<8x[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_type_not_supported
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]x[8]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x?x8xi8>, vector<[8]x[8]x8xi8>
+
+ return %A : vector<[8]x[8]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_non_mem
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : tensor<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_0
+// CHECK-NOT: memref.collapse
+
+#s3 = strided<[?, ?, 16, 1]>
+
+func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_1
+// CHECK-NOT: memref.collapse
+
+#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
+
+func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_read_strided_vec
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
+
+ return %A : vector<[4]x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_bcast_transp
+// CHECK-NOT: memref.collapse
+
+#perm = affine_map<(i, j, k, p) -> (k, 0)>
+
+func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
new file mode 100644
index 0000000000000..7f68d8f7ab848
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
@@ -0,0 +1,72 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --lower-affine --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func private @printVec(%v : vector<[32]xi8>) {
+ %v0 = vector.scalable.extract %v[0] : vector<[16]xi8> from vector<[32]xi8>
+ %v1 = vector.scalable.extract %v[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %v0 : vector<[16]xi8>
+ vector.print %v1 : vector<[16]xi8>
+ return
+}
+
+func.func @transfer_read_scalable_not_rightmost(%vs : i32, %M : memref<?x?x?x8xi8>) {
+ func.call @setArmVLBits(%vs) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+ %A = vector.transfer_read %M[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ %B = vector.shape_cast %A : vector<[4]x8xi8> to vector<[32]xi8>
+ func.call @printVec(%B) : (vector<[32]xi8>) -> ()
+
+ return
+}
+
+func.func @main() {
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+ %A0_cst = arith.constant dense<[[11, 12, 13, 14, 15, 16, 17, 18],
+ [21, 22, 23, 24, 25, 26, 27, 28],
+ [31, 32, 33, 34, 35, 36, 37, 38],
+ [41, 42, 43, 44, 45, 46, 47, 48]]> : vector<4x8xi8>
+
+ %A1_cst = arith.constant dense<[[51, 52, 53, 54, 55, 56, 57, 58],
+ [61, 62, 63, 64, 65, 66, 67, 68],
+ [71, 72, 73, 74, 75, 76, 77, 78],
+ [81, 82, 83, 84, 85, 86, 87, 88]]> : vector<4x8xi8>
+
+ %M = memref.alloca() : memref<1x2x4x8xi8>
+ vector.transfer_write %A0_cst, %M[%c0, %c0, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+ vector.transfer_write %A1_cst, %M[%c0, %c1, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+
+ %MM = memref.cast %M : memref<1x2x4x8xi8> to memref<?x?x?x8xi8>
+
+// CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 )
+// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+ %c128 = arith.constant 128 : i32
+ func.call @transfer_read_scalable_not_rightmost(%c128, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
+ %c256 = arith.constant 256 : i32
+ func.call @transfer_read_scalable_not_rightmost(%c256, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+ return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f60e73d
to
1210d59
Compare
4d13aa2
to
413d9dc
Compare
1210d59
to
3b17c94
Compare
413d9dc
to
5496f97
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, Momchil - thank you!
I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in MLIR's Testing Guide a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better.
For additional context, this RFC provides some of the rationale behind that approach.
Also - what about memrefs with dynamic dimensions?
VectorType origVT = readOp.getVectorType(); | ||
ArrayRef<bool> origScalableDims = origVT.getScalableDims(); | ||
const int64_t origVRank = origVT.getRank(); | ||
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] getNumScalableDims would be more canonical then llvm::count
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (!readOp.getPermutationMap().isMinorIdentity()) | ||
return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would supporting non-identity be a problem? It would be good to add a comment, either:
TODO: We haven't required this, so leaving for later.
or- "Too complex because of , disabling".
Any hint for future developers would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// We handle transfers of vectors with rank >= 2 and a single scalable | ||
// dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] It would be helpful to add why:
- Don't need to worry about 1D, that's supported by default.
- More than 1 scalable dims are tricky (how to collapse e.g.
vscale * vscale
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added.
// The collapsed dimensions (excluding the scalable one) of the vector and | ||
// the memref must match and the corresponding indices must be in-bounds (it | ||
// follows these indices would be zero). This guarantees that the operation | ||
// transfers a contiguous block. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The collapsed dimensions (excluding the scalable one) of the vector and
// the memref must match
What about dynamic dim sizes in the memref? If that's not supported, is there a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part wasn't tested at all. Test cases added.
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1), | ||
[](bool v) { return v; })) | ||
return rewriter.notifyMatchFailure(readOp, | ||
"out-if-bounds index to collapse"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, it's not really index that's out-of-bounds, but the corresponding memory access. So, index could be in-bounds, but we might be reading "more" than there's available to read (starting at that index). For example:
vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
"out-if-bounds index to collapse"); | |
"out-of-bounds index to collapse"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
#s3 = strided<[?, ?, 16, 1]> | ||
|
||
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Avoid "magic" suffixes likes _0
.
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { | |
func.func @negative_test_discont_mem_due_to_strides(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
#layout = affine_map<(i, j, k, p) -> (j, i, k, p)> | ||
|
||
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Same as above.
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { | |
func.func @negative_test_discontig_mem_due_to_maps(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test removed, no need to test here all the possible ways a memref could be discontinuous.
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> { | ||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8> | ||
|
||
return %A : vector<[4]x4xi8> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What makes this a negative test? It says "strided vec", but I'm not sure what you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's garbage, deleted.
func.func @negative_test_vector_mask( | ||
%i : index, %j : index, | ||
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.mask %mask { | ||
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8> | ||
} : vector<[4]x8xi1> -> vector<[4]x8xi8> | ||
|
||
return %A : vector<[4]x8xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @negative_test_mask_operand | ||
// CHECK-NOT: memref.collapse | ||
|
||
func.func @negative_test_mask_operand( | ||
%i : index, %j : index, | ||
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8> | ||
|
||
return %A : vector<[4]x8xi8> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the past, I would differentiate these are:
- "masked" (
vector.mask {vector. transfer_read}
), vs - "with_mask" (
vector.transfer_read %mask
)
Would you mind following similar convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mixing fixed-width and scalable vectors. Lets avoid that until we understand better how to mix VLA + VLS programming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test tweaked a bit.
5496f97
to
e422213
Compare
358f5da
to
1740d45
Compare
THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.
e422213
to
9d76736
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments!
Hi @momchil-velikov I think you added me as a reviewer for this PR? If you would definitely like me to take a look please let me know, otherwise I'll pass because it looks like it is ARM specific |
No need, thank you! |
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8> | ||
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %[[C0]]], %[[PAD]] {in_bounds = [true]} | ||
// CHECK-SAME: : memref<?x?x?xi8>, vector<[32]xi8> | ||
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own education: what does mean in scalable world ?
Seems like it encodes (32 * vscale) -> (4*vscale) x 8 but I am not "seeing" it.
Is there something I can read to refresh my mental model here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so this is specifically to target Arm's i8mm (i.e. dot product instructions) for SVE. Here are two links:
You may recall that that extension requires the input data to be in a specific format. Nothing particularly fancy and, e.g. packing into tiles of vector<2x8xi8>
for NEON would do (vector<4x8xi8>
would be unrolled to vector<2x8xi8>
). At hardware level though, we would load vector<16xi8>
rather than a 2D vectors (there are no 2D load instructions).
For SVE, we simply make the N
dimension "scalable, which gives us vector<[2]x8xi8>
. Again, since we cannot load 2D vectors, we "interpret" that as vector<[16]xi8>
In these cases, vector.shape_cast
as no-op and helps us to get from higher-level abstraction to hardware-level representation.
Is there something I can read to refresh my mental model here?
There's quite a lot of fine details here and I want to avoid going on too much of a tangent, so just ask for more clarification if this is still unclear. Personally, I like this white paper a lot:
Thanks!
EDIT
One important design-point to keep in mind:
- LLVM supports "fixed" arrays of scalable vectors (so e.g.
vector<2x[8]xi8>
is effectively supported), but, - it does not support "scalable" arrays of vectors (so .e.g
vector<[2]x8xi8>
is not supported and we need to somehow decompose/lower/re-interpret this at the MLIR level).
llvm#143146) This patch adds a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single trailing scalable dimension.
THis patch add a transform of
transfer_read
operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.