Skip to content

[Matrix] Propagate shape information through PHI insts #141681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b1b79aa
[Matrix] Propagate shape information through PHI instructions
jroelofs May 27, 2025
71b99d3
move formerly unsupported test to new home
jroelofs May 27, 2025
905c1e9
clang-format
jroelofs May 27, 2025
9ee44f0
add test for ConstantDataVector lowering
jroelofs May 27, 2025
169960d
move report_fatal_error outside of NDEBUG block
jroelofs May 27, 2025
b64a134
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs May 28, 2025
18951cd
fix bad merge
jroelofs May 31, 2025
f8aea05
use col major load intrinsics
jroelofs Jun 2, 2025
e56b225
add tests for phi's consuming phi's, and phi's with more than two inputs
jroelofs Jun 2, 2025
ffbc73f
handle phi's more like other ops. instcombine will clean up after us
jroelofs Jun 2, 2025
15fd60b
handle phi's with shape mismatch
jroelofs Jun 2, 2025
88bd8cb
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 2, 2025
655eb88
simplify getMatrix shim
jroelofs Jun 2, 2025
e262f76
test the other order of shape mismatch
jroelofs Jun 2, 2025
2c86c2f
clang-format
jroelofs Jun 9, 2025
86d3545
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
501414f
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
2e5b2d4
clang-format
jroelofs Jun 11, 2025
f048181
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 11, 2025
7511d17
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
2821467
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
4ba4e66
[Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch
jroelofs Jun 12, 2025
4cbc839
review feedback: parens for initializer
jroelofs Jun 16, 2025
6f8ec49
review feedback: rename to GetMatrix
jroelofs Jun 16, 2025
104c126
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
67ead37
drop code for splitting constants, add test for it
jroelofs Jun 16, 2025
c9b3992
split phi's in two phases
jroelofs Jun 16, 2025
da17d10
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
dd55682
clang-format
jroelofs Jun 16, 2025
08be3b4
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 17, 2025
5a991f7
rm constant.ll
jroelofs Jun 17, 2025
e9d0e62
test that shows reshape shuffles are inserted in the correct spot
jroelofs Jun 18, 2025
a760840
florian's suggestion is a little simpler: we already know it's a phi
jroelofs Jun 18, 2025
a62ef50
also use getInsertionPointAtDef() for non-inst phi operand reshape pl…
jroelofs Jun 18, 2025
47dd2e2
inline the lambda shim, to simplify
jroelofs Jun 18, 2025
0debc08
rm unnecessary include
jroelofs Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ static bool isUniformShape(Value *V) {
}

switch (I->getOpcode()) {
case Instruction::PHI:
case Instruction::FNeg:
return true;
default:
Expand Down Expand Up @@ -1136,7 +1137,27 @@ class LowerMatrixIntrinsics {

Changed |= !FusedInsts.empty();

// Fourth, lower remaining instructions with shape information.
// Fourth, pre-process all the PHINode's. The incoming values will be
// assigned later in VisitPHI.
for (Instruction *Inst : MatrixInsts) {
auto *PHI = dyn_cast<PHINode>(Inst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were missing an early continue here if Inst is in FusedInsts, which prevents heap-use-after-free if Inst has been removed before. W/o that, the tests fail with ASan. Should be fixed in 0816bb3

if (!PHI)
continue;

const ShapeInfo &SI = ShapeMap.at(Inst);
auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType();
MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);

IRBuilder<> Builder(Inst);
for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
PHI->getNumIncomingValues(),
PHI->getName()));
assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
Inst2ColumnMatrix[PHI] = PhiM;
}

// Fifth, lower remaining instructions with shape information.
for (Instruction *Inst : MatrixInsts) {
if (FusedInsts.count(Inst))
continue;
Expand All @@ -1161,6 +1182,8 @@ class LowerMatrixIntrinsics {
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
else if (auto *PHI = dyn_cast<PHINode>(Inst))
Result = VisitPHI(PHI, SI, Builder);
else
continue;

Expand Down Expand Up @@ -1458,7 +1481,8 @@ class LowerMatrixIntrinsics {
IRBuilder<> &Builder) {
auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
(void)inserted;
assert(inserted.second && "multiple matrix lowering mapping");
assert((inserted.second || isa<PHINode>(Inst)) &&
"multiple matrix lowering mapping");

ToRemove.push_back(Inst);
Value *Flattened = nullptr;
Expand Down Expand Up @@ -2239,6 +2263,35 @@ class LowerMatrixIntrinsics {
Builder);
}

MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
auto BlockIP = Inst->getParent()->getFirstInsertionPt();
Builder.SetInsertPoint(BlockIP);
MatrixTy PhiM = getMatrix(Inst, SI, Builder);

for (auto [IncomingV, IncomingB] :
llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) {
// getMatrix() may insert some instructions to help with reshaping. The
// safest place for those is at the top of the block after the rest of the
// PHI's. Even better, if we can put it in the incoming block.
Builder.SetInsertPoint(BlockIP);
if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
Builder.SetInsertPoint(*MaybeIP);

MatrixTy OpM = getMatrix(IncomingV, SI, Builder);

for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
}
}

// finalizeLowering() may also insert instructions in some cases. The safe
// place for those is at the end of the initial block of PHIs.
Builder.SetInsertPoint(BlockIP);
return PhiM;
}

/// Lower binary operators.
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
IRBuilder<> &Builder) {
Expand Down
Loading
Loading