diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index b3bde055f04f0..2225329ff830b 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported: * `IntPositive`: Specifying an integer attribute whose value is positive * `IntNonNegative`: Specifying an integer attribute whose value is non-negative +* `IntPowerOf2`: Specifying an integer attribute whose value is a power of + two > 0 * `ArrayMinCount`: Specifying an array attribute to have at least `N` elements * `ArrayMaxCount`: Specifying an array attribute to have at most `N` diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 77e3074661abf..dad4b7bff40cd 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1217,6 +1217,11 @@ def LoadOp : MemRef_Op<"load", be reused in the cache. For details, refer to the [https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction). + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. Example: ```mlir @@ -1227,7 +1232,39 @@ def LoadOp : MemRef_Op<"load", let arguments = (ins Arg:$memref, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "Type":$resultType, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultType, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultTypes, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; + let results = (outs AnyType:$result); let extraClassDeclaration = [{ @@ -1913,6 +1950,11 @@ def MemRef_StoreOp : MemRef_Op<"store", be reused in the cache. For details, refer to the [https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction). + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. Example: ```mlir @@ -1924,13 +1966,25 @@ def MemRef_StoreOp : MemRef_Op<"store", Arg:$memref, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); let builders = [ + OpBuilder<(ins "Value":$valueToStore, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, valueToStore, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{ $_state.addOperands(valueToStore); $_state.addOperands(memref); - }]>]; + }]> + ]; let extraClassDeclaration = [{ Value getValueToStore() { return getOperand(0); } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..f9eadb431be51 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1734,12 +1734,42 @@ def Vector_LoadOp : Vector_Op<"load"> { ```mlir %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let arguments = (ins Arg:$base, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "VectorType":$resultType, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultTypes, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; + let results = (outs AnyVectorOfAnyRank:$result); let extraClassDeclaration = [{ @@ -1818,6 +1848,12 @@ def Vector_StoreOp : Vector_Op<"store"> { ```mlir vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let arguments = (ins @@ -1825,8 +1861,21 @@ def Vector_StoreOp : Vector_Op<"store"> { Arg:$base, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal - ); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "Value":$valueToStore, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, valueToStore, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -1850,7 +1899,9 @@ def Vector_MaskedLoadOp : Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$pass_thru)>, + AnyVectorOfNonZeroRank:$pass_thru, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "loads elements from memory into a vector as defined by a mask vector"; @@ -1871,6 +1922,12 @@ def Vector_MaskedLoadOp : comes from the pass-through vector regardless of the index, and the index is allowed to be out-of-bounds. + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + The masked load can be used directly where applicable, or can be used during progressively lowering to bring other memory operations closer to hardware ISA support for a masked load. The semantics of the operation @@ -1887,6 +1944,18 @@ def Vector_MaskedLoadOp : : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; + let builders = [ + OpBuilder<(ins "Type":$resultType, + "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$pass_thru, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, mask, pass_thru, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast(getBase().getType()); @@ -1913,7 +1982,9 @@ def Vector_MaskedStoreOp : Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore)> { + AnyVectorOfNonZeroRank:$valueToStore, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> { let summary = "stores elements from a vector into memory as defined by a mask vector"; @@ -1933,6 +2004,12 @@ def Vector_MaskedStoreOp : is stored regardless of the index, and the index is allowed to be out-of-bounds. + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. + The masked store can be used directly where applicable, or can be used during progressively lowering to bring other memory operations closer to hardware ISA support for a masked store. The semantics of the operation @@ -1949,6 +2026,27 @@ def Vector_MaskedStoreOp : : memref, vector<16xi1>, vector<16xf32> ``` }]; + let builders = [ + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$valueToStore, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultTypes, base, indices, mask, valueToStore, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "Value":$base, + "ValueRange":$indices, + "Value":$mask, + "Value":$valueToStore, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, base, indices, mask, valueToStore, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast(getBase().getType()); diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index e91a13fea5c7f..18da85a580710 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -796,6 +796,10 @@ def IntPositive : AttrConstraint< CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">, "whose value is positive">; +def IntPowerOf2 : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">, + "whose value is a power of two > 0">; + class ArrayMaxCount : AttrConstraint< CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>, "with at most " # n # " elements">; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 8ccf1bfc292d5..267a711a71914 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -841,8 +841,8 @@ struct LoadOpLowering : public LoadStoreOpLowering { adaptor.getMemref(), adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp( - loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, - false, loadOp.getNontemporal()); + loadOp, typeConverter->convertType(type.getElementType()), dataPtr, + loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal()); return success(); } }; @@ -864,7 +864,8 @@ struct StoreOpLowering : public LoadStoreOpLowering { getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, - 0, false, op.getNontemporal()); + op.getAlignment().value_or(0), + false, op.getNontemporal()); return success(); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index 040a27e160557..874acbc9b6c3c 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -148,6 +148,15 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { // ----- +// CHECK-LABEL: func @aligned_load( +func.func @aligned_load(%static : memref<10x42xf32>, %i : index, %j : index) { +// CHECK: llvm.load %{{.*}} {alignment = 16 : i64} : !llvm.ptr -> f32 + %0 = memref.load %static[%i, %j] { alignment = 16 } : memref<10x42xf32> + return +} + +// ----- + // CHECK-LABEL: func @zero_d_store func.func @zero_d_store(%arg0: memref, %arg1: f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64)> @@ -177,6 +186,16 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va // ----- +// CHECK-LABEL: func @aligned_store +func.func @aligned_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { +// CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : f32, !llvm.ptr + + memref.store %val, %static[%i, %j] { alignment = 16 } : memref<10x42xf32> + return +} + +// ----- + // CHECK-LABEL: func @static_memref_dim func.func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { // CHECK: llvm.mlir.constant(42 : index) : i64 diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir new file mode 100644 index 0000000000000..afd314a9b498a --- /dev/null +++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func @test_load_store_alignment +// CHECK: memref.load {{.*}} {alignment = 16 : i64} +// CHECK: memref.store {{.*}} {alignment = 16 : i64} +func.func @test_load_store_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + %val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32> + memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32> + return +} + +// ----- + +func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) { + // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32> + return +} + +// ----- + +func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: i32) { + // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + memref.store %val, %memref[%c0] { alignment = 1 } : memref<4xi32> + return +} diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir new file mode 100644 index 0000000000000..e7beb8a3a9bd0 --- /dev/null +++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func @test_masked_load_store_alignment +// CHECK: vector.maskedload {{.*}} {alignment = 16 : i64} +// CHECK: vector.maskedstore {{.*}} {alignment = 16 : i64} +func.func @test_masked_load_store_alignment(%memref: memref<4xi32>, %mask: vector<4xi1>, %passthru: vector<4xi32>) { + %c0 = arith.constant 0 : index + %val = vector.maskedload %memref[%c0], %mask, %passthru { alignment = 16 } : memref<4xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32> + vector.maskedstore %memref[%c0], %mask, %val { alignment = 16 } : memref<4xi32>, vector<4xi1>, vector<4xi32> + return +} + +// ----- + +// CHECK-LABEL: func @test_load_store_alignment +// CHECK: vector.load {{.*}} {alignment = 16 : i64} +// CHECK: vector.store {{.*}} {alignment = 16 : i64} +func.func @test_load_store_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32> + vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32> + return +} + +// ----- + +func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> + return +} + +// ----- + +func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) { + %c0 = arith.constant 0 : index + // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> + return +}