Skip to content

[Matrix] Propagate shape information through PHI insts #141681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 18, 2025

Conversation

jroelofs
Copy link
Contributor

@jroelofs jroelofs commented May 27, 2025

... and split them as we lower them, avoiding several shuffles in the process.

... and split them as we lower themm, avoiding several shuffles in the process.
@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

... and split them as we lower themm, avoiding several shuffles in the process.


Patch is 42.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141681.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+92-1)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll (+216)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll (+58-65)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56d4be513ea6f..c06d08688ab1c 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
@@ -230,6 +231,7 @@ static bool isUniformShape(Value *V) {
     return true;
 
   switch (I->getOpcode()) {
+  case Instruction::PHI:
   case Instruction::FAdd:
   case Instruction::FSub:
   case Instruction::FMul: // Scalar multiply.
@@ -360,6 +362,33 @@ class LowerMatrixIntrinsics {
         addVector(PoisonValue::get(FixedVectorType::get(
             EltTy, isColumnMajor() ? NumRows : NumColumns)));
     }
+    MatrixTy(ConstantData *Constant, const ShapeInfo &SI)
+        : IsColumnMajor(SI.IsColumnMajor) {
+      Type *EltTy = cast<VectorType>(Constant->getType())->getElementType();
+      Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows));
+
+      for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) {
+        if (auto *CDV = dyn_cast<ConstantDataVector>(Constant)) {
+          unsigned Width = SI.getStride();
+          size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8;
+          StringRef Data = CDV->getRawDataValues().substr(
+              J * Width * EltSize, Width * EltSize);
+          addVector(ConstantDataVector::getRaw(Data, Width,
+                                               CDV->getElementType()));
+        } else if (isa<PoisonValue>(Constant))
+          addVector(PoisonValue::get(RowTy));
+        else if (isa<UndefValue>(Constant))
+          addVector(UndefValue::get(RowTy));
+        else if (isa<ConstantAggregateZero>(Constant))
+          addVector(ConstantAggregateZero::get(RowTy));
+        else {
+#ifndef NDEBUG
+          Constant->dump();
+          report_fatal_error("unhandled ConstantData type");
+#endif
+        }
+      }
+    }
 
     Value *getVector(unsigned i) const { return Vectors[i]; }
     Value *getColumn(unsigned i) const {
@@ -564,6 +593,27 @@ class LowerMatrixIntrinsics {
       MatrixVal = M.embedInVector(Builder);
     }
 
+    // If it's a PHI, split it now. We'll take care of fixing up the operands
+    // later once we're in VisitPHI.
+    if (auto *PHI = dyn_cast<PHINode>(MatrixVal)) {
+      auto *EltTy = cast<VectorType>(PHI->getType())->getElementType();
+      MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy};
+
+      IRBuilder<>::InsertPointGuard IPG(Builder);
+      Builder.SetInsertPoint(PHI);
+      for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
+        PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
+                                             PHI->getNumIncomingValues(),
+                                             PHI->getName()));
+
+      Inst2ColumnMatrix[PHI] = PhiM;
+      return PhiM;
+    }
+
+    // If it's a constant, materialize the split version of it with this shape.
+    if (auto *IncomingConst = dyn_cast<ConstantData>(MatrixVal))
+      return MatrixTy(IncomingConst, SI);
+
     // Otherwise split MatrixVal.
     SmallVector<Value *, 16> SplitVecs;
     for (unsigned MaskStart = 0;
@@ -1077,6 +1127,11 @@ class LowerMatrixIntrinsics {
         Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
     }
 
+    // Fifth, lower all the PHI's with shape information.
+    for (Instruction *Inst : MatrixInsts)
+      if (auto *PHI = dyn_cast<PHINode>(Inst))
+        Changed |= VisitPHI(PHI);
+
     if (ORE) {
       RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
       RemarkGen.emitRemarks();
@@ -1349,7 +1404,8 @@ class LowerMatrixIntrinsics {
                         IRBuilder<> &Builder) {
     auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
     (void)inserted;
-    assert(inserted.second && "multiple matrix lowering mapping");
+    assert((inserted.second || isa<PHINode>(Inst)) &&
+           "multiple matrix lowering mapping");
 
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
@@ -2133,6 +2189,41 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  bool VisitPHI(PHINode *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+
+    MatrixTy PhiM = getMatrix(Inst, I->second, Builder);
+
+    for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues();
+         IncomingI != IncomingE; ++IncomingI) {
+      Value *IncomingV = Inst->getIncomingValue(IncomingI);
+      BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI);
+
+      // getMatrix() may insert some instructions. The safe place to insert them
+      // is at the end of the parent block, where the register allocator would
+      // have inserted the copies that materialize the PHI.
+      if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
+        Builder.SetInsertPoint(IncomingInst->getParent()->getTerminator());
+
+      MatrixTy OpM = getMatrix(IncomingV, I->second, Builder);
+
+      for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
+        PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
+        NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
+      }
+    }
+
+    // finalizeLowering() may also insert instructions in some cases. The safe
+    // place for those is at the end of the initial block of PHIs.
+    Builder.SetInsertPoint(*Inst->getInsertionPointAfterDef());
+    finalizeLowering(Inst, PhiM, Builder);
+    return true;
+  }
+
   /// Lower binary operators, if shape information is available.
   bool VisitBinaryOperator(BinaryOperator *Inst) {
     auto I = ShapeMap.find(Inst);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll
new file mode 100644
index 0000000000000..d49b4d1112062
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll
@@ -0,0 +1,216 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -matrix-allow-contract=false -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @matrix_phi(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI9:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI10:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI11:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP12]], align 8
+; CHECK-NEXT:    [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP13]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  %mat = load <9 x double>, ptr %in1
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [%mat, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_zeroinitializer(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [zeroinitializer, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_undef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ undef, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [undef, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
+
+define void @matrix_phi_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) {
+; CHECK-LABEL: @matrix_phi_poison(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[PHI4:%.*]] = phi <3 x double> [ poison, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI5:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[PHI6:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[DEC]] = sub i32 [[CTR]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
+; CHECK-NEXT:    store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
+; CHECK-NEXT:    store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %phi = phi <9 x double> [poison, %entry], [%sum, %loop]
+  %ctr = phi i32 [%count, %entry], [%dec, %loop]
+
+  %in2v = load <9 x double>, ptr %in2
+
+  ; Give in2 the shape: 3 x 3
+  %in2t  = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
+  %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)
+
+  %sum = fadd <9 x double> %phi, %in2tt
+
+  %dec = sub i32 %ctr, 1
+  %cmp = icmp eq i32 %dec, 0
+  br i1 %cmp, label %exit, label %loop
+
+exit:
+  store <9 x double> %sum, ptr %out
+  ret void
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
index 2af2c979f2065..6ed8e46d62892 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll
@@ -28,9 +28,6 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B,
 ; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <3 x double> [[TMP13]], double [[TMP14]], i64 1
 ; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 2
 ; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <3 x double> [[TMP15]], double [[TMP16]], i64 2
-; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <3 x double> [[TMP5]], <3 x double> [[TMP11]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <3 x double> [[TMP17]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <6 x double> [[TMP18]], <6 x double> [[TMP19]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
 ; CHECK-NEXT:    br label [[IF_END:%.*]]
 ; CHECK:       if.else:
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <9 x double> [[B:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
@@ -54,183 +51,179 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B,
 ; CHECK-NEXT:    [[TMP36:%.*]] = insertelement <3 x double> [[TMP34]], double [[TMP35]], i64 1
 ; CHECK-NEXT:    [[TMP37:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 2
 ; CHECK-NEXT:    [[TMP38:%.*]] = insertelement <3 x double> [[TMP36]], double [[TMP37]], i64 2
-; CHECK-NEXT:    [[TMP39:%.*]] = shufflevector <3 x double> [[TMP26]], <3 x double> [[TMP32]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP40:%.*]] = shufflevector <3 x double> [[TMP38]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP41:%.*]] = shufflevector <6 x double> [[TMP39]], <6 x double> [[TMP40]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
 ; CHECK-NEXT:    br label [[IF_END]]
 ; CHECK:       if.end:
-; CHECK-NEXT:    [[MERGE:%.*]] = phi <9 x double> [ [[TMP20]], [[IF_THEN]] ], [ [[TMP41]], [[IF_ELSE]] ]
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <9 x double> [[C:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> <i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CHECK-NEXT:    [[MERGE9:%.*]] = phi <3 x double> [ [[TMP5]], [[IF_THEN]] ], [ [[TMP26]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[MERGE10:%.*]] = phi <3 x double> [ [[TMP11]], [[IF_THEN]] ], [ [[TMP32]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[MERGE11:%.*]] = phi <3 x double> [ [[TMP17]], [[IF_THEN]] ], [ [[TMP38]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE:%.*]], <9 x double> poison, <3 x i32> <i32 0, i32 1, i32 2>
 ; CHECK-NEXT:    [[SPLIT10:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 3, i32 4, i32 5>
 ; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 0
+; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[T...
[truncated]

Copy link

github-actions bot commented May 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

github-actions bot commented May 27, 2025

⚠️ undef deprecator found issues in your code. ⚠️

You can test this locally with the following command:
git diff -U0 --pickaxe-regex -S '([^a-zA-Z0-9#_-]undef[^a-zA-Z0-9_-]|UndefValue::get)' 'HEAD~1' HEAD llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll

The following files introduce new uses of undef:

  • llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll

Undef is now deprecated and should only be used in the rare cases where no replacement is possible. For example, a load of uninitialized memory yields undef. You should use poison values for placeholders instead.

In tests, avoid using undef and having tests that trigger undefined behavior. If you need an operand with some unimportant value, you can add a new argument to the function and use that instead.

For example, this is considered a bad practice:

define void @fn() {
  ...
  br i1 undef, ...
}

Please use the following instead:

define void @fn(i1 %cond) {
  ...
  br i1 %cond, ...
}

Please refer to the Undefined Behavior Manual for more information.

@jroelofs jroelofs changed the title [Matrix] Propagate shape information through PHI instructions [Matrix] Propagate shape information through PHI insts May 28, 2025
@jroelofs jroelofs requested a review from fhahn June 9, 2025 22:30
@jroelofs
Copy link
Contributor Author

jroelofs commented Jun 9, 2025

⚠️ undef deprecator found issues in your code. ⚠️

I am going to ignore this one. It should be appropriate for a lowering pass to handle undef, and for tests to check that behavior. When undef is actually deprecated, these should be easy to remove.

@jroelofs jroelofs requested a review from fpetrogalli June 11, 2025 15:12
@jroelofs
Copy link
Contributor Author

ping

@jroelofs jroelofs requested a review from fhahn June 17, 2025 18:37
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@jroelofs jroelofs merged commit 0fa373c into llvm:main Jun 18, 2025
4 of 7 checks passed
fhahn added a commit that referenced this pull request Jun 19, 2025
We need to skip instructions in FusedInsts, as they may have been
deleted. Fixes a heap-use-after-free after #141681.
// Fourth, pre-process all the PHINode's. The incoming values will be
// assigned later in VisitPHI.
for (Instruction *Inst : MatrixInsts) {
auto *PHI = dyn_cast<PHINode>(Inst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were missing an early continue here if Inst is in FusedInsts, which prevents heap-use-after-free if Inst has been removed before. W/o that, the tests fail with ASan. Should be fixed in 0816bb3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants