Skip to content

[MLIR][XeGPU] Remove the transpose attribte from Gather/Scatter ops and Cleanup the documents #145389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented Jun 23, 2025

This PR removes the transpose attribute from the definition of LoadGatherOp and StoreScatterOp. It is meaningful in the context of SIMD lowering pipeline, but not for SIMT lowering pipeline.

This PR also removes the layout attribute from SIMT examples to match the changes in #144592

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Chao Chen (chencha3)

Changes

As the title suggests. This PR removes the transpose attribute from the definition of LoadGatherOp and StoreScatterOp. It is meaningful in the context of SIMD lowering pipeline, but not for SIMT lowering pipeline.


Patch is 45.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145389.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+12-22)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+2)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+3-24)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+39-29)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+5-10)
  • (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14-22)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+18-27)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+23-23)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+28-28)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e6c7efc47593f..ffc08e9b90b56 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -609,12 +609,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
     consistent with the number of work-items in a subgroup. When the chunk size
-    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
-    to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
-    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
-    due to the hardware implementation. Therefore, a transpose attribute is introduced
-    on purpose, making sure users are aware of this implicit transformation.
-
+    is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
+    to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
@@ -634,8 +630,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   Example 2:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
@@ -643,20 +638,18 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   ```
   Example 3 (SIMT mode):
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-            vector<16xi1> -> vector<8x1xf32>
+            vector<16xi1> -> vector<8xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -714,19 +707,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   Example 2:
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
           : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
   Example 3 (SIMT mode):
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
   ```
 
@@ -736,7 +727,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
-    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 772cf73649646..09311e6017d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
+constexpr unsigned packedSizeInBitsForGatherScatter =
+    32; // Minimum packing size per register for Gather and Scatter ops.
 } // namespace targetinfo
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0afc502c026f7..f0fb03d4f1139 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,13 +20,6 @@
 namespace mlir {
 namespace xegpu {
 
-static void transpose(llvm::ArrayRef<int64_t> trans,
-                      SmallVector<int64_t> &shape) {
-  SmallVector<int64_t> old = shape;
-  for (size_t i = 0; i < trans.size(); i++)
-    shape[i] = old[trans[i]];
-}
-
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
@@ -76,7 +69,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 
 static LogicalResult
 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
-                           TensorDescType tdescTy, UnitAttr transposeAttr,
+                           TensorDescType tdescTy,
                            function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
@@ -102,17 +95,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
     if (tdescTy.getLayoutAttr())
       return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (transposeAttr)
-      return emitError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!transposeAttr)
-      return emitError() << "rank-2 tensor has to be transposed.";
-    transpose({1, 0}, tdescShape);
-  }
-
   if (tdescShape != valueShape)
     return emitError() << "Value shape " << makeString(valueShape)
                        << " is neither a valid distribution for SIMT nor "
@@ -310,13 +295,9 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-
     // Make sure the transpose value is valid.
-    bool valid = llvm::all_of(
-        trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
-
-    if (valid)
-      transpose(trans, tdescShape);
+    if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
+      tdescShape = applyPermutation(tdescShape, trans);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -536,7 +517,6 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
@@ -558,7 +538,6 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cc22d2bbd8c39..60ccd823775a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -213,6 +213,35 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
                     LaneData({1, packingFactor}));
 }
 
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+  // Expecting a 1D or 2D vector.
+  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
+         "Expected 1D or 2D TensorDesc.");
+  // Expecting int or float element type.
+  assert(tdescTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (tdescTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+
+  if (tdescTy.isScattered()) {
+    int packingFactor =
+        xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+                      LaneData({1, packingFactor}));
+  }
+
+  int packingFactor =
+      (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+          ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
+          : 1;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
 /// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum
@@ -379,8 +408,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   // Here we assign the default layout to the tensor descriptor operand of
   // prefetch.
   auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultSIMTLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
 }
@@ -516,24 +544,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
+  // The layout is strictly determined by the tensor descriptor type.
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
 
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+
   // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
@@ -567,21 +585,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo valueLayout =
-      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
+  LayoutInfo layout =
+      getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+
   // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
   // Use default 1D layout for mask operand.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 0457f8128b908..be39ee1f0b53f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -507,8 +507,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
         for (int64_t i = 0; i < numNewChunks; ++i)
           convertedMasks.push_back(mask);
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
       newValueTy = valueTy.cloneWith(*targetShape, elemTy);
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
@@ -519,8 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
       auto newOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -598,9 +596,6 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
-
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
       convertedMasks =
@@ -616,9 +611,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
       Value v = convertedValues[i];
       Value t = convertedTdescs[i];
       Value m = op.getMask() ? convertedMasks[i] : nullptr;
-      rewriter.create<xegpu::StoreScatterOp>(
-          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+      rewriter.create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
+                                             op.getL2HintAttr(),
+                                             op.getL3HintAttr());
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 054c4d12fdb28..5ceb548221758 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
 gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
   gpu.return
 }
 
@@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-
-
 // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
-
 // CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
@@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<4x2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<4x2xf32>
   gpu.return
 }
 
@@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<4x8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.te...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

As the title suggests. This PR removes the transpose attribute from the definition of LoadGatherOp and StoreScatterOp. It is meaningful in the context of SIMD lowering pipeline, but not for SIMT lowering pipeline.


Patch is 45.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145389.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+12-22)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+2)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+3-24)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+39-29)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+5-10)
  • (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14-22)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+18-27)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+23-23)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+28-28)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e6c7efc47593f..ffc08e9b90b56 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -609,12 +609,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
     consistent with the number of work-items in a subgroup. When the chunk size
-    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
-    to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
-    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
-    due to the hardware implementation. Therefore, a transpose attribute is introduced
-    on purpose, making sure users are aware of this implicit transformation.
-
+    is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
+    to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
@@ -634,8 +630,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   Example 2:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
@@ -643,20 +638,18 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   ```
   Example 3 (SIMT mode):
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-            vector<16xi1> -> vector<8x1xf32>
+            vector<16xi1> -> vector<8xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -714,19 +707,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   Example 2:
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
           : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
   Example 3 (SIMT mode):
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
   ```
 
@@ -736,7 +727,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
-    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 772cf73649646..09311e6017d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
+constexpr unsigned packedSizeInBitsForGatherScatter =
+    32; // Minimum packing size per register for Gather and Scatter ops.
 } // namespace targetinfo
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0afc502c026f7..f0fb03d4f1139 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,13 +20,6 @@
 namespace mlir {
 namespace xegpu {
 
-static void transpose(llvm::ArrayRef<int64_t> trans,
-                      SmallVector<int64_t> &shape) {
-  SmallVector<int64_t> old = shape;
-  for (size_t i = 0; i < trans.size(); i++)
-    shape[i] = old[trans[i]];
-}
-
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
@@ -76,7 +69,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 
 static LogicalResult
 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
-                           TensorDescType tdescTy, UnitAttr transposeAttr,
+                           TensorDescType tdescTy,
                            function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
@@ -102,17 +95,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
     if (tdescTy.getLayoutAttr())
       return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (transposeAttr)
-      return emitError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!transposeAttr)
-      return emitError() << "rank-2 tensor has to be transposed.";
-    transpose({1, 0}, tdescShape);
-  }
-
   if (tdescShape != valueShape)
     return emitError() << "Value shape " << makeString(valueShape)
                        << " is neither a valid distribution for SIMT nor "
@@ -310,13 +295,9 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-
     // Make sure the transpose value is valid.
-    bool valid = llvm::all_of(
-        trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
-
-    if (valid)
-      transpose(trans, tdescShape);
+    if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
+      tdescShape = applyPermutation(tdescShape, trans);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -536,7 +517,6 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
@@ -558,7 +538,6 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cc22d2bbd8c39..60ccd823775a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -213,6 +213,35 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
                     LaneData({1, packingFactor}));
 }
 
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+  // Expecting a 1D or 2D vector.
+  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
+         "Expected 1D or 2D TensorDesc.");
+  // Expecting int or float element type.
+  assert(tdescTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (tdescTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+
+  if (tdescTy.isScattered()) {
+    int packingFactor =
+        xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+                      LaneData({1, packingFactor}));
+  }
+
+  int packingFactor =
+      (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+          ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
+          : 1;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
 /// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum
@@ -379,8 +408,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   // Here we assign the default layout to the tensor descriptor operand of
   // prefetch.
   auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultSIMTLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
 }
@@ -516,24 +544,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
+  // The layout is strictly determined by the tensor descriptor type.
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
 
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+
   // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
@@ -567,21 +585,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo valueLayout =
-      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
+  LayoutInfo layout =
+      getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+
   // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
   // Use default 1D layout for mask operand.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 0457f8128b908..be39ee1f0b53f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -507,8 +507,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
         for (int64_t i = 0; i < numNewChunks; ++i)
           convertedMasks.push_back(mask);
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
       newValueTy = valueTy.cloneWith(*targetShape, elemTy);
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
@@ -519,8 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
       auto newOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -598,9 +596,6 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
-
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
       convertedMasks =
@@ -616,9 +611,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
       Value v = convertedValues[i];
       Value t = convertedTdescs[i];
       Value m = op.getMask() ? convertedMasks[i] : nullptr;
-      rewriter.create<xegpu::StoreScatterOp>(
-          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+      rewriter.create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
+                                             op.getL2HintAttr(),
+                                             op.getL3HintAttr());
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 054c4d12fdb28..5ceb548221758 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
 gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
   gpu.return
 }
 
@@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-
-
 // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
-
 // CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
@@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<4x2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<4x2xf32>
   gpu.return
 }
 
@@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<4x8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.te...
[truncated]

@chencha3 chencha3 changed the title [MLIR][XeGPU] Remove the transpose attribute from LoadGatherOp and StoreScatterOp [MLIR][XeGPU] Cleanup the definition and documents Jun 23, 2025
@Garra1980
Copy link

Why not mention removing transpose attribute in the PR title?

@charithaintc
Copy link
Contributor

Why not mention removing transpose attribute in the PR title?

second this. please use a commit msg reflecting the specific change here.

@chencha3 chencha3 changed the title [MLIR][XeGPU] Cleanup the definition and documents [MLIR][XeGPU] Remove the transpose attribte from Gather/Scatter ops and Cleanup the documents Jun 24, 2025
@chencha3
Copy link
Contributor Author

Why not mention removing transpose attribute in the PR title?

second this. please use a commit msg reflecting the specific change here.

Updated.

on purpose, making sure users are aware of this implicit transformation.

is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we assuming here that there will always be a transpose operation after the load?

I wonder how a user can understand the semantics of this op. what if the user does not want the transpose and want to use the op in isolation (which is perfectly legal)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no transpose. The semantic is each row corresponding to a lane. In the SIMD lowering pipeline, the transpose will be added when we lower the load_gather to the corresponding intrinsic. For SIMT lowering, there is no transpose at all.

Copy link
Contributor

@charithaintc charithaintc Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it again.

It seems like now xegpu.load (with chunck > 1) is just a logical operation. meaning it does not have a matching HW instruction. Logically we can use it without an accompanying transpose operation. that is true.

In practice, it will always come with an accompanying transpose. It will mostly be useful for A*BT case. In that case we always need an explicit vector.transpose after the xegpu.load. During lowering the load + transpose are optimized away in both SIMD and SIMT paths. Essentially we say that "we have a HW instruction that can do both these together, so transpose here is a nop". No need to do any shuffling to the transpose.

For A*B case, I think doing multiple loads will maybe be cheaper than doing a load gather and then doing an in-resister transpose. not sure about this case.

A*BT case

func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
  %c0 = arith.constant 0 : index
  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
  %cst_0 = arith.constant dense<true> : vector<16xi1>
  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
  %3 = xegpu.load %2, %cst_0  : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16> // layout  = [16, 1][1, 2]
 %6 = vector.transpose %3 : vector<16x16xf16> -> vector<16x16xf16> // this is a NOP // layout = [1, 16][2, 1]
  %4 = xegpu.dpas %1, %6 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
  return
}

A*B case.

func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
  %c0 = arith.constant 0 : index
  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
  %cst_0 = arith.constant dense<true> : vector<16xi1>
  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
  %3 = xegpu.load %2, %cst_0  : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16> 
  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
  return
}

A*BT case is clear to me. but not sure what we do with A*B case here ? Maybe I am still missing something. @Jianhui-Li can you also clarify on these examples. I know that A*B is not a real use case, but still confused how layout propagation works here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm. It is clear now after discussing with @Jianhui-Li. A*B case will need a convert_layout because the load is not giving us the layout needed for DPAS B.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For A* B case, as you use load w/ chunk_size for B, which assumes [16, 1] [1, 2] layout. The propagation needs to insert a xegpu.conv_layout to convert it to [1, 16][2, 1] before it feed to DPAS.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in lowering perspective we expect two cases.

  1. xegpu.load + vector.transpose : regular case. just lower to load with chunk size instrinsic.
  2. xegpu.load + convert_layout : load to load with chunk size intrinsic and do cross lane shuffles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
One thing to note in the lowering: In you code example, the user specify xegpu.load w/ chunk_size, which will be lowered XeVM.load w/ vector size by default (each lane load contiguous data).
If user override the layout of xegpu.load w/ chunk_size, say forcing it to be takes [1, 16] [2, 1] layout, it will need to lowered to multiple regular XeVM.load, since now the data loaded by each lane are not contiguous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user override the layout of xegpu.load w/ chunk_size, say forcing it to be takes [1, 16] [2, 1] layout, it will need to lowered to multiple regular XeVM.load, since now the data loaded by each lane are not contiguous.

is the user allowed to do this? I also like it if we keep it relaxed. But I can see in this PR we have hard coded the scattered load layout to [16, 1][1, 2]. Check here
https://github.com/llvm/llvm-project/pull/145389/files#diff-fcc9cdbf8bb4e5d37e661524b877082aee9b7badb0317f980c1881da564a926dR230-R237

Copy link
Contributor

@Jianhui-Li Jianhui-Li Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that the propagation pass will be improved to allow user to set the layout which override the default decision.

l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is useful to keep the SIMT-mode example, and you keep the SIMT example for load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad, I just brought it back.


// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int packingFactor =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is packingFactor value for block load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because chunk size is limited to 32-bit data type in PVC. so lower precision data will be loaded as 32-bit and then cast to lower precision. this limits the availability of chunk size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this variable is only used for "chunkSize" case, then consider change name to be "packingFactorForChunk".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed it to chunkAlignmentFactor.

%cst_0 = arith.constant dense<true> : vector<16xi1>
%2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
%3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
%4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep this example with the xegpu.load feeding DPAS B operand? any specific reason for removing it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because, the current layout assignment for load is incorrect (cannot pass the verifier of tensor_desc). there is a conflict between the load and the dpas.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. got it. but I think user might want to force this layout also. I feel like we should not check this in the verifier. [1, 16][2, 1] can also be supported in our lowering. we just need to emit multiple loads.


if (tdescTy.isScattered()) {
int packingFactor =
xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about bitwidth > 32 case? maybe packingFactor computation can come before this if logic and reused in both cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

// CHECK: [[m:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: [[desc:%.+]] = xegpu.create_tdesc [[arg0]], [[idx]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
// CHECK: xegpu.load [[desc]], [[m]] : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x8xf16>
func.func @load_gather_with_chunksize(%arg0: memref<256xf16>) -> vector<16x8xf16> {
Copy link
Contributor

@Jianhui-Li Jianhui-Li Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks a new case.
It is better to keep the existing test case above, and just modify it slightly to reflect the effect of "transpose attribute removal".
original test excerpt:
%2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
%3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
%4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>

Modified test with transpose attribute removed from load and explicitly represented
// CHECK-NEXT: %{{.}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
// CHECK-NEXT: %{{.
}} = vector.transpose %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :

%2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
%3' = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
%3 = vector.transpose %3' : vector<16x16xf16> -> vector<16x16xf16>
%4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>

Then at the WI distribution stage, Vector.transpose will be distribute to a "null" op, since the lane owns the same data fragment before and after the transpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current propagate pass cannot handle the modified test case. the layout <[1, 16],[2, 1]> assigned to the scattered tensor_desc is invalid. the current one should be <[1, 16], [1, 2]>, and a convert layout should be inserted between load and dpas.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the transpose, we don't need to insert conv_layout for this case. The propagation pass will be improved to insert conv_layout in case needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I updated it based on the suggestion.

Comment on lines +230 to +237
if (tdescTy.isScattered()) {
int packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
LaneData({1, packingFactor}));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for hard coding the layout here? I feel like this should still take the consumer's layout. making the layouts consistent should be done at conflict resolution stage.

We can remove any bad examples (with layout conflicts) if that is the reason for doing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantic of scattered tensor_desc explicitly defines the layout it can handle. When its user requests a different layout, it will create a conflict. I think this how the conflict is detected, and then handled. please correct me if I am wrong.

@chencha3 chencha3 requested a review from Jianhui-Li June 25, 2025 20:41
@chencha3 chencha3 merged commit 36fbc6a into llvm:main Jun 26, 2025
7 checks passed
@chencha3 chencha3 deleted the xegpu_remove_transpose_attr_for_gather_scatter branch June 26, 2025 00:44
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants