@@ -288,6 +288,7 @@ static bool isUniformShape(Value *V) {
288
288
}
289
289
290
290
switch (I->getOpcode ()) {
291
+ case Instruction::PHI:
291
292
case Instruction::FNeg:
292
293
return true ;
293
294
default :
@@ -1136,7 +1137,27 @@ class LowerMatrixIntrinsics {
1136
1137
1137
1138
Changed |= !FusedInsts.empty ();
1138
1139
1139
- // Fourth, lower remaining instructions with shape information.
1140
+ // Fourth, pre-process all the PHINode's. The incoming values will be
1141
+ // assigned later in VisitPHI.
1142
+ for (Instruction *Inst : MatrixInsts) {
1143
+ auto *PHI = dyn_cast<PHINode>(Inst);
1144
+ if (!PHI)
1145
+ continue ;
1146
+
1147
+ const ShapeInfo &SI = ShapeMap.at (Inst);
1148
+ auto *EltTy = cast<FixedVectorType>(PHI->getType ())->getElementType ();
1149
+ MatrixTy PhiM (SI.NumRows , SI.NumColumns , EltTy);
1150
+
1151
+ IRBuilder<> Builder (Inst);
1152
+ for (unsigned VI = 0 , VE = PhiM.getNumVectors (); VI != VE; ++VI)
1153
+ PhiM.setVector (VI, Builder.CreatePHI (PhiM.getVectorTy (),
1154
+ PHI->getNumIncomingValues (),
1155
+ PHI->getName ()));
1156
+ assert (!Inst2ColumnMatrix.contains (PHI) && " map already contains phi?" );
1157
+ Inst2ColumnMatrix[PHI] = PhiM;
1158
+ }
1159
+
1160
+ // Fifth, lower remaining instructions with shape information.
1140
1161
for (Instruction *Inst : MatrixInsts) {
1141
1162
if (FusedInsts.count (Inst))
1142
1163
continue ;
@@ -1161,6 +1182,8 @@ class LowerMatrixIntrinsics {
1161
1182
Result = VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
1162
1183
else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1163
1184
Result = VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1185
+ else if (auto *PHI = dyn_cast<PHINode>(Inst))
1186
+ Result = VisitPHI (PHI, SI, Builder);
1164
1187
else
1165
1188
continue ;
1166
1189
@@ -1458,7 +1481,8 @@ class LowerMatrixIntrinsics {
1458
1481
IRBuilder<> &Builder) {
1459
1482
auto inserted = Inst2ColumnMatrix.insert (std::make_pair (Inst, Matrix));
1460
1483
(void )inserted;
1461
- assert (inserted.second && " multiple matrix lowering mapping" );
1484
+ assert ((inserted.second || isa<PHINode>(Inst)) &&
1485
+ " multiple matrix lowering mapping" );
1462
1486
1463
1487
ToRemove.push_back (Inst);
1464
1488
Value *Flattened = nullptr ;
@@ -2239,6 +2263,35 @@ class LowerMatrixIntrinsics {
2239
2263
Builder);
2240
2264
}
2241
2265
2266
+ MatrixTy VisitPHI (PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2267
+ auto BlockIP = Inst->getParent ()->getFirstInsertionPt ();
2268
+ Builder.SetInsertPoint (BlockIP);
2269
+ MatrixTy PhiM = getMatrix (Inst, SI, Builder);
2270
+
2271
+ for (auto [IncomingV, IncomingB] :
2272
+ llvm::zip_equal (Inst->incoming_values (), Inst->blocks ())) {
2273
+ // getMatrix() may insert some instructions to help with reshaping. The
2274
+ // safest place for those is at the top of the block after the rest of the
2275
+ // PHI's. Even better, if we can put it in the incoming block.
2276
+ Builder.SetInsertPoint (BlockIP);
2277
+ if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2278
+ if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef ())
2279
+ Builder.SetInsertPoint (*MaybeIP);
2280
+
2281
+ MatrixTy OpM = getMatrix (IncomingV, SI, Builder);
2282
+
2283
+ for (unsigned VI = 0 , VE = PhiM.getNumVectors (); VI != VE; ++VI) {
2284
+ PHINode *NewPHI = cast<PHINode>(PhiM.getVector (VI));
2285
+ NewPHI->addIncoming (OpM.getVector (VI), IncomingB);
2286
+ }
2287
+ }
2288
+
2289
+ // finalizeLowering() may also insert instructions in some cases. The safe
2290
+ // place for those is at the end of the initial block of PHIs.
2291
+ Builder.SetInsertPoint (BlockIP);
2292
+ return PhiM;
2293
+ }
2294
+
2242
2295
// / Lower binary operators.
2243
2296
MatrixTy VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI,
2244
2297
IRBuilder<> &Builder) {
0 commit comments