Skip to content

[mlir][vector][memref] Add alignment attribute to memory access ops #144344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jun 16, 2025

Alignment information is important to allow LLVM backends such as AMDGPU to select wide memory accesses (e.g., dwordx4 or b128). Since this info is not always inferable, it's better to inform LLVM backends explicitly about it.

This patch introduces alignment attribute to MemRef/Vector memory access ops. The propagation of these attributes to LLVM/SPIR-V will be implemented in a separate follow-up PR.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: None (tyb0807)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144344.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-3)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+48-2)
  • (added) mlir/test/Dialect/MemRef/load-store-alignment.mlir (+27)
  • (added) mlir/test/Dialect/Vector/load-store-alignment.mlir (+27)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (modified) mlir/unittests/Dialect/MemRef/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp (+88)
  • (added) mlir/unittests/Dialect/Vector/CMakeLists.txt (+7)
  • (added) mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp (+95)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 77e3074661abf..160b04e452c5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1227,7 +1227,45 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memref, indices, false, alignment);
+    }]>,
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1924,13 +1962,30 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..3cd71491bcc04 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1739,7 +1739,34 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1825,9 +1852,28 @@ def Vector_StoreOp : Vector_Op<"store"> {
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment
   );
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>
+  ];
+
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f5a5461e0ac0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i32}
+// CHECK: memref.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  memref.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f54d989dd190
--- /dev/null
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: vector.load {{.*}} {alignment = 16 : i32}
+// CHECK: vector.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index aea247547473d..34c9fb7317443 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(SPIRV)
 add_subdirectory(SMT)
 add_subdirectory(Transform)
 add_subdirectory(Utils)
+add_subdirectory(Vector)
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..87d33854fadcd 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LoadStoreAlignment.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..f0b8e93c2d0e1
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
@@ -0,0 +1,88 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}
diff --git a/mlir/unittests/Dialect/Vector/CMakeLists.txt b/mlir/unittests/Dialect/Vector/CMakeLists.txt
new file mode 100644
index 0000000000000..b23d9c2df3870
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRVectorTests
+  LoadStoreAlignment.cpp
+)
+mlir_target_link_libraries(MLIRVectorTests
+  PRIVATE
+  MLIRVectorDialect
+  )
diff --git a/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..745dd8632fe4d
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
@@ -0,0 +1,95 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}

@tyb0807 tyb0807 requested review from kuhar and ftynse June 16, 2025 12:57
@krzysz00 krzysz00 changed the title Add attribute to MemRef/Vector memory access ops Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more documentation and lit tests, doesn't need the unit tests

I'd be OK with updating the LLVM (and SPIR-V, if applicable) lowerings in this patch or in a stacked-on followup

@kuhar kuhar changed the title Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
@kuhar kuhar changed the title [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to memory access ops Jun 16, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this, this will be very useful to have in the IREE codegen.

This also needs llvm/spirv lowering changes and lit tests. We don't need unit tests. See https://mlir.llvm.org/getting_started/TestingGuide/#test-categories

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 16, 2025

Thanks for the review. Actually, I already have all this covered in lit tests. I just wanted to make sure the new builders work as intended. I guess I can just remove the unit tests?

@tyb0807 tyb0807 requested review from krzysz00 and kuhar June 16, 2025 23:54
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation? That's the approach that Triton took (AxisInfo.cpp).

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 17, 2025

Indeed, but I'm not sure if we can always infer the alignment of load/store op solely from the indexing maths. In case where this is not possible, we would need a way (less automatic) to specify this constraint, right?

@matthias-springer
Copy link
Member

Can you show an example where that would not work?

@banach-space
Copy link
Contributor

Do you ever expect to need something like this (different alignment for different Ops):

func.func @test_load_store_alignment(%memref: memref<4xi32>) {
  %c0 = arith.constant 0 : index
  %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
  vector.store %val, %memref[%c0] { alignment = 32 } : memref<4xi32>, vector<4xi32>
  return
}

I am just wondering, why do we need to "decorate" every Op with this attribute? And what logic is meant to take care of it? Why couldn't the alignment be a parameter that's passed to e.g. a conversion pass?

@matthias-springer
Copy link
Member

Alignment is a property of the memref SSA value. But we don't encode it in the memref type. We have memref.assume_alignment as a way to attach alignment information to an SSA value. The alignment can then be queried by a dataflow analysis.

There are two alternatives to this approach:

  1. Make the alignment information part of the memref type.
  2. Add attributes to each load/store op. (That's what this PR is doing.)

I'd like to make sure that we have a consistent story for dealing with alignment. Having both memref.assume_alignment and attributes on various ops seems a bit odd...

@kuhar
Copy link
Member

kuhar commented Jun 17, 2025

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation?

Alignment is a property of the memref SSA value. But we don't encode it in the memref type.

I don't think this is the case. You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op. You may end up with a memref of bytes as you lower and merge allocations etc. This is also the case with lower level IRs like llvm or spirv, e.g.: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Memory_Operands.

@matthias-springer
Copy link
Member

Good point, looks like we may want to have it on the load/store ops.

@ftynse
Copy link
Member

ftynse commented Jun 18, 2025

You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op.

At that point, maybe you should use !ptr.ptr instead and not a memref. Not necessarily opposing this change, but I don't want to blindly replicate notions from lower-level abstractions like LLVM IR and SPIR-V to a higher-level abstraction.

A stronger argument may be that alignment goes both ways and we can have overaligned and underaligned accesses compared to the natural/preferred alignment of the element type, and those should be reflected somewhere, which is not necessarily a property of the type. Underaligned accesses are more interesting because those may be an optimization hint (aligned accesses are faster) or plainly forbidden by the architecture.

Note that the attribute approach is not precluding a dataflow analysis. We can have an analysis that propagates alignment information to individual operations, e.g., by looking at the structure of subscripts and attributes on previous operations accessing the same value. Attributes can be seen as a way to preserve analysis results.

Good point, looks like we may want to have it on the load/store ops.

Should we also remove memref.assume_alignment? This operation is rather confusing because nothing precludes one from using it repeatedly on the same value and the fact that it is side-effecting (so DCE doesn't remove it) without actually having side effects has been pointed out.

@@ -1217,6 +1217,11 @@ def LoadOp : MemRef_Op<"load",
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).

An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a verifier that checks for it being a power of 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a new IntPowerOf2 AttrConstraint?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really worth the complexity? A simple three-line check in the verifier should do the trick.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Jun 18, 2025
@kuhar
Copy link
Member

kuhar commented Jun 18, 2025

At that point, maybe you should use !ptr.ptr instead and not a memref.

We want to preserve high-level information so ptr.ptr would be too late and won't work for spirv which relies on typed pointers and doesn't lower via ptr.ptr.

@kuhar
Copy link
Member

kuhar commented Jun 18, 2025

Should we also remove memref.assume_alignment? This operation is rather confusing because nothing precludes one from using it repeatedly on the same value and the fact that it is side-effecting (so DCE doesn't remove it) without actually having side effects has been pointed out.

I don't have any use for it myself, but I think this still may be useful to denote the alignment of the base pointer if you want to rely on some other analysis or llvm to deduce alignment based on gep offsets from the base pointers. I'd think that if you have multiple assume_alignment over the same memref, this should effectively form a conjunction of all the assumptions, e.g.: and(x div 2, x div 4) => x div 4) or and(x div 2, x div 3) ==> x div 6 -- I think this is unsurprising if we only allow pow2 values.

@krzysz00
Copy link
Contributor

Two notes:

  1. Currently, the lowerings of vector.load and vector.store to LLVM always assume that the pointer is aligned to the natural alignment of the vector element, not the vector. This is behavior users often want to override, either in general or on a case-by-case basis, and having an alignment attribute lets people do that
  2. I think this'll be a good moment to clarify what alignment means for sub-byte types, since we have EmulateNarrowTypes. I'd argue that if the type in the memref has width < 8 bits, then alignment should be in units of the number of elements, not the number of bytes. That'll allow the narrow type emulator to use faster lowerings when you, for example, do a vector.store %v, %m[[...], {alignment = 2 : index} : vector<2xi4>, memref<...xi4>

@ftynse
Copy link
Member

ftynse commented Jun 20, 2025

We want to preserve high-level information so ptr.ptr would be too late and won't work for spirv which relies on typed pointers and doesn't lower via ptr.ptr.

What high-level information does memref<?xi8> have? It's an array of bytes... We need a better interop between pointers and memrefs, in particular taking views of a pointer as a memref in a way that doesn't interfere with aliasing analyses, but I don't see why we should keep abusing memrefs when a pointer is actually needed.

I don't have any use for it myself, but I think this still may be useful to denote the alignment of the base pointer if you want to rely on some other analysis or llvm to deduce alignment based on gep offsets from the base pointers. I'd think that if you have multiple assume_alignment over the same memref, this should effectively form a conjunction of all the assumptions, e.g.: and(x div 2, x div 4) => x div 4) or and(x div 2, x div 3) ==> x div 6 -- I think this is unsurprising if we only allow pow2 values.

We can certainly take the LCM of two values, that's not the issue here. The issue is that one needs a dataflow analysis to see which assumptions have been stated about which value, e.g., assume_alignment may be placed later than a load in a block, but that block may be a body of the loop. And there may be another assumption before and outside that block...

@ftynse
Copy link
Member

ftynse commented Jun 20, 2025

Currently, the lowerings of vector.load and vector.store to LLVM always assume that the pointer is aligned to the natural alignment of the vector element, not the vector. This is behavior users often want to override, either in general or on a case-by-case basis, and having an alignment attribute lets people do that

There's a lot of logic duplication as well as underdefined cases in this area that I'd love to see cleaned up. For example, we could cast a scalar memref to memref<... x vector<>>, at which point it becomes to load with the alignment of vector. Except we don't really have alignment well-specified for multidimensional vectors. And the vector.load/store don't preclude non-contiguous leading dimensions...

Again, my concern here is we keep adding alternative ways to do things thus increasing overall complexity of the system. These new ways may be more flexible or general than existing mechanisms, but we must clean those up and ensure some overall coherence.

I think this'll be a good moment to clarify what alignment means for sub-byte types, since we have EmulateNarrowTypes. I'd argue that if the type in the memref has width < 8 bits, then alignment should be in units of the number of elements, not the number of bytes. That'll allow the narrow type emulator to use faster lowerings when you, for example, do a vector.store %v, %m[[...], {alignment = 2 : index} : vector<2xi4>, memref<...xi4>

Please bring this to discourse.

IMO, we could say alignment is always specified in the number of bits, not bytes, and be done with it. Having contextual rules to interpret an attribute is a ridiculous amount of unnecessary complexity.

@@ -0,0 +1,27 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the "positive" tests into test/Dialect/*/ops.mlir and negative tests into test/Dialect/*/invalid.mlir. Both exist in both dialects. We don't want to have a test file per operation.

@kuhar
Copy link
Member

kuhar commented Jun 20, 2025

What high-level information does memref<?xi8> have? It's an array of bytes...
why we should keep abusing memrefs when a pointer is actually needed

Going back to the example of merging allocations, we want to do that at the level of memref common for both llvm and spirv. At that point you do have the original memref types and can set the alignment information. Doing the same transformation at the level of spirv and close to llvm (with ptr.ptr) seems too late for me, and there's no better place to store alignment info then on load/store ops.

@banach-space
Copy link
Contributor

Currently, the lowerings of vector.load and vector.store to LLVM always assume that the pointer is aligned to the natural alignment of the vector element, not the vector. This is behavior users often want to override, either in general or on a case-by-case basis, and having an alignment attribute lets people do that

Recently (see #137389), ConvertVectorToLLVM was extended to support exactly that kind of override via the --convert-vector-to-llvm='use-vector-alignment=1' flag. Isn’t this patch duplicating that logic?

Let’s ask a specific question: could --convert-vector-to-llvm='use-vector-alignment=1' be replaced by what’s being added here? If so, great - let’s unify. If not, let’s clarify how they interact.

(Ping @sstamenova for context.)

Again, my concern here is we keep adding alternative ways to do things thus increasing overall complexity of the system. These new ways may be more flexible or general than existing mechanisms, but we must clean those up and ensure some overall coherence.

+1

I think this'll be a good moment to clarify what alignment means for sub-byte types, since we have EmulateNarrowTypes. I'd argue that if the type in the memref has width < 8 bits, then alignment should be in units of the number of elements, not the number of bytes. That'll allow the narrow type emulator to use faster lowerings when you, for example, do a vector.store %v, %m[[...], {alignment = 2 : index} : vector<2xi4>, memref<...xi4>

Please bring this to discourse.

+1

IMHO, narrow-type emulation is a poor example here. While many patterns in that pass deal with sub-byte types, some also emulate byte-sized elements (see discussion in #131529). That logic has grown organically and lacks cohesion - it’s not an ideal foundation for defining alignment semantics.


To be clear, I’m not looking to block this patch - but we now have multiple mechanisms to control alignment:

  • The new attributes on vector.load / vector.store (this patch)
  • The --convert-vector-to-llvm='use-vector-alignment=1' flag
  • memref.assume_alignment, which introduces its own complications (see #144809, #144825)

At the very least, this patch should clarify the relationship between the new attribute and the vector alignment flag in the lowering pass. Otherwise, we risk further fragmentation of behaviour and assumptions around alignment.

@kuhar
Copy link
Member

kuhar commented Jun 23, 2025

Recently (see #137389), ConvertVectorToLLVM was extended to support exactly that kind of override via the --convert-vector-to-llvm='use-vector-alignment=1' flag. Isn’t this patch duplicating that logic?

Thanks for linking this, I was not aware of this PR. My assessment is that, in isolation, use-vector-alignment=1 is closer to being hack that in some cases could have the same effect alignment = attributes. Alignment is not inherently a property of the memref type itself, but rather a property of load/store ops, and usually this information is known by the frontend / high-level transformations, so setting it during conversion is too late. For example, in IREE we have a pass to pad shared memory allocations to reduce bank conflicts. This expands allocations by adding extra columns but does not change the types of memory accesses. Using use-vector-alignment=1 would lead to miscompiles.

The alignment attribute is general enough to subsume use-vector-alignment -- we can write a simple pass to populate alignment attributes based on the vector types. So in this sense, we may be able to reduce feature duplication in the future.

@krzysz00
Copy link
Contributor

Yeah, I remember when landing --convert-vector-to-llvm='use-vector-alignment=1 I had the sense that it was a temporary solution to work around the fact that we couldn't put alignment on vector loads/stores.

I'd argue for removing it in favor of some flavor of vector-declare-natural-alignment transform over the vector dialect to reduce that same redundancy.

memref.assume_alignment is a rather weird op that, as far as I can tell, exists to allow backends that try to reason about pointer alignment to not stumble over the mysterious pointer-out-of-nowhere that can exist inside memrefs's base.

Comment on lines 13 to 27
// -----

func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
// expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
return
}

// -----

func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
// expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
vector.store %val, %memref[%c0] { alignment = 1 } : memref<4xi32>, vector<4xi32>
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort of "invalid usage" tests belong in invalid.mlir.

Comment on lines +3 to +11
// CHECK-LABEL: func @test_load_store_alignment
// CHECK: vector.load {{.*}} {alignment = 16 : i64}
// CHECK: vector.store {{.*}} {alignment = 16 : i64}
func.func @test_load_store_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
%val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space banach-space requested a review from dcaballe June 24, 2025 12:46
@banach-space
Copy link
Contributor

The alignment attribute is general enough to subsume use-vector-alignment -- we can write a simple pass to populate alignment attributes based on the vector types. So in this sense, we may be able to reduce feature duplication in the future.

I'd argue for removing it in favor of some flavor of vector-declare-natural-alignment transform over the vector dialect to reduce that same redundancy.

+1 Lets make sure that there is a clear TODO to that end.

Btw, there is more Ops that access memory:

  • vector.gather + vector.scatter
  • vector.transfer_read + vector.transfer_write
  • vector.compressstore + vector.expandload

What about these?

@krzysz00
Copy link
Contributor

Yeah, we probably should get them all - and probably hit them all with nontemporal too while we're here

@electriclilies
Copy link
Contributor

Hi! Author of #137389 here :)

Yes, it was a workaround-- we were getting perf issues from vectors being aligned to scalar alignments. I think an alignment attribute is a better long term solution, and I'm all for getting rid of the --convert-vector-to-llvm='use-vector-alignment=1 as long as we're still able to set vector alignments properly. Also, we do have use cases for setting the alignment higher up in the stack, for example during memory planning.

One concern I have with using memref ops for this is that they exist in so many different places in the stack. I could see this becoming messy-- let's say I have a memory planner which tries to pick some alignments, then everything gets lowered into LLVMIR, and then LLVM passes also try to set alignment. It's not clear to me if we should let lower level passes "override" alignment set by higher level passes, or vice versa. And this gets complicated as we interleave default llvm passes with custom ones. It might be worthwhile to expose an interface to let users override how default passes set alignment without making an upstream change. Just something to think about as you start implementing logic to set the alignments.


func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dup: 'vector.load' 'vector.load'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That and it should be in test/Dialect/Vector/invalid.mlir or the like

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM aside from the minor test relocation nits


func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That and it should be in test/Dialect/Vector/invalid.mlir or the like


func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably

@banach-space
Copy link
Contributor

Yeah, we probably should get them all

Lets do this here, to avoid a split state where some Ops support it while others don't.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants