Skip to content

[MLIR] Legalize certain vector.transfer_read ops of scalable vectors #143146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

momchil-velikov
Copy link
Collaborator

THis patch add a transform of transfer_read operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

THis patch add a transform of transfer_read operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.


Full diff: https://github.com/llvm/llvm-project/pull/143146.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp (+109-1)
  • (added) mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir (+226)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir (+72)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
   }
 };
 
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+///   {in_bounds = [false, true, true, true]}
+///   : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+///   : memref<?x?x2x8xi8> into memref<?x?xi8>
+/// %0 = vector.transfer_read  %collapse_shape[%i, %j], %c0_i8
+///   {in_bounds = [false, true]}
+///   : memref<?x?xi8>, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!readOp.getPermutationMap().isMinorIdentity())
+      return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+    // We handle transfers of vectors with rank >= 2 and a single scalable
+    // dimension.
+    VectorType origVT = readOp.getVectorType();
+    ArrayRef<bool> origScalableDims = origVT.getScalableDims();
+    const int64_t origVRank = origVT.getRank();
+    if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+      return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+    // Number of trailing dimensions to collapse, including the scalable
+    // dimension.  Nothing to do if the single scalable dimension is already the
+    // last one.
+    const int64_t numCollapseDims = std::distance(
+        llvm::find(origScalableDims, true), origScalableDims.end());
+    if (numCollapseDims < 2)
+      return rewriter.notifyMatchFailure(readOp,
+                                         "scalable dimension is trailing");
+
+    // We want a simple memref (not a tensor) with contiguous elements for at
+    // least all the trailing dimensions up to and including the scalable one.
+    auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
+    if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+      return rewriter.notifyMatchFailure(
+          readOp, "non-contiguous memref dimensions to collapse");
+
+    // The collapsed dimensions (excluding the scalable one) of the vector and
+    // the memref must match and the corresponding indices must be in-bounds (it
+    // follows these indices would be zero). This guarantees that the operation
+    // transfers a contiguous block.
+    if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+                     origVT.getShape().take_back(numCollapseDims - 1)))
+      return rewriter.notifyMatchFailure(
+          readOp, "memref and vector dimensions do not match");
+
+    SmallVector<bool> origInBounds = readOp.getInBoundsValues();
+    if (!llvm::all_of(
+            ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
+            [](bool v) { return v; }))
+      return rewriter.notifyMatchFailure(readOp,
+                                         "out-if-bounds index to collapse");
+
+    // Collapse the trailing dimensions of the memref.
+    SmallVector<ReassociationIndices> reassoc;
+    for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+      reassoc.push_back({i});
+    for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
+         ++i)
+      reassoc.back().push_back(i);
+    if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
+      return failure();
+    Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
+        readOp.getLoc(), readOp.getBase(), reassoc);
+
+    // Get a vector type with collapsed trailing dimensions.
+    SmallVector<int64_t> shape(origVT.getShape());
+    for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
+      shape[origVRank - numCollapseDims] *= shape[i];
+    shape.pop_back_n(numCollapseDims - 1);
+    auto collapsedVT =
+        VectorType::get(shape, origVT.getElementType(),
+                        origScalableDims.drop_back(numCollapseDims - 1));
+
+    // Drop the extra (zero) indices.
+    auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
+
+    // Create the new `transfer_read`.
+    auto newReadOp = rewriter.create<vector::TransferReadOp>(
+        readOp.getLoc(), collapsedVT, collapsedMem, indices,
+        ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
+
+    // Cast back to the orignal vector type.
+    auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
+                                                            origVT, newReadOp);
+
+    rewriter.replaceOp(readOp, toOrigShape);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
@@ -306,7 +413,8 @@ void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
                LegalizeSVEMaskAllocation<memref::AllocaOp>,
                LegalizeSVEMaskAllocation<memref::AllocOp>,
                LegalizeSVEMaskTypeCastConversion,
-               LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
+               LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
+               LegalizeTransferRead>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
new file mode 100644
index 0000000000000..d12a2c11bbdba
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -0,0 +1,226 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL:       @test_base_case
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK:               %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1], [2, 3]]
+// CHECK-SAME:            : memref<?x?x?x8xi8> into memref<?x?x?xi8>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?x?xi8>, vector<[32]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_using_strided_layout
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1], [2, 3]]
+// CHECK-SAME:            : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
+// CHECK-SAME:              memref<?x?x?xi8, strided<[?, ?, 1]>>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_3d_vector
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2, 3]]
+// CHECK-SAME:            : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME:              memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
+
+  return %A : vector<[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_4d_vector
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME:             memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT:         %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME:           : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
+// CHECK-NEXT:         %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT:         return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
+
+  return %A : vector<2x[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x8xi8>
+
+  return %A : vector<8x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref<?x?x?x8xi8>, vector<[8]xi8>
+
+  return %A : vector<[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_1
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x[8]xi8>
+
+  return %A : vector<8x[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_type_not_supported
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]x[8]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x?x8xi8>, vector<[8]x[8]x8xi8>
+
+  return %A : vector<[8]x[8]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_non_mem
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : tensor<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_0
+// CHECK-NOT: memref.collapse
+
+#s3 = strided<[?, ?, 16, 1]>
+
+func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_1
+// CHECK-NOT: memref.collapse
+
+#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
+
+func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_read_strided_vec
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
+
+  return %A : vector<[4]x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_bcast_transp
+// CHECK-NOT: memref.collapse
+
+#perm = affine_map<(i, j, k, p) -> (k, 0)>
+
+func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
new file mode 100644
index 0000000000000..7f68d8f7ab848
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
@@ -0,0 +1,72 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata    --lower-affine --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func private @printVec(%v : vector<[32]xi8>) {
+  %v0 = vector.scalable.extract %v[0] : vector<[16]xi8> from vector<[32]xi8>
+  %v1 = vector.scalable.extract %v[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %v0 : vector<[16]xi8>
+  vector.print %v1 : vector<[16]xi8>
+  return
+}
+
+func.func @transfer_read_scalable_not_rightmost(%vs : i32, %M : memref<?x?x?x8xi8>) {
+  func.call @setArmVLBits(%vs) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+  %A = vector.transfer_read %M[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  %B = vector.shape_cast %A : vector<[4]x8xi8> to vector<[32]xi8>
+  func.call @printVec(%B) : (vector<[32]xi8>) -> ()
+
+  return
+}
+
+func.func @main() {
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+  %A0_cst = arith.constant dense<[[11, 12, 13, 14, 15, 16, 17, 18],
+                                  [21, 22, 23, 24, 25, 26, 27, 28],
+                                  [31, 32, 33, 34, 35, 36, 37, 38],
+                                  [41, 42, 43, 44, 45, 46, 47, 48]]> : vector<4x8xi8>
+
+  %A1_cst = arith.constant dense<[[51, 52, 53, 54, 55, 56, 57, 58],
+                                  [61, 62, 63, 64, 65, 66, 67, 68],
+                                  [71, 72, 73, 74, 75, 76, 77, 78],
+                                  [81, 82, 83, 84, 85, 86, 87, 88]]> : vector<4x8xi8>
+
+  %M = memref.alloca() : memref<1x2x4x8xi8>
+  vector.transfer_write %A0_cst, %M[%c0, %c0, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+  vector.transfer_write %A1_cst, %M[%c0, %c1, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+
+  %MM = memref.cast %M : memref<1x2x4x8xi8> to memref<?x?x?x8xi8>
+
+// CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 )
+// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+  %c128 = arith.constant 128 : i32
+  func.call @transfer_read_scalable_not_rightmost(%c128, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
+  %c256 = arith.constant 256 : i32
+  func.call @transfer_read_scalable_not_rightmost(%c256, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+  return
+}

Copy link

github-actions bot commented Jun 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from f60e73d to 1210d59 Compare June 6, 2025 15:27
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch 2 times, most recently from 4d13aa2 to 413d9dc Compare June 9, 2025 16:42
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 1210d59 to 3b17c94 Compare June 9, 2025 16:42
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch from 413d9dc to 5496f97 Compare June 13, 2025 16:42
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, Momchil - thank you!

I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in MLIR's Testing Guide a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better.

For additional context, this RFC provides some of the rationale behind that approach.

Also - what about memrefs with dynamic dimensions?

VectorType origVT = readOp.getVectorType();
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
const int64_t origVRank = origVT.getRank();
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] getNumScalableDims would be more canonical then llvm::count

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +342 to +352
if (!readOp.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would supporting non-identity be a problem? It would be good to add a comment, either:

  • TODO: We haven't required this, so leaving for later. or
  • "Too complex because of , disabling".

Any hint for future developers would be helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 345 to 346
// We handle transfers of vectors with rank >= 2 and a single scalable
// dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] It would be helpful to add why:

  • Don't need to worry about 1D, that's supported by default.
  • More than 1 scalable dims are tricky (how to collapse e.g. vscale * vscale?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added.

Comment on lines 369 to 372
// The collapsed dimensions (excluding the scalable one) of the vector and
// the memref must match and the corresponding indices must be in-bounds (it
// follows these indices would be zero). This guarantees that the operation
// transfers a contiguous block.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The collapsed dimensions (excluding the scalable one) of the vector and

// the memref must match

What about dynamic dim sizes in the memref? If that's not supported, is there a test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part wasn't tested at all. Test cases added.

ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
[](bool v) { return v; }))
return rewriter.notifyMatchFailure(readOp,
"out-if-bounds index to collapse");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, it's not really index that's out-of-bounds, but the corresponding memory access. So, index could be in-bounds, but we might be reading "more" than there's available to read (starting at that index). For example:

vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
Suggested change
"out-if-bounds index to collapse");
"out-of-bounds index to collapse");

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


#s3 = strided<[?, ?, 16, 1]>

func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Avoid "magic" suffixes likes _0.

Suggested change
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
func.func @negative_test_discont_mem_due_to_strides(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>

func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Same as above.

Suggested change
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
func.func @negative_test_discontig_mem_due_to_maps(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test removed, no need to test here all the possible ways a memref could be discontinuous.

Comment on lines 203 to 199
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>

return %A : vector<[4]x4xi8>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this a negative test? It says "strided vec", but I'm not sure what you mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's garbage, deleted.

Comment on lines 233 to 255
func.func @negative_test_vector_mask(
%i : index, %j : index,
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {

%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.mask %mask {
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
} : vector<[4]x8xi1> -> vector<[4]x8xi8>

return %A : vector<[4]x8xi8>
}

// -----

// CHECK-LABEL: @negative_test_mask_operand
// CHECK-NOT: memref.collapse

func.func @negative_test_mask_operand(
%i : index, %j : index,
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {

%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>

return %A : vector<[4]x8xi8>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, I would differentiate these are:

  • "masked" (vector.mask {vector. transfer_read}), vs
  • "with_mask" (vector.transfer_read %mask)

Would you mind following similar convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mixing fixed-width and scalable vectors. Lets avoid that until we understand better how to mix VLA + VLS programming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test tweaked a bit.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch from 5496f97 to e422213 Compare June 20, 2025 11:54
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 358f5da to 1740d45 Compare June 23, 2025 09:10
Base automatically changed from users/momchil-velikov/memref-contig-slice to main June 23, 2025 13:12
THis patch add a transform  of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch from e422213 to 9d76736 Compare June 23, 2025 13:41
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments!

@newling
Copy link
Contributor

newling commented Jun 24, 2025

Hi @momchil-velikov I think you added me as a reviewer for this PR? If you would definitely like me to take a look please let me know, otherwise I'll pass because it looks like it is ARM specific

@momchil-velikov
Copy link
Collaborator Author

Hi @momchil-velikov I think you added me as a reviewer for this PR? If you would definitely like me to take a look please let me know, otherwise I'll pass because it looks like it is ARM specific

No need, thank you!

@momchil-velikov momchil-velikov merged commit 3b251cd into main Jun 25, 2025
7 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/legalize-scalable-transfer-read branch June 25, 2025 08:54
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %[[C0]]], %[[PAD]] {in_bounds = [true]}
// CHECK-SAME: : memref<?x?x?xi8>, vector<[32]xi8>
// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own education: what does mean in scalable world ?
Seems like it encodes (32 * vscale) -> (4*vscale) x 8 but I am not "seeing" it.
Is there something I can read to refresh my mental model here?

Copy link
Contributor

@banach-space banach-space Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this is specifically to target Arm's i8mm (i.e. dot product instructions) for SVE. Here are two links:

You may recall that that extension requires the input data to be in a specific format. Nothing particularly fancy and, e.g. packing into tiles of vector<2x8xi8> for NEON would do (vector<4x8xi8> would be unrolled to vector<2x8xi8>). At hardware level though, we would load vector<16xi8> rather than a 2D vectors (there are no 2D load instructions).

For SVE, we simply make the N dimension "scalable, which gives us vector<[2]x8xi8>. Again, since we cannot load 2D vectors, we "interpret" that as vector<[16]xi8>

In these cases, vector.shape_cast as no-op and helps us to get from higher-level abstraction to hardware-level representation.

Is there something I can read to refresh my mental model here?

There's quite a lot of fine details here and I want to avoid going on too much of a tangent, so just ask for more clarification if this is still unclear. Personally, I like this white paper a lot:

Thanks!

EDIT
One important design-point to keep in mind:

  • LLVM supports "fixed" arrays of scalable vectors (so e.g. vector<2x[8]xi8> is effectively supported), but,
  • it does not support "scalable" arrays of vectors (so .e.g vector<[2]x8xi8> is not supported and we need to somehow decompose/lower/re-interpret this at the MLIR level).

anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
llvm#143146)

This patch adds a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
trailing scalable dimension.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants