Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
777a403
add utils
chencha3 May 12, 2025
af01c99
add skeleton
chencha3 May 12, 2025
e8b43fb
add filter
chencha3 May 13, 2025
3f73fda
clean up
chencha3 May 13, 2025
ab448a3
add scf type conversion util
chencha3 May 13, 2025
7b5e8f1
partial working
chencha3 May 13, 2025
e2eb9e6
refactor pack and unpack
chencha3 May 15, 2025
6ec3604
cleanup layout attr
chencha3 May 15, 2025
bc69a8d
check in elemwise support
chencha3 May 16, 2025
4fc7540
check in unit test
chencha3 May 16, 2025
132f15e
fix format
chencha3 May 16, 2025
aa4ba9c
roll back pass name
chencha3 May 16, 2025
061b6e0
add 1d and 2d elemwise test
chencha3 May 16, 2025
387ac93
refactor
chencha3 May 16, 2025
ebd78ae
fix naming issue
chencha3 May 16, 2025
bbf4796
fix format
chencha3 May 16, 2025
3807eea
fix overflow
chencha3 May 19, 2025
c6695d9
add comments
chencha3 May 19, 2025
50e33ff
add dbg log
chencha3 May 20, 2025
ae22f27
fix format
chencha3 May 20, 2025
9776850
cleanup
chencha3 May 20, 2025
6cffa44
refactor
chencha3 May 20, 2025
d1584fc
Merge branch 'main' into xegpu_blocking_pass
chencha3 May 21, 2025
e023c1a
add a corner unit test
chencha3 May 22, 2025
ee912c2
Merge branch 'main' into xegpu_blocking_pass
chencha3 May 22, 2025
3967810
fix comments
chencha3 May 23, 2025
aebc327
remove unnecessary reference for lambda
chencha3 May 27, 2025
90e7563
rename
chencha3 May 27, 2025
f5bfc2f
address comments
chencha3 May 27, 2025
598fbce
fix format
chencha3 May 27, 2025
ff11a05
add comments
chencha3 May 27, 2025
9f7f715
add comments
chencha3 May 27, 2025
b164d7b
address comments
chencha3 May 27, 2025
554f4b4
refactor
chencha3 May 27, 2025
d9f2e81
refactor getTileShape with template
chencha3 May 27, 2025
18e49f6
add qualifiers
chencha3 May 27, 2025
1f218f4
add qualifiers
chencha3 May 27, 2025
f869b13
refactor setLayoutAttrs
chencha3 May 27, 2025
de75855
cleanup unnecessary reference symbols
chencha3 May 27, 2025
beacf8a
update naming
chencha3 May 28, 2025
c4c7abd
refactor
chencha3 May 28, 2025
70e84c4
refine comments
chencha3 Jun 2, 2025
7dd05fa
Update mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
chencha3 Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
}

LayoutAttr dropSgLayoutAndData() {
if (!getInstData() && !getLaneLayout())
return nullptr;
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
getLaneLayout(), getLaneData(), getOrder());
}

LayoutAttr dropInstData() {
if (!getSgLayout() && !getLaneLayout())
return nullptr;
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,16 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
];
}

def XeGPUBlocking: Pass<"xegpu-blocking"> {
let summary = "Instructionlize XeGPU ops";
let description = [{
The pass unrolls XeGPU ops working on large shapes into ops working on small shapes
(given by the inst_data in the layout attr), such that each of them can be dispatch
into a hardware instruction.
}];
let dependentDialects = [
"memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
];
}

#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
namespace mlir {

class VectorType;
class OpOperand;
class OpResult;
class OpBuilder;
class ValueRange;

namespace xegpu {
class LayoutAttr;
class TensorDescType;
Expand Down Expand Up @@ -50,6 +55,52 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
LayoutAttr layout);

/// Return the attribute name for the OpOperand to attach LayoutAttr
std::string getLayoutName(OpOperand &opr);

/// Return the attribute name for the OpResult to attach LayoutAttr
std::string getLayoutName(OpResult res);

/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
/// values, the LayoutAttr is extracted from the TensorDescType itself. For
/// other values, it is obtained from the attributes of the defining operation.
/// Returns nullptr if no LayoutAttr is found.
LayoutAttr getLayoutAttr(Value value);

/// Retrieves the LayoutAttr associated with a given OpOperand. It will
/// first check the operand_layout_{id} of the owner operation. If not found,
/// it will check the operand itself and its defining op.
LayoutAttr getLayoutAttr(OpOperand &opr);

/// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
void setLayoutAttr(OpOperand &opr, LayoutAttr layout);

/// Set the LayoutAttr for the given OpResult by attching it to the defining op
void setLayoutAttr(OpResult result, LayoutAttr layout);

/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
/// If the operation contains regions, it is also applied recursively to the
/// contained operations
void setLayoutAttrs(Operation *op,
function_ref<LayoutAttr(Value)> getLayoutImpl);

/// Extract a set of small vectors from a value with a given shape using
/// vector.extract_stride_slice
SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,
Location loc, Value value,
ArrayRef<int64_t> shape);

/// Create a vector of shape from a set of values using
/// vector.insert_stride_slice.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
ValueRange values,
ArrayRef<int64_t> shape);

/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
/// cannot carry the layout attribute, they are converted into RankedTensorType
/// first, which will convert back to VectorType in the second round.
void doSCFStructuralTypeConversionWithTensorType(Operation *op);

} // namespace xegpu

} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUBlocking.cpp
XeGPUFoldAliasOps.cpp
XeGPUSubgroupDistribute.cpp
XeGPUUnroll.cpp
Expand Down
265 changes: 265 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
//===---- XeGPUBlocking.cpp ---- XeGPU Instructionlize Pass ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/XeGPU/Transforms/Passes.h"

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUBLOCKING
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir

#define DEBUG_TYPE "xegpu-blocking"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

using namespace mlir;

namespace {

void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
ValueRange inputs = castOp.getInputs();
ValueRange outputs = castOp.getOutputs();

if (inputs.size() == 1 && outputs.size() == 1) {
castOp->replaceAllUsesWith(inputs);
castOp->erase();
}

VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
if (inputTy && outputTy) {
OpBuilder builder(castOp);
// unpack
if (inputs.size() > 1 && outputs.size() == 1) {
ArrayRef<int64_t> shape = outputTy.getShape();
Value result = xegpu::createVectorWithShapeFromValues(
builder, castOp.getLoc(), inputs, shape);
castOp->replaceAllUsesWith(ValueRange(result));
castOp->erase();
}

// pack
if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
ArrayRef<int64_t> tileShape = outputTy.getShape();
SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
builder, castOp.getLoc(), inputs[0], tileShape);
castOp->replaceAllUsesWith(results);
castOp->erase();
}
}
}

/// Unroll XeGPU ops to their instruction-level representation.
class XeGPUBlockingPass final
: public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
public:
void runOnOperation() override;

private:
// Get the tile shape for a given value. If the value has a layout
// attribute and it is an SG layout, return the inst_data as the tile shape
// if inst_data is available; otherwise, return the original shape of the
// value. If the value does not have an SG layout, return std::nullopt.
std::optional<SmallVector<int64_t>>
getTileShape(TypedValue<ShapedType> value) const;

std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;

std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;

// Get the tile shape for a given operation.
std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;

// Determine if the operation requires unrolling. Return false if all operands
// and results have tile shapes identical to their original types. Otherwise,
// return true.
bool needsUnroll(Operation *op) const;
};
} // namespace

std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(TypedValue<ShapedType> value) const {
assert(value && "value must be non-null");
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for these two similar functions of getTileShape...I think we can use a separate function to have the common code and call it from both functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks

if (layout && layout.isSgLayout()) {
if (auto inst_data = layout.getInstData())
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
return llvm::to_vector(value.getType().getShape());
}
return std::nullopt;
}

std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
if (layout && layout.isSgLayout()) {
if (auto inst_data = layout.getInstData())
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());

if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
return llvm::to_vector(type.getShape());
}
return std::nullopt;
}

std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(OpResult result) const {
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
if (layout && layout.isSgLayout()) {
if (auto inst_data = layout.getInstData())
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());

if (auto type = dyn_cast<ShapedType>(result.getType()))
return llvm::to_vector(type.getShape());
}
return std::nullopt;
}

std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(Operation *op) const {
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
return getTileShape(op->getOpOperand(0));
if (isa<xegpu::StoreNdOp>(op))
return getTileShape(op->getOpOperand(1));

if (isa<xegpu::DpasOp>(op)) {
std::optional<SmallVector<int64_t>> aTile =
getTileShape(op->getOpOperand(0));
std::optional<SmallVector<int64_t>> bTile =
getTileShape(op->getOpOperand(1));

if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
return std::nullopt;

// semantic check for A and B
if ((*aTile)[1] != (*bTile)[0])
return std::nullopt;

// semantic check for C
if (op->getNumOperands() == 3) {
std::optional<SmallVector<int64_t>> cTile =
getTileShape(op->getOpOperand(2));
int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
if (!cTile || !llvm::equal(*cTile, expectedCTile))
return std::nullopt;
}

return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
}

if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
return getTileShape(op->getOpResult(0));

return std::nullopt;
}

bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
if (isa<LoopLikeOpInterface>(op))
return false;

for (auto &opr : op->getOpOperands()) {
std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
if (!shapedType)
continue;

if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
return true;
}

for (auto result : op->getOpResults()) {
std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
auto shapedType = dyn_cast<ShapedType>(result.getType());
if (!shapedType)
continue;

if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
return true;
}
return false;
}

void XeGPUBlockingPass::runOnOperation() {
MLIRContext *ctx = &getContext();
Operation *mod = getOperation();

// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });

// Perform type conversion for SCF control folow ops
xegpu::doSCFStructuralTypeConversionWithTensorType(mod);

xegpu::UnrollOptions options;
options.setFilterConstraint([&](Operation *op) -> LogicalResult {
return needsUnroll(op) ? success() : failure();
});

options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });

options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
Type elemTy = type.getElementType();
Type newTy;

if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
newTy = xegpu::TensorDescType::get(
ctx, tileShape, elemTy, tdescTy.getEncoding(),
tdescTy.getLayoutAttr().dropInstData());
else
newTy = type.clone(tileShape, elemTy);

std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "The shape of the type must be a multiple of tileShape.");
return SmallVector<Type>(computeProduct(*ratio), newTy);
});

RewritePatternSet patterns(ctx);

vector::UnrollVectorOptions vectorOptions;
vectorOptions.setNativeShapeFn(options.nativeShape);

populateXeGPUUnrollPatterns(patterns, options);
vector::populateVectorUnrollPatterns(patterns, vectorOptions);

(void)applyPatternsGreedily(mod, std::move(patterns));

mod->walk([&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
resolveUnrealizedConversionCastOp(castOp);

for (OpOperand &opr : op->getOpOperands()) {
std::string name = xegpu::getLayoutName(opr);
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name))
op->removeAttr(name);
}

for (OpResult result : op->getOpResults()) {
std::string name = xegpu::getLayoutName(result);
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<LoopLikeOpInterface>(op))
xegpu::setLayoutAttr(result, layout.dropInstData());
}
}
});
}
Loading