Skip to content

Commit 0fa373c

Browse files
authored
[Matrix] Propagate shape information through PHI insts (#141681)
... and split them as we lower them, avoiding several shuffles in the process.
1 parent b5aaf9d commit 0fa373c

File tree

3 files changed

+844
-263
lines changed

3 files changed

+844
-263
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ static bool isUniformShape(Value *V) {
288288
}
289289

290290
switch (I->getOpcode()) {
291+
case Instruction::PHI:
291292
case Instruction::FNeg:
292293
return true;
293294
default:
@@ -1136,7 +1137,27 @@ class LowerMatrixIntrinsics {
11361137

11371138
Changed |= !FusedInsts.empty();
11381139

1139-
// Fourth, lower remaining instructions with shape information.
1140+
// Fourth, pre-process all the PHINode's. The incoming values will be
1141+
// assigned later in VisitPHI.
1142+
for (Instruction *Inst : MatrixInsts) {
1143+
auto *PHI = dyn_cast<PHINode>(Inst);
1144+
if (!PHI)
1145+
continue;
1146+
1147+
const ShapeInfo &SI = ShapeMap.at(Inst);
1148+
auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType();
1149+
MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);
1150+
1151+
IRBuilder<> Builder(Inst);
1152+
for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
1153+
PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
1154+
PHI->getNumIncomingValues(),
1155+
PHI->getName()));
1156+
assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
1157+
Inst2ColumnMatrix[PHI] = PhiM;
1158+
}
1159+
1160+
// Fifth, lower remaining instructions with shape information.
11401161
for (Instruction *Inst : MatrixInsts) {
11411162
if (FusedInsts.count(Inst))
11421163
continue;
@@ -1161,6 +1182,8 @@ class LowerMatrixIntrinsics {
11611182
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
11621183
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
11631184
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1185+
else if (auto *PHI = dyn_cast<PHINode>(Inst))
1186+
Result = VisitPHI(PHI, SI, Builder);
11641187
else
11651188
continue;
11661189

@@ -1458,7 +1481,8 @@ class LowerMatrixIntrinsics {
14581481
IRBuilder<> &Builder) {
14591482
auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
14601483
(void)inserted;
1461-
assert(inserted.second && "multiple matrix lowering mapping");
1484+
assert((inserted.second || isa<PHINode>(Inst)) &&
1485+
"multiple matrix lowering mapping");
14621486

14631487
ToRemove.push_back(Inst);
14641488
Value *Flattened = nullptr;
@@ -2239,6 +2263,35 @@ class LowerMatrixIntrinsics {
22392263
Builder);
22402264
}
22412265

2266+
MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2267+
auto BlockIP = Inst->getParent()->getFirstInsertionPt();
2268+
Builder.SetInsertPoint(BlockIP);
2269+
MatrixTy PhiM = getMatrix(Inst, SI, Builder);
2270+
2271+
for (auto [IncomingV, IncomingB] :
2272+
llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) {
2273+
// getMatrix() may insert some instructions to help with reshaping. The
2274+
// safest place for those is at the top of the block after the rest of the
2275+
// PHI's. Even better, if we can put it in the incoming block.
2276+
Builder.SetInsertPoint(BlockIP);
2277+
if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2278+
if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
2279+
Builder.SetInsertPoint(*MaybeIP);
2280+
2281+
MatrixTy OpM = getMatrix(IncomingV, SI, Builder);
2282+
2283+
for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
2284+
PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
2285+
NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
2286+
}
2287+
}
2288+
2289+
// finalizeLowering() may also insert instructions in some cases. The safe
2290+
// place for those is at the end of the initial block of PHIs.
2291+
Builder.SetInsertPoint(BlockIP);
2292+
return PhiM;
2293+
}
2294+
22422295
/// Lower binary operators.
22432296
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
22442297
IRBuilder<> &Builder) {

0 commit comments

Comments
 (0)