diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index e6c7efc47593f..daab65ec893b8 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -80,9 +80,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface information e.g., memref, the strides information has to be explicitly passed via the "strides" and "const_strides" argument. - In SIMT mode, tensor descriptor is augmented with `LayoutAttr` which describes the - mapping of the tensor descriptor to the work items. - Example 1 (suppose the tensor shape inferred by the compiler is 8x16): ```mlir %0 = memref.alloc() : memref<1024x1024xf32> @@ -106,15 +103,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface %c1 = arith.constant 1 : index %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32> ``` - - Example 4 (SIMT mode): - ```mlir - %0 = memref.alloc() : memref<1024x1024xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 8 : index - %1 = xegpu.create_nd_tdesc %0[%c0, %c0] : memref<1024x1024xf32> - -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - ``` }]; let arguments = (ins @@ -301,9 +289,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ fp32 or fp64. It implies that vnni and transpose cannot exit at the same time. - In SIMT mode, LoadNdOp expects the tensor descriptor to be augmented with `LayoutAttr` - which describes the mapping of the tensor to the work items. In this case, result - vector represents the data to be loaded by each work-item. + In SIMT mode, result vector represents the data to be loaded by each work-item. Example 1: ```mlir @@ -317,8 +303,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ ```mlir xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<8x16xf32, - #xegpu.layout> -> vector<8x1xf32> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> ``` @@ -359,9 +344,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. - In SIMT mode, StoreNdOp expects the tensor descriptor to be augmented with `LayoutAttr` - which describes the mapping of the tensor to the work items. In this case, input - vector represents the data to be stored by each work-item. + In SIMT mode, the input vector represents the data to be stored by each work-item. Example 1: ```mlir @@ -375,8 +358,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : vector<8x1xf16>, !xegpu.tensor_desc<8x16xf16, - #xegpu.layout> + : vector<8xf16>, !xegpu.tensor_desc<8x16xf16> ``` @@ -410,15 +392,10 @@ def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset", The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input. - Example 1: + Example: ``` %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32> ``` - Example 2 (SIMT mode): - ``` - %2 = xegpu.update_nd_offset %1, [0, 16]: - !xegpu.tensor_desc<8x16xf32, #xegpu.layout> - ``` }]; let arguments = (ins @@ -476,11 +453,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> { match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1. - In SIMT mode, similar to `create_nd_tdesc` the resulting tensor descriptor is augmented - with `LayoutAttr` which describes the mapping of the tensor descriptor to the work items. - In this case, the first dimension of the tensor descriptor represents the work-items, and - the second dimension represents the chunk size. - Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64] ```mlir %a = memref.alloc() : memref<1024xf32> @@ -505,15 +477,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> { %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex> -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr> ``` - - Example 4: SIMT mode - ```mlir - %0 = memref.alloc() : memref<1024xf32> - %off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> - %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex> - -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr, - #xegpu.layout> - ``` }]; let arguments = (ins XeGPU_BaseAddrType: $source, @@ -609,54 +572,44 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ let description = [{ It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When the chunk size - is larger than 2, the output vector is a 2D vector, with dim-1 correspoding - to work-items, and dim-0 corresponding to the chunk size loaded by each work-item. - Specially, there is a transpose effect on the result (as compared to the TensorDesc) - due to the hardware implementation. Therefore, a transpose attribute is introduced - on purpose, making sure users are aware of this implicit transformation. - + is larger than 2, the output vector is a 2D vector, with dim-0 correspoding + to work-items, and dim-1 corresponding to the chunk size loaded by each work-item. The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes. - In SIMT mode, LoadGatherOp expects the tensor descriptor to be augmented with `LayoutAttr` - which describes the mapping of the tensor to the work items. In this case, result vector - represents the data to be loaded by each work-item. Each work-item recieves a `chunk_size` - number of elements. + In SIMT mode, the result vector represents the data to be loaded by each work-item. + Each work-item recieves a `chunk_size` number of elements. Example 1: ```mlir - %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> ``` Example 2: ```mlir - %2 = xegpu.load %1, %0 {transpose, - l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, - vector<16xi1> -> vector<8x16xf32> + vector<16xi1> -> vector<16x8xf32> ``` Example 3 (SIMT mode): ```mlir - %2 = xegpu.load %1, %0 {transpose, - l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, - !xegpu.layout> - vector<16xi1> -> vector<8x1xf32> + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + vector<16xi1> -> vector<8xf32> ``` }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, XeGPU_MaskType: $mask, - OptionalAttr: $transpose, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -699,44 +652,38 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is introduced on purpose, making sure users are aware of this implicit transformation. - In SIMT mode, StoreScatterOp expects the tensor descriptor to be augmented with `LayoutAttr` - which describes the mapping of the tensor to the work items. In this case, input vector - represents the data to be stored by each work-item. Each work-item recieves a `chunk_size` - number of elements. + In SIMT mode, the input vector represents the data to be stored by each work-item. + Each work-item stores a `chunk_size` number of elements. Example 1: ```mlir - xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} + xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1> ``` Example 2: ```mlir - xegpu.store %0, %1, %2 {transpose, - l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} - : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr>, vector<16xi1> + xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> + : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr>, vector<16xi1> ``` + Example 3 (SIMT mode): ```mlir - xegpu.store %0, %1, %2 {transpose, - l1_hint = #xegpu.cache_hint, - l2_hint = #xegpu.cache_hint, - l3_hint = #xegpu.cache_hint} - : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr, - !xegpu.layout> vector<16xi1> + xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr> vector<16xi1> ``` - }]; let arguments = (ins XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, XeGPU_MaskType: $mask, - OptionalAttr: $transpose, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -773,20 +720,13 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", update the offset per work-item, so its offsets contains values representing shifts for each work-item. - Example 1: + Example: ```mlir %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex> %2 = xegpu.update_offset %1, %off : !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr>, vector<4xindex> ``` - Example 2 (SIMT mode): - ```mlir - %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex> - %2 = xegpu.update_offset %1, %off : - !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr, - #xegpu.layout>, vector<4xindex> - ``` }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 772cf73649646..09311e6017d0c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault = 16; // Minimum packing size per register for DPAS A. constexpr unsigned packedSizeInBitsForDpasB = 32; // Minimum packing size per register for DPAS B. +constexpr unsigned packedSizeInBitsForGatherScatter = + 32; // Minimum packing size per register for Gather and Scatter ops. } // namespace targetinfo } // namespace xegpu diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 6790c5e3af2c0..649e0d453015f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -309,11 +310,23 @@ LogicalResult TensorDescType::verify( llvm::ArrayRef shape, mlir::Type elementType, mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); - // Low-precision types are packed in 32-bit units. - int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth(); if (rank != 1 && rank != 2) return emitError() << "expected 1D or 2D tensor"; + auto blockAttr = mlir::dyn_cast_if_present(encoding); + if (blockAttr) { + MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); + if (rank == 2 && memorySpaceAttr && + memorySpaceAttr.getValue() == MemorySpace::SLM) + return emitError() << "SLM is not supported for 2D block tensor"; + } + + // for gather and scatter ops, Low-precision types are packed in 32-bit units. + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + int chunkAlignmentFactor = + bitWidth < targetinfo::packedSizeInBitsForGatherScatter + ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth + : 1; auto scatterAttr = mlir::dyn_cast_if_present(encoding); if (scatterAttr) { // Expected tensor ranks for scattered data: @@ -329,21 +342,13 @@ LogicalResult TensorDescType::verify( if (chunkSize > 1) { if (shape.back() != chunkSize) return emitError() << "expected tensor shape[1] to match chunk size"; - if (shape.back() % packingFactor != 0) - return emitError() - << "expected tensor shape[1] to be a multiple of packing factor " - << packingFactor; + if (shape.back() % chunkAlignmentFactor != 0) + return emitError() << "expected tensor shape[1] to be a multiple of " + "chunk alignment factor " + << chunkAlignmentFactor; } } - auto blockAttr = mlir::dyn_cast_if_present(encoding); - if (blockAttr) { - MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); - if (rank == 2 && memorySpaceAttr && - memorySpaceAttr.getValue() == MemorySpace::SLM) - return emitError() << "SLM is not supported for 2D block tensor"; - } - auto layoutAttr = llvm::dyn_cast_if_present(layout); if (layoutAttr) { if (rank != (size_t)layoutAttr.getRank()) @@ -360,7 +365,7 @@ LogicalResult TensorDescType::verify( if (rank > 1 && laneData[0] != 1) return emitError() << "cannot map over non-contiguous scattered row elements"; - if (laneData[rank - 1] != packingFactor) + if (laneData[rank - 1] != chunkAlignmentFactor) return emitError() << "work item data mapping must match the number of " "contiguous elements"; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 0afc502c026f7..2793c7a35bc97 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,13 +20,6 @@ namespace mlir { namespace xegpu { -static void transpose(llvm::ArrayRef trans, - SmallVector &shape) { - SmallVector old = shape; - for (size_t i = 0; i < trans.size(); i++) - shape[i] = old[trans[i]]; -} - template static std::string makeString(T array, bool breakline = false) { std::string buf; @@ -76,7 +69,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, - TensorDescType tdescTy, UnitAttr transposeAttr, + TensorDescType tdescTy, function_ref emitError) { if (!tdescTy.isScattered()) @@ -102,17 +95,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { if (tdescTy.getLayoutAttr()) return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (transposeAttr) - return emitError() << "doesn't need TransposeAttr for SIMT code"; return success(); } - if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { - if (!transposeAttr) - return emitError() << "rank-2 tensor has to be transposed."; - transpose({1, 0}, tdescShape); - } - if (tdescShape != valueShape) return emitError() << "Value shape " << makeString(valueShape) << " is neither a valid distribution for SIMT nor " @@ -310,13 +295,9 @@ LogicalResult LoadNdOp::verify() { if (getTranspose()) { auto trans = getTranspose().value(); - - // Make sure the transpose value is valid. - bool valid = llvm::all_of( - trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); }); - - if (valid) - transpose(trans, tdescShape); + // Make sure the transpose value is valid, and apply it + if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); })) + tdescShape = applyPermutation(tdescShape, trans); else mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; } @@ -536,7 +517,6 @@ LogicalResult LoadGatherOp::verify() { return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - getTransposeAttr(), [&]() { return emitOpError(); }); } @@ -558,7 +538,6 @@ LogicalResult StoreScatterOp::verify() { return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - getTransposeAttr(), [&]() { return emitOpError(); }); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index cc22d2bbd8c39..3c77c83a863e6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -213,6 +213,37 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { LaneData({1, packingFactor})); } +/// Helper to get the default layout for a vector type. +static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { + // Expecting a 1D or 2D vector. + assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && + "Expected 1D or 2D TensorDesc."); + // Expecting int or float element type. + assert(tdescTy.getElementType().isIntOrFloat() && + "Expected int or float element type."); + // If the rank is 1, then return default layout for 1D vector. + if (tdescTy.getRank() == 1) + return getDefaultSIMTLayoutInfo(1); + // Packing factor is determined by the element type bitwidth. + unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + + if (tdescTy.isScattered()) { + int packingFactor = + bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter + ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth + : 1; + return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}), + LaneData({1, packingFactor})); + } + + int packingFactor = + (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) + ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth + : 1; + return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), + LaneData({1, packingFactor})); +} + /// Helper Function to get the expected layouts for DPAS operands. `lane_data` /// is set according to the following criteria: /// * For A operand, the data must be packed in minimum @@ -379,8 +410,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp( // Here we assign the default layout to the tensor descriptor operand of // prefetch. auto tdescTy = prefetch.getTensorDescType(); - auto prefetchLayout = getDefaultSIMTLayoutInfo( - VectorType::get(tdescTy.getShape(), tdescTy.getElementType())); + auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy); // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -516,24 +546,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp( void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef operands, ArrayRef results) { - LayoutInfo valueLayout = results[0]->getValue(); - // Need the layout of the value to propagate to the tensor descriptor. - if (!valueLayout.isAssigned()) - return; + // The layout is strictly determined by the tensor descriptor type. + LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType()); - LayoutInfo tensorDescLayout = valueLayout; - if (load.getTranspose()) { - // LoadGatherOp has the transpose effect. However, at the stage of this - // analyis this effect is not expected and should be abstracted away. Emit - // a warning. - load.emitWarning("Transpose effect is not expected for LoadGatherOp at " - "LayoutInfoPropagation stage."); - tensorDescLayout = valueLayout.getTransposedLayout({1, 0}); - } // Mask operand should have 1D default layout. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); + // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + propagateIfChanged(operands[0], operands[0]->meet(layout)); // Propagate the new layout to the mask operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); } @@ -567,21 +587,13 @@ void LayoutInfoPropagation::visitStoreScatterOp( "Expected the first dimension of 2D tensor descriptor to be equal to " "subgroup size."); - LayoutInfo valueLayout = - getDefaultSIMTLayoutInfo(storeScatter.getValueType()); - LayoutInfo storeScatterLayout = valueLayout; - if (storeScatter.getTranspose()) { - // StoreScatteOp allows transpose effect. However, at the stage of this - // analyis this effect is not expected and should be abstracted away. Emit - // a warning. - storeScatter.emitWarning("Transpose effect is not expected for " - "StoreScatterOp at LayoutInfoPropagation stage."); - storeScatterLayout = valueLayout.getTransposedLayout({1, 0}); - } + LayoutInfo layout = + getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType()); + // Propagate the value layout. - propagateIfChanged(operands[0], operands[0]->meet(valueLayout)); + propagateIfChanged(operands[0], operands[0]->meet(layout)); // Propagate the tensor descriptor layout. - propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout)); + propagateIfChanged(operands[1], operands[1]->meet(layout)); // Use default 1D layout for mask operand. LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1); propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 0457f8128b908..2c48a735bf956 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -498,17 +498,14 @@ struct UnrollLoadGatherOp : public UnrollPattern { if (originalChunkSize > 1) { targetMaskShape.pop_back(); convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); - SmallVector convertedMasks1D = pack( - op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; - for (auto mask : convertedMasks1D) { - for (int64_t i = 0; i < numNewChunks; ++i) - convertedMasks.push_back(mask); - } - // This is to handle the transpose effect when chunkSize > 1. - std::swap((*targetShape)[0], (*targetShape)[1]); + // the mask is reused across the chunk_size dimension + for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape, + loc, rewriter)) + convertedMasks.append(numNewChunks, mask); + newValueTy = valueTy.cloneWith(*targetShape, elemTy); } else { convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); @@ -519,8 +516,8 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector newOps; for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) { auto newOp = rewriter.create( - loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); newOps.push_back(newOp); } @@ -573,7 +570,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { if (!targetShape) return failure(); - SmallVector targetIndiceShape(*targetShape); + SmallVector targetMaskShape(*targetShape); int64_t originalChunkSize = tdescTy.getChunkSize(); VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); @@ -587,24 +584,19 @@ struct UnrollStoreScatterOp : public UnrollPattern { SmallVector convertedMasks; if (originalChunkSize > 1) { + targetMaskShape.pop_back(); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; - convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]); - SmallVector convertedMasks1D = pack( - op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter); - - for (auto mask : convertedMasks1D) { - for (int64_t i = 0; i < numNewChunks; ++i) { - convertedMasks.push_back(mask); - } - } - // This is to handle the transpose effect when chunkSize > 1. - std::swap((*targetShape)[0], (*targetShape)[1]); + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + // the mask is reused across the chunk_size dimension + for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape, + loc, rewriter)) + convertedMasks.append(numNewChunks, mask); } else { - convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape); - convertedMasks = - pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, + loc, rewriter); } SmallVector convertedValTypes = @@ -616,9 +608,9 @@ struct UnrollStoreScatterOp : public UnrollPattern { Value v = convertedValues[i]; Value t = convertedTdescs[i]; Value m = op.getMask() ? convertedMasks[i] : nullptr; - rewriter.create( - loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + rewriter.create(loc, v, t, m, op.getL1HintAttr(), + op.getL2HintAttr(), + op.getL3HintAttr()); } rewriter.eraseOp(op); @@ -655,20 +647,15 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { SmallVector newOps; int64_t originalChunkSize = tdescTy.getChunkSize(); if (originalChunkSize > 1) { - SmallVector shape1D(targetShape->begin(), - targetShape->end() - 1); - convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D); - SmallVector convertedOffsetVec1D = - pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter); + auto targetOffsetShape = ArrayRef(*targetShape).drop_back(); + convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; - - for (auto offset : convertedOffsetVec1D) { - for (int64_t i = 0; i < numNewChunks; ++i) { - convertedOffsetVec.push_back(offset); - } - } + // the offset is reused across the chunk_size dimension + for (auto offset : pack(offsetVec, convertedOffsetTypes, + targetOffsetShape, loc, rewriter)) + convertedOffsetVec.append(numNewChunks, offset); } else { convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0a37ae70b5d99..a2778cd94d963 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -239,7 +239,7 @@ func.func @create_tdesc_vc_5(%src: memref) { func.func @create_tdesc_vc_6(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{tensor shape[1] to be a multiple of packing factor 2}} + // expected-error@+1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}} -> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 054c4d12fdb28..aff8f63adc05b 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) { gpu.return } - - // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> @@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) { gpu.return } - // CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> @@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) { %1 = arith.constant dense<1>: vector<4xi1> //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<2x4xf32> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<2x4xf32> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<4x2xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<4x2xf32> gpu.return } @@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) { %1 = arith.constant dense<1>: vector<4xi1> //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr> %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<8x4xf16> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<8x4xf16> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<4x8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<4x8xf16> gpu.return } @@ -394,17 +391,15 @@ gpu.func @subgroup_store(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32> - %2 = arith.constant dense<2.9>: vector<2x4xf32> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<4x2xf32> + %2 = arith.constant dense<2.9>: vector<4x2xf32> //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> gpu.return } - - // CHECK: gpu.func @simt_store(%[[arg0:.*]]: ui64) { gpu.func @simt_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -426,17 +421,15 @@ gpu.func @subgroup_store_2(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> %1 = arith.constant dense<1>: vector<4xi1> - //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf16> - %2 = arith.constant dense<2.9>: vector<2x4xf16> + //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<4x2xf16> + %2 = arith.constant dense<2.9>: vector<4x2xf16> //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr> %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr> - //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<4x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<4x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr>, vector<4xi1> gpu.return } - - // CHECK: gpu.func @simt_store_2(%[[arg0:.*]]: ui64) { gpu.func @simt_store_2(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -467,7 +460,6 @@ gpu.func @subgroup_store_3(%src: ui64) { gpu.return } - // CHECK: gpu.func @simt_store_3(%[[arg0:.*]]: ui64) { gpu.func @simt_store_3(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 429081079de1e..0214d84f2c16f 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -90,26 +90,27 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor } // ----- -// CHECK-LABEL: func.func @load_gather_with_transpose_effect( +// CHECK-LABEL: func.func @load_gather_with_chunksize( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout} // CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_result_0 = #xegpu.layout} : +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<16xi1> -> vector<16x16xf16> -func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { +func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> %cst_0 = arith.constant dense : vector<16xi1> %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr> - %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x16xf16> - %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x16xf16> + %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %5 = xegpu.dpas %1, %4 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> return } @@ -127,24 +128,24 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> %cst_0 = arith.constant dense : vector<16xi1> %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32> return } // ----- -// CHECK-LABEL: func.func @store_scatter_with_transpose_effect( +// CHECK-LABEL: func.func @store_scatter_with_chunksize( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) { // CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> -// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, +// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, // CHECK-SAME: #xegpu.layout>, vector<16xi1> -func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) { - %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32> +func.func @store_scatter_with_chunksize(%arg0: memref<128xf32>) { + %cst = arith.constant dense<1.000000e+00> : vector<16x8xf32> %cst_0 = arith.constant dense : vector<16xi1> %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> - xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + xegpu.store %cst, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir index f977ba3c11bcf..ac5fe89a67f9a 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -368,35 +368,35 @@ gpu.module @test_kernel { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - + %delta = arith.constant dense<[ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, - 128, 128, 128, 128, 128, 128, 128, 256 + 128, 128, 128, 128, 128, 128, 128, 256 ]> : vector<32xindex> %new_tdesc = xegpu.update_offset %tdesc, %delta - : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xindex> - + : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xindex> + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xi1> -> vector<32xf32> %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32> - xegpu.store %st_vec, %tdesc, %mask: - vector<32xf32>, - !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, + xegpu.store %st_vec, %tdesc, %mask: + vector<32xf32>, + !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xi1> - + gpu.return } - + } // ----- @@ -407,8 +407,8 @@ gpu.module @test_kernel { // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> - // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<2x16xf32> - // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> + // CHECK-COUNT-4: xegpu.store {{.*}} : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) { @@ -416,32 +416,32 @@ gpu.module @test_kernel { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - + %delta = arith.constant dense<[ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, - 128, 128, 128, 128, 128, 128, 128, 256 + 128, 128, 128, 128, 128, 128, 128, 256 ]> : vector<32xindex> %new_tdesc = xegpu.update_offset %tdesc, %delta - : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xindex> - + : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xindex> + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> -> vector<4x32xf32> + %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> -> vector<32x4xf32> - %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32> - xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: - vector<4x32xf32>, - !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, + %st_vec = arith.addf %ld_vec, %ld_vec : vector<32x4xf32> + xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: + vector<32x4xf32>, + !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> - + gpu.return } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 41414d802f212..6999da5d222fe 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -169,7 +169,7 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> @@ -199,15 +199,15 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> - + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xi1> -> vector<32xf32> - - gpu.return %ld : vector<32xf32> + + gpu.return %ld : vector<32xf32> } //----- @@ -222,7 +222,7 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> @@ -242,16 +242,16 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> - + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> %st_vec = arith.constant dense<1023.0>: vector<32xf32> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xi1> - + gpu.return } @@ -280,7 +280,7 @@ gpu.module @test { } // CHECK-LABEL: create_tdesc_step_chunk3 - // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex> // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> @@ -300,45 +300,45 @@ gpu.module @test { // CHECK-LABEL: load_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<2x16xf32> + // CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> - gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> { + gpu.func @load_chunk(%src: ui64) -> vector<32x4xf32> { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> - + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> -> vector<4x32xf32> - - gpu.return %ld : vector<4x32xf32> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> -> vector<32x4xf32> + + gpu.return %ld : vector<32x4xf32> } //----- // CHECK-LABEL: store_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + // CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.func @store_chunk(%src: ui64) { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> - + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %st_vec = arith.constant dense<1023.>: vector<4x32xf32> + %st_vec = arith.constant dense<1023.>: vector<32x4xf32> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> - + xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32x4xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> + gpu.return } @@ -352,11 +352,11 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - + gpu.return } @@ -370,7 +370,7 @@ gpu.module @test { 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> %delta = arith.constant dense<32>: vector<32xindex> %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> @@ -379,6 +379,6 @@ gpu.module @test { : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xindex> gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - } + } }