-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][XeGPU] Add unroll patterns and blocking pass for XeGPU [2/N] #140163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
777a403
af01c99
e8b43fb
3f73fda
ab448a3
7b5e8f1
e2eb9e6
6ec3604
bc69a8d
4fc7540
132f15e
aa4ba9c
061b6e0
387ac93
ebd78ae
bbf4796
3807eea
c6695d9
50e33ff
ae22f27
9776850
6cffa44
d1584fc
e023c1a
ee912c2
3967810
aebc327
90e7563
f5bfc2f
598fbce
ff11a05
9f7f715
b164d7b
554f4b4
d9f2e81
18e49f6
1f218f4
f869b13
de75855
beacf8a
c4c7abd
70e84c4
7dd05fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
//===---- XeGPUBlocking.cpp ---- XeGPU Instructionlize Pass ---------------===// | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/XeGPU/Transforms/Passes.h" | ||
|
||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" | ||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h" | ||
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" | ||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" | ||
#include "mlir/Interfaces/LoopLikeInterface.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir { | ||
namespace xegpu { | ||
#define GEN_PASS_DEF_XEGPUBLOCKING | ||
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" | ||
} // namespace xegpu | ||
} // namespace mlir | ||
|
||
#define DEBUG_TYPE "xegpu-blocking" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) { | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ValueRange inputs = castOp.getInputs(); | ||
ValueRange outputs = castOp.getOutputs(); | ||
|
||
if (inputs.size() == 1 && outputs.size() == 1) { | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
castOp->replaceAllUsesWith(inputs); | ||
castOp->erase(); | ||
} | ||
|
||
VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType()); | ||
VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType()); | ||
if (inputTy && outputTy) { | ||
OpBuilder builder(castOp); | ||
// unpack | ||
if (inputs.size() > 1 && outputs.size() == 1) { | ||
ArrayRef<int64_t> shape = outputTy.getShape(); | ||
Value result = xegpu::createVectorWithShapeFromValues( | ||
builder, castOp.getLoc(), inputs, shape); | ||
castOp->replaceAllUsesWith(ValueRange(result)); | ||
castOp->erase(); | ||
} | ||
|
||
// pack | ||
if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) { | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ArrayRef<int64_t> tileShape = outputTy.getShape(); | ||
SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue( | ||
builder, castOp.getLoc(), inputs[0], tileShape); | ||
castOp->replaceAllUsesWith(results); | ||
castOp->erase(); | ||
} | ||
} | ||
} | ||
|
||
/// Unroll XeGPU ops to their instruction-level representation. | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class XeGPUBlockingPass final | ||
: public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> { | ||
public: | ||
void runOnOperation() override; | ||
|
||
private: | ||
// Get the tile shape for a given value. If the value has a layout | ||
// attribute and it is an SG layout, return the inst_data as the tile shape | ||
// if inst_data is available; otherwise, return the original shape of the | ||
// value. If the value does not have an SG layout, return std::nullopt. | ||
std::optional<SmallVector<int64_t>> | ||
getTileShape(TypedValue<ShapedType> value) const; | ||
|
||
std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const; | ||
|
||
std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const; | ||
|
||
// Get the tile shape for a given operation. | ||
std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const; | ||
|
||
// Determine if the operation requires unrolling. Return false if all operands | ||
// and results have tile shapes identical to their original types. Otherwise, | ||
// return true. | ||
bool needsUnroll(Operation *op) const; | ||
}; | ||
} // namespace | ||
|
||
std::optional<SmallVector<int64_t>> | ||
XeGPUBlockingPass::getTileShape(TypedValue<ShapedType> value) const { | ||
assert(value && "value must be non-null"); | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for these two similar functions of getTileShape...I think we can use a separate function to have the common code and call it from both functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, thanks |
||
if (layout && layout.isSgLayout()) { | ||
if (auto inst_data = layout.getInstData()) | ||
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef()); | ||
return llvm::to_vector(value.getType().getShape()); | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
std::optional<SmallVector<int64_t>> | ||
XeGPUBlockingPass::getTileShape(OpOperand &operand) const { | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand); | ||
if (layout && layout.isSgLayout()) { | ||
if (auto inst_data = layout.getInstData()) | ||
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef()); | ||
|
||
if (auto type = dyn_cast<ShapedType>(operand.get().getType())) | ||
return llvm::to_vector(type.getShape()); | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
std::optional<SmallVector<int64_t>> | ||
XeGPUBlockingPass::getTileShape(OpResult result) const { | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); | ||
if (layout && layout.isSgLayout()) { | ||
if (auto inst_data = layout.getInstData()) | ||
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef()); | ||
|
||
if (auto type = dyn_cast<ShapedType>(result.getType())) | ||
return llvm::to_vector(type.getShape()); | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
std::optional<SmallVector<int64_t>> | ||
XeGPUBlockingPass::getTileShape(Operation *op) const { | ||
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op)) | ||
return getTileShape(op->getOpResult(0)); | ||
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op)) | ||
return getTileShape(op->getOpOperand(0)); | ||
if (isa<xegpu::StoreNdOp>(op)) | ||
return getTileShape(op->getOpOperand(1)); | ||
|
||
if (isa<xegpu::DpasOp>(op)) { | ||
std::optional<SmallVector<int64_t>> aTile = | ||
getTileShape(op->getOpOperand(0)); | ||
std::optional<SmallVector<int64_t>> bTile = | ||
getTileShape(op->getOpOperand(1)); | ||
|
||
if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2) | ||
return std::nullopt; | ||
|
||
// semantic check for A and B | ||
if ((*aTile)[1] != (*bTile)[0]) | ||
return std::nullopt; | ||
|
||
// semantic check for C | ||
if (op->getNumOperands() == 3) { | ||
std::optional<SmallVector<int64_t>> cTile = | ||
getTileShape(op->getOpOperand(2)); | ||
int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]}; | ||
if (!cTile || !llvm::equal(*cTile, expectedCTile)) | ||
return std::nullopt; | ||
} | ||
|
||
return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]}); | ||
} | ||
|
||
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) | ||
return getTileShape(op->getOpResult(0)); | ||
|
||
return std::nullopt; | ||
} | ||
|
||
bool XeGPUBlockingPass::needsUnroll(Operation *op) const { | ||
if (isa<LoopLikeOpInterface>(op)) | ||
return false; | ||
|
||
for (auto &opr : op->getOpOperands()) { | ||
std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr); | ||
auto shapedType = dyn_cast<ShapedType>(opr.get().getType()); | ||
if (!shapedType) | ||
continue; | ||
|
||
if (tileShape && !llvm::equal(*tileShape, shapedType.getShape())) | ||
return true; | ||
} | ||
|
||
for (auto result : op->getOpResults()) { | ||
std::optional<SmallVector<int64_t>> tileShape = getTileShape(result); | ||
auto shapedType = dyn_cast<ShapedType>(result.getType()); | ||
if (!shapedType) | ||
continue; | ||
|
||
if (tileShape && !llvm::equal(*tileShape, shapedType.getShape())) | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
void XeGPUBlockingPass::runOnOperation() { | ||
MLIRContext *ctx = &getContext(); | ||
Operation *mod = getOperation(); | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. | ||
// This ensures that the LayoutAttr remains accessible even if the defining | ||
// operation is replaced. | ||
xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); }); | ||
|
||
// Perform type conversion for SCF control folow ops | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
xegpu::doSCFStructuralTypeConversionWithTensorType(mod); | ||
|
||
xegpu::UnrollOptions options; | ||
options.setFilterConstraint([&](Operation *op) -> LogicalResult { | ||
return needsUnroll(op) ? success() : failure(); | ||
chencha3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}); | ||
|
||
options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); }); | ||
|
||
options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) { | ||
Type elemTy = type.getElementType(); | ||
Type newTy; | ||
|
||
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) | ||
newTy = xegpu::TensorDescType::get( | ||
ctx, tileShape, elemTy, tdescTy.getEncoding(), | ||
tdescTy.getLayoutAttr().dropInstData()); | ||
else | ||
newTy = type.clone(tileShape, elemTy); | ||
|
||
std::optional<SmallVector<int64_t>> ratio = | ||
computeShapeRatio(type.getShape(), tileShape); | ||
assert(ratio && "The shape of the type must be a multiple of tileShape."); | ||
return SmallVector<Type>(computeProduct(*ratio), newTy); | ||
}); | ||
|
||
RewritePatternSet patterns(ctx); | ||
|
||
vector::UnrollVectorOptions vectorOptions; | ||
vectorOptions.setNativeShapeFn(options.nativeShape); | ||
|
||
populateXeGPUUnrollPatterns(patterns, options); | ||
vector::populateVectorUnrollPatterns(patterns, vectorOptions); | ||
|
||
(void)applyPatternsGreedily(mod, std::move(patterns)); | ||
|
||
mod->walk([&](Operation *op) { | ||
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) | ||
resolveUnrealizedConversionCastOp(castOp); | ||
|
||
for (OpOperand &opr : op->getOpOperands()) { | ||
std::string name = xegpu::getLayoutName(opr); | ||
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) | ||
op->removeAttr(name); | ||
} | ||
|
||
for (OpResult result : op->getOpResults()) { | ||
std::string name = xegpu::getLayoutName(result); | ||
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) { | ||
op->removeAttr(name); | ||
if (!isa<LoopLikeOpInterface>(op)) | ||
xegpu::setLayoutAttr(result, layout.dropInstData()); | ||
} | ||
} | ||
}); | ||
} |
Uh oh!
There was an error while loading. Please reload this page.