diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 7a705336bf11c..e730998f153b0 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include +#include #include #define DEBUG_TYPE "memref-to-spirv-pattern" @@ -465,7 +466,13 @@ struct MemoryRequirements { /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if /// any. static FailureOr -calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { +calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, + uint64_t preferredAlignment) { + + if (std::numeric_limits::max() < preferredAlignment) { + return failure(); + } + MLIRContext *ctx = accessedPtr.getContext(); auto memoryAccess = spirv::MemoryAccess::None; @@ -474,7 +481,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } auto ptrType = cast(accessedPtr.getType()); - if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { + bool mayOmitAlignment = + !preferredAlignment && + ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer; + if (mayOmitAlignment) { if (memoryAccess == spirv::MemoryAccess::None) { return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; } @@ -483,6 +493,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } // PhysicalStorageBuffers require the `Aligned` attribute. + // Other storage types may show an `Aligned` attribute. auto pointeeType = dyn_cast(ptrType.getPointeeType()); if (!pointeeType) return failure(); @@ -494,7 +505,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); - auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); + auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes; + auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue); return MemoryRequirements{memAccessAttr, alignment}; } @@ -508,16 +520,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { llvm::is_one_of::value, "Must be called on either memref::LoadOp or memref::StoreOp"); - Operation *memrefAccessOp = loadOrStoreOp.getOperation(); - auto memrefMemAccess = memrefAccessOp->getAttrOfType( - spirv::attributeName()); - auto memrefAlignment = - memrefAccessOp->getAttrOfType("alignment"); - if (memrefMemAccess && memrefAlignment) - return MemoryRequirements{memrefMemAccess, memrefAlignment}; - return calculateMemoryRequirements(accessedPtr, - loadOrStoreOp.getNontemporal()); + loadOrStoreOp.getNontemporal(), + loadOrStoreOp.getAlignment().value_or(0)); } LogicalResult diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index d0ddac8cd801c..7c765f70136bb 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -85,6 +85,51 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : return %0: i1 } +// CHECK-LABEL: func @load_aligned +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + +// CHECK-LABEL: func @load_aligned_nontemporal +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned|Nontemporal", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + +// CHECK-LABEL: func @load_aligned_psb +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, PhysicalStorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" %[[ADDR]] ["Aligned", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + // CHECK-LABEL: func @store_i1 // CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class>, // CHECK-SAME: %[[IDX:.+]]: index