From b1b79aad6639f82f4d4f70eed23f6c252dcdd2c6 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 27 May 2025 12:59:31 -0700 Subject: [PATCH 01/26] [Matrix] Propagate shape information through PHI instructions ... and split them as we lower themm, avoiding several shuffles in the process. --- .../Scalar/LowerMatrixIntrinsics.cpp | 93 +++++++- .../Transforms/LowerMatrixIntrinsics/phi.ll | 216 ++++++++++++++++++ .../propagate-backwards-unsupported.ll | 123 +++++----- 3 files changed, 366 insertions(+), 66 deletions(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 56d4be513ea6f..c06d08688ab1c 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -30,6 +30,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" @@ -230,6 +231,7 @@ static bool isUniformShape(Value *V) { return true; switch (I->getOpcode()) { + case Instruction::PHI: case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: // Scalar multiply. @@ -360,6 +362,33 @@ class LowerMatrixIntrinsics { addVector(PoisonValue::get(FixedVectorType::get( EltTy, isColumnMajor() ? NumRows : NumColumns))); } + MatrixTy(ConstantData *Constant, const ShapeInfo &SI) + : IsColumnMajor(SI.IsColumnMajor) { + Type *EltTy = cast(Constant->getType())->getElementType(); + Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows)); + + for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) { + if (auto *CDV = dyn_cast(Constant)) { + unsigned Width = SI.getStride(); + size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8; + StringRef Data = CDV->getRawDataValues().substr( + J * Width * EltSize, Width * EltSize); + addVector(ConstantDataVector::getRaw(Data, Width, + CDV->getElementType())); + } else if (isa(Constant)) + addVector(PoisonValue::get(RowTy)); + else if (isa(Constant)) + addVector(UndefValue::get(RowTy)); + else if (isa(Constant)) + addVector(ConstantAggregateZero::get(RowTy)); + else { +#ifndef NDEBUG + Constant->dump(); + report_fatal_error("unhandled ConstantData type"); +#endif + } + } + } Value *getVector(unsigned i) const { return Vectors[i]; } Value *getColumn(unsigned i) const { @@ -564,6 +593,27 @@ class LowerMatrixIntrinsics { MatrixVal = M.embedInVector(Builder); } + // If it's a PHI, split it now. We'll take care of fixing up the operands + // later once we're in VisitPHI. + if (auto *PHI = dyn_cast(MatrixVal)) { + auto *EltTy = cast(PHI->getType())->getElementType(); + MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; + + IRBuilder<>::InsertPointGuard IPG(Builder); + Builder.SetInsertPoint(PHI); + for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) + PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), + PHI->getNumIncomingValues(), + PHI->getName())); + + Inst2ColumnMatrix[PHI] = PhiM; + return PhiM; + } + + // If it's a constant, materialize the split version of it with this shape. + if (auto *IncomingConst = dyn_cast(MatrixVal)) + return MatrixTy(IncomingConst, SI); + // Otherwise split MatrixVal. SmallVector SplitVecs; for (unsigned MaskStart = 0; @@ -1077,6 +1127,11 @@ class LowerMatrixIntrinsics { Changed |= VisitStore(cast(Inst), Op1, Op2, Builder); } + // Fifth, lower all the PHI's with shape information. + for (Instruction *Inst : MatrixInsts) + if (auto *PHI = dyn_cast(Inst)) + Changed |= VisitPHI(PHI); + if (ORE) { RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); RemarkGen.emitRemarks(); @@ -1349,7 +1404,8 @@ class LowerMatrixIntrinsics { IRBuilder<> &Builder) { auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); (void)inserted; - assert(inserted.second && "multiple matrix lowering mapping"); + assert((inserted.second || isa(Inst)) && + "multiple matrix lowering mapping"); ToRemove.push_back(Inst); Value *Flattened = nullptr; @@ -2133,6 +2189,41 @@ class LowerMatrixIntrinsics { return true; } + bool VisitPHI(PHINode *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + IRBuilder<> Builder(Inst); + + MatrixTy PhiM = getMatrix(Inst, I->second, Builder); + + for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues(); + IncomingI != IncomingE; ++IncomingI) { + Value *IncomingV = Inst->getIncomingValue(IncomingI); + BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI); + + // getMatrix() may insert some instructions. The safe place to insert them + // is at the end of the parent block, where the register allocator would + // have inserted the copies that materialize the PHI. + if (auto *IncomingInst = dyn_cast(IncomingV)) + Builder.SetInsertPoint(IncomingInst->getParent()->getTerminator()); + + MatrixTy OpM = getMatrix(IncomingV, I->second, Builder); + + for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { + PHINode *NewPHI = cast(PhiM.getVector(VI)); + NewPHI->addIncoming(OpM.getVector(VI), IncomingB); + } + } + + // finalizeLowering() may also insert instructions in some cases. The safe + // place for those is at the end of the initial block of PHIs. + Builder.SetInsertPoint(*Inst->getInsertionPointAfterDef()); + finalizeLowering(Inst, PhiM, Builder); + return true; + } + /// Lower binary operators, if shape information is available. bool VisitBinaryOperator(BinaryOperator *Inst) { auto I = ShapeMap.find(Inst); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll new file mode 100644 index 0000000000000..d49b4d1112062 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -0,0 +1,216 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -matrix-allow-contract=false -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @matrix_phi(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI9:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI10:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI11:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 16 +; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]] +; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]] +; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP13]], align 16 +; CHECK-NEXT: ret void +; +entry: + %mat = load <9 x double>, ptr %in1 + br label %loop + +loop: + %phi = phi <9 x double> [%mat, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = load <9 x double>, ptr %in2 + + ; Give in2 the shape: 3 x 3 + %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) + %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + + %sum = fadd <9 x double> %phi, %in2tt + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} + +define void @matrix_phi_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_zeroinitializer( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %phi = phi <9 x double> [zeroinitializer, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = load <9 x double>, ptr %in2 + + ; Give in2 the shape: 3 x 3 + %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) + %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + + %sum = fadd <9 x double> %phi, %in2tt + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} + +define void @matrix_phi_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_undef( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ undef, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %phi = phi <9 x double> [undef, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = load <9 x double>, ptr %in2 + + ; Give in2 the shape: 3 x 3 + %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) + %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + + %sum = fadd <9 x double> %phi, %in2tt + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} + +define void @matrix_phi_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_poison( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ poison, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %phi = phi <9 x double> [poison, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = load <9 x double>, ptr %in2 + + ; Give in2 the shape: 3 x 3 + %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) + %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + + %sum = fadd <9 x double> %phi, %in2tt + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll index 2af2c979f2065..6ed8e46d62892 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll @@ -28,9 +28,6 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B, ; CHECK-NEXT: [[TMP15:%.*]] = insertelement <3 x double> [[TMP13]], double [[TMP14]], i64 1 ; CHECK-NEXT: [[TMP16:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 2 ; CHECK-NEXT: [[TMP17:%.*]] = insertelement <3 x double> [[TMP15]], double [[TMP16]], i64 2 -; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <3 x double> [[TMP5]], <3 x double> [[TMP11]], <6 x i32> -; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <3 x double> [[TMP17]], <3 x double> poison, <6 x i32> -; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <6 x double> [[TMP18]], <6 x double> [[TMP19]], <9 x i32> ; CHECK-NEXT: br label [[IF_END:%.*]] ; CHECK: if.else: ; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[B:%.*]], <9 x double> poison, <3 x i32> @@ -54,183 +51,179 @@ define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B, ; CHECK-NEXT: [[TMP36:%.*]] = insertelement <3 x double> [[TMP34]], double [[TMP35]], i64 1 ; CHECK-NEXT: [[TMP37:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 2 ; CHECK-NEXT: [[TMP38:%.*]] = insertelement <3 x double> [[TMP36]], double [[TMP37]], i64 2 -; CHECK-NEXT: [[TMP39:%.*]] = shufflevector <3 x double> [[TMP26]], <3 x double> [[TMP32]], <6 x i32> -; CHECK-NEXT: [[TMP40:%.*]] = shufflevector <3 x double> [[TMP38]], <3 x double> poison, <6 x i32> -; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <6 x double> [[TMP39]], <6 x double> [[TMP40]], <9 x i32> ; CHECK-NEXT: br label [[IF_END]] ; CHECK: if.end: -; CHECK-NEXT: [[MERGE:%.*]] = phi <9 x double> [ [[TMP20]], [[IF_THEN]] ], [ [[TMP41]], [[IF_ELSE]] ] -; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <9 x double> [[C:%.*]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <9 x double> [[C]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[MERGE9:%.*]] = phi <3 x double> [ [[TMP5]], [[IF_THEN]] ], [ [[TMP26]], [[IF_ELSE]] ] +; CHECK-NEXT: [[MERGE10:%.*]] = phi <3 x double> [ [[TMP11]], [[IF_THEN]] ], [ [[TMP32]], [[IF_ELSE]] ] +; CHECK-NEXT: [[MERGE11:%.*]] = phi <3 x double> [ [[TMP17]], [[IF_THEN]] ], [ [[TMP38]], [[IF_ELSE]] ] +; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE:%.*]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP42:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 0 +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> poison, double [[TMP42]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP43:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] -; CHECK-NEXT: [[BLOCK12:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP44:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 1 +; CHECK-NEXT: [[BLOCK12:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT13:%.*]] = insertelement <1 x double> poison, double [[TMP44]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT14:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT13]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP45:%.*]] = fmul <1 x double> [[BLOCK12]], [[SPLAT_SPLAT14]] ; CHECK-NEXT: [[TMP46:%.*]] = fadd <1 x double> [[TMP43]], [[TMP45]] -; CHECK-NEXT: [[BLOCK15:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP47:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 2 +; CHECK-NEXT: [[BLOCK15:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP47:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT16:%.*]] = insertelement <1 x double> poison, double [[TMP47]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT17:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT16]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP48:%.*]] = fmul <1 x double> [[BLOCK15]], [[SPLAT_SPLAT17]] ; CHECK-NEXT: [[TMP49:%.*]] = fadd <1 x double> [[TMP46]], [[TMP48]] ; CHECK-NEXT: [[TMP50:%.*]] = shufflevector <1 x double> [[TMP49]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP51:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP50]], <3 x i32> -; CHECK-NEXT: [[BLOCK18:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP52:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 0 +; CHECK-NEXT: [[BLOCK18:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP52:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT19:%.*]] = insertelement <1 x double> poison, double [[TMP52]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT20:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT19]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP53:%.*]] = fmul <1 x double> [[BLOCK18]], [[SPLAT_SPLAT20]] -; CHECK-NEXT: [[BLOCK21:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP54:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 1 +; CHECK-NEXT: [[BLOCK21:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP54:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT22:%.*]] = insertelement <1 x double> poison, double [[TMP54]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT23:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT22]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP55:%.*]] = fmul <1 x double> [[BLOCK21]], [[SPLAT_SPLAT23]] ; CHECK-NEXT: [[TMP56:%.*]] = fadd <1 x double> [[TMP53]], [[TMP55]] -; CHECK-NEXT: [[BLOCK24:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP57:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 2 +; CHECK-NEXT: [[BLOCK24:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT25:%.*]] = insertelement <1 x double> poison, double [[TMP57]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT26:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT25]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP58:%.*]] = fmul <1 x double> [[BLOCK24]], [[SPLAT_SPLAT26]] ; CHECK-NEXT: [[TMP59:%.*]] = fadd <1 x double> [[TMP56]], [[TMP58]] ; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <1 x double> [[TMP59]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <3 x double> [[TMP51]], <3 x double> [[TMP60]], <3 x i32> -; CHECK-NEXT: [[BLOCK27:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP62:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 0 +; CHECK-NEXT: [[BLOCK27:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP62:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT28:%.*]] = insertelement <1 x double> poison, double [[TMP62]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT29:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT28]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP63:%.*]] = fmul <1 x double> [[BLOCK27]], [[SPLAT_SPLAT29]] -; CHECK-NEXT: [[BLOCK30:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP64:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 1 +; CHECK-NEXT: [[BLOCK30:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP64:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT31:%.*]] = insertelement <1 x double> poison, double [[TMP64]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT32:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT31]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP65:%.*]] = fmul <1 x double> [[BLOCK30]], [[SPLAT_SPLAT32]] ; CHECK-NEXT: [[TMP66:%.*]] = fadd <1 x double> [[TMP63]], [[TMP65]] -; CHECK-NEXT: [[BLOCK33:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP67:%.*]] = extractelement <3 x double> [[SPLIT9]], i64 2 +; CHECK-NEXT: [[BLOCK33:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP67:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT34:%.*]] = insertelement <1 x double> poison, double [[TMP67]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT35:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT34]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP68:%.*]] = fmul <1 x double> [[BLOCK33]], [[SPLAT_SPLAT35]] ; CHECK-NEXT: [[TMP69:%.*]] = fadd <1 x double> [[TMP66]], [[TMP68]] ; CHECK-NEXT: [[TMP70:%.*]] = shufflevector <1 x double> [[TMP69]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP71:%.*]] = shufflevector <3 x double> [[TMP61]], <3 x double> [[TMP70]], <3 x i32> -; CHECK-NEXT: [[BLOCK36:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP72:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 0 +; CHECK-NEXT: [[BLOCK36:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP72:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT37:%.*]] = insertelement <1 x double> poison, double [[TMP72]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT38:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT37]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP73:%.*]] = fmul <1 x double> [[BLOCK36]], [[SPLAT_SPLAT38]] -; CHECK-NEXT: [[BLOCK39:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP74:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 1 +; CHECK-NEXT: [[BLOCK39:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP74:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT40:%.*]] = insertelement <1 x double> poison, double [[TMP74]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT41:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT40]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP75:%.*]] = fmul <1 x double> [[BLOCK39]], [[SPLAT_SPLAT41]] ; CHECK-NEXT: [[TMP76:%.*]] = fadd <1 x double> [[TMP73]], [[TMP75]] -; CHECK-NEXT: [[BLOCK42:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP77:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 2 +; CHECK-NEXT: [[BLOCK42:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP77:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT43:%.*]] = insertelement <1 x double> poison, double [[TMP77]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT44:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT43]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP78:%.*]] = fmul <1 x double> [[BLOCK42]], [[SPLAT_SPLAT44]] ; CHECK-NEXT: [[TMP79:%.*]] = fadd <1 x double> [[TMP76]], [[TMP78]] ; CHECK-NEXT: [[TMP80:%.*]] = shufflevector <1 x double> [[TMP79]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP81:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP80]], <3 x i32> -; CHECK-NEXT: [[BLOCK45:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP82:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 0 +; CHECK-NEXT: [[BLOCK45:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT46:%.*]] = insertelement <1 x double> poison, double [[TMP82]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT47:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT46]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP83:%.*]] = fmul <1 x double> [[BLOCK45]], [[SPLAT_SPLAT47]] -; CHECK-NEXT: [[BLOCK48:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP84:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 1 +; CHECK-NEXT: [[BLOCK48:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP84:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT49:%.*]] = insertelement <1 x double> poison, double [[TMP84]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT50:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT49]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP85:%.*]] = fmul <1 x double> [[BLOCK48]], [[SPLAT_SPLAT50]] ; CHECK-NEXT: [[TMP86:%.*]] = fadd <1 x double> [[TMP83]], [[TMP85]] -; CHECK-NEXT: [[BLOCK51:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP87:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 2 +; CHECK-NEXT: [[BLOCK51:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP87:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT52:%.*]] = insertelement <1 x double> poison, double [[TMP87]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT53:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT52]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP88:%.*]] = fmul <1 x double> [[BLOCK51]], [[SPLAT_SPLAT53]] ; CHECK-NEXT: [[TMP89:%.*]] = fadd <1 x double> [[TMP86]], [[TMP88]] ; CHECK-NEXT: [[TMP90:%.*]] = shufflevector <1 x double> [[TMP89]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP91:%.*]] = shufflevector <3 x double> [[TMP81]], <3 x double> [[TMP90]], <3 x i32> -; CHECK-NEXT: [[BLOCK54:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP92:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 0 +; CHECK-NEXT: [[BLOCK54:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP92:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT55:%.*]] = insertelement <1 x double> poison, double [[TMP92]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT56:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT55]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP93:%.*]] = fmul <1 x double> [[BLOCK54]], [[SPLAT_SPLAT56]] -; CHECK-NEXT: [[BLOCK57:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP94:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 1 +; CHECK-NEXT: [[BLOCK57:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP94:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT58:%.*]] = insertelement <1 x double> poison, double [[TMP94]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT59:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT58]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP95:%.*]] = fmul <1 x double> [[BLOCK57]], [[SPLAT_SPLAT59]] ; CHECK-NEXT: [[TMP96:%.*]] = fadd <1 x double> [[TMP93]], [[TMP95]] -; CHECK-NEXT: [[BLOCK60:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP97:%.*]] = extractelement <3 x double> [[SPLIT10]], i64 2 +; CHECK-NEXT: [[BLOCK60:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP97:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT61:%.*]] = insertelement <1 x double> poison, double [[TMP97]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT62:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT61]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP98:%.*]] = fmul <1 x double> [[BLOCK60]], [[SPLAT_SPLAT62]] ; CHECK-NEXT: [[TMP99:%.*]] = fadd <1 x double> [[TMP96]], [[TMP98]] ; CHECK-NEXT: [[TMP100:%.*]] = shufflevector <1 x double> [[TMP99]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP101:%.*]] = shufflevector <3 x double> [[TMP91]], <3 x double> [[TMP100]], <3 x i32> -; CHECK-NEXT: [[BLOCK63:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP102:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 0 +; CHECK-NEXT: [[BLOCK63:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP102:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT64:%.*]] = insertelement <1 x double> poison, double [[TMP102]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT65:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT64]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP103:%.*]] = fmul <1 x double> [[BLOCK63]], [[SPLAT_SPLAT65]] -; CHECK-NEXT: [[BLOCK66:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP104:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 1 +; CHECK-NEXT: [[BLOCK66:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP104:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT67:%.*]] = insertelement <1 x double> poison, double [[TMP104]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT68:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT67]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP105:%.*]] = fmul <1 x double> [[BLOCK66]], [[SPLAT_SPLAT68]] ; CHECK-NEXT: [[TMP106:%.*]] = fadd <1 x double> [[TMP103]], [[TMP105]] -; CHECK-NEXT: [[BLOCK69:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP107:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 2 +; CHECK-NEXT: [[BLOCK69:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP107:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT70:%.*]] = insertelement <1 x double> poison, double [[TMP107]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT71:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT70]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP108:%.*]] = fmul <1 x double> [[BLOCK69]], [[SPLAT_SPLAT71]] ; CHECK-NEXT: [[TMP109:%.*]] = fadd <1 x double> [[TMP106]], [[TMP108]] ; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <1 x double> [[TMP109]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP111:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP110]], <3 x i32> -; CHECK-NEXT: [[BLOCK72:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP112:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 0 +; CHECK-NEXT: [[BLOCK72:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP112:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT73:%.*]] = insertelement <1 x double> poison, double [[TMP112]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT74:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT73]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP113:%.*]] = fmul <1 x double> [[BLOCK72]], [[SPLAT_SPLAT74]] -; CHECK-NEXT: [[BLOCK75:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP114:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 1 +; CHECK-NEXT: [[BLOCK75:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP114:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT76:%.*]] = insertelement <1 x double> poison, double [[TMP114]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT77:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT76]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP115:%.*]] = fmul <1 x double> [[BLOCK75]], [[SPLAT_SPLAT77]] ; CHECK-NEXT: [[TMP116:%.*]] = fadd <1 x double> [[TMP113]], [[TMP115]] -; CHECK-NEXT: [[BLOCK78:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP117:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 2 +; CHECK-NEXT: [[BLOCK78:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP117:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT79:%.*]] = insertelement <1 x double> poison, double [[TMP117]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT80:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT79]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP118:%.*]] = fmul <1 x double> [[BLOCK78]], [[SPLAT_SPLAT80]] ; CHECK-NEXT: [[TMP119:%.*]] = fadd <1 x double> [[TMP116]], [[TMP118]] ; CHECK-NEXT: [[TMP120:%.*]] = shufflevector <1 x double> [[TMP119]], <1 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP121:%.*]] = shufflevector <3 x double> [[TMP111]], <3 x double> [[TMP120]], <3 x i32> -; CHECK-NEXT: [[BLOCK81:%.*]] = shufflevector <3 x double> [[SPLIT6]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP122:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 0 +; CHECK-NEXT: [[BLOCK81:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP122:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT82:%.*]] = insertelement <1 x double> poison, double [[TMP122]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT83:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT82]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP123:%.*]] = fmul <1 x double> [[BLOCK81]], [[SPLAT_SPLAT83]] -; CHECK-NEXT: [[BLOCK84:%.*]] = shufflevector <3 x double> [[SPLIT7]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP124:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 1 +; CHECK-NEXT: [[BLOCK84:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP124:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT85:%.*]] = insertelement <1 x double> poison, double [[TMP124]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT86:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT85]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP125:%.*]] = fmul <1 x double> [[BLOCK84]], [[SPLAT_SPLAT86]] ; CHECK-NEXT: [[TMP126:%.*]] = fadd <1 x double> [[TMP123]], [[TMP125]] -; CHECK-NEXT: [[BLOCK87:%.*]] = shufflevector <3 x double> [[SPLIT8]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP127:%.*]] = extractelement <3 x double> [[SPLIT11]], i64 2 +; CHECK-NEXT: [[BLOCK87:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP127:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 ; CHECK-NEXT: [[SPLAT_SPLATINSERT88:%.*]] = insertelement <1 x double> poison, double [[TMP127]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLAT89:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT88]], <1 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP128:%.*]] = fmul <1 x double> [[BLOCK87]], [[SPLAT_SPLAT89]] From 71b99d3b4da5f0f655dc2f3c46cc7d42214517a1 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 27 May 2025 15:40:31 -0700 Subject: [PATCH 02/26] move formerly unsupported test to new home --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 266 +++++++++++++++++- .../propagate-backwards-unsupported.ll | 254 ----------------- 2 files changed, 258 insertions(+), 262 deletions(-) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index d49b4d1112062..e2720b86d87ca 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -1,8 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -matrix-allow-contract=false -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s -define void @matrix_phi(ptr %in1, ptr %in2, i32 %count, ptr %out) { -; CHECK-LABEL: @matrix_phi( +define void @matrix_phi_loop(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3 @@ -59,8 +59,8 @@ exit: ret void } -define void @matrix_phi_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) { -; CHECK-LABEL: @matrix_phi_zeroinitializer( +define void @matrix_phi_loop_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_zeroinitializer( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: @@ -111,8 +111,8 @@ exit: ret void } -define void @matrix_phi_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) { -; CHECK-LABEL: @matrix_phi_undef( +define void @matrix_phi_loop_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_undef( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: @@ -163,8 +163,8 @@ exit: ret void } -define void @matrix_phi_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) { -; CHECK-LABEL: @matrix_phi_poison( +define void @matrix_phi_loop_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_poison( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: @@ -214,3 +214,253 @@ exit: store <9 x double> %sum, ptr %out ret void } + +define <9 x double> @matrix_phi_ifthenelse(i1 %cond, <9 x double> %A, <9 x double> %B, <9 x double> %C) { +; CHECK-LABEL: @matrix_phi_ifthenelse( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND:%.*]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]] +; CHECK: if.then: +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <9 x double> [[A:%.*]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <9 x double> [[A]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <9 x double> [[A]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x double> poison, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <3 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <3 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <3 x double> poison, double [[TMP6]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <3 x double> [[TMP7]], double [[TMP8]], i64 1 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <3 x double> [[TMP9]], double [[TMP10]], i64 2 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 2 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <3 x double> poison, double [[TMP12]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 2 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <3 x double> [[TMP13]], double [[TMP14]], i64 1 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 2 +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <3 x double> [[TMP15]], double [[TMP16]], i64 2 +; CHECK-NEXT: br label [[IF_END:%.*]] +; CHECK: if.else: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[B:%.*]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <9 x double> [[B]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <9 x double> [[B]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <3 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <3 x double> poison, double [[TMP21]], i64 0 +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <3 x double> [[TMP22]], double [[TMP23]], i64 1 +; CHECK-NEXT: [[TMP25:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP26:%.*]] = insertelement <3 x double> [[TMP24]], double [[TMP25]], i64 2 +; CHECK-NEXT: [[TMP27:%.*]] = extractelement <3 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP28:%.*]] = insertelement <3 x double> poison, double [[TMP27]], i64 0 +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP30:%.*]] = insertelement <3 x double> [[TMP28]], double [[TMP29]], i64 1 +; CHECK-NEXT: [[TMP31:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP32:%.*]] = insertelement <3 x double> [[TMP30]], double [[TMP31]], i64 2 +; CHECK-NEXT: [[TMP33:%.*]] = extractelement <3 x double> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP34:%.*]] = insertelement <3 x double> poison, double [[TMP33]], i64 0 +; CHECK-NEXT: [[TMP35:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 2 +; CHECK-NEXT: [[TMP36:%.*]] = insertelement <3 x double> [[TMP34]], double [[TMP35]], i64 1 +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 2 +; CHECK-NEXT: [[TMP38:%.*]] = insertelement <3 x double> [[TMP36]], double [[TMP37]], i64 2 +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: [[MERGE9:%.*]] = phi <3 x double> [ [[TMP5]], [[IF_THEN]] ], [ [[TMP26]], [[IF_ELSE]] ] +; CHECK-NEXT: [[MERGE10:%.*]] = phi <3 x double> [ [[TMP11]], [[IF_THEN]] ], [ [[TMP32]], [[IF_ELSE]] ] +; CHECK-NEXT: [[MERGE11:%.*]] = phi <3 x double> [ [[TMP17]], [[IF_THEN]] ], [ [[TMP38]], [[IF_ELSE]] ] +; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE:%.*]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> poison, double [[TMP42]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP43:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK12:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT13:%.*]] = insertelement <1 x double> poison, double [[TMP44]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT14:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT13]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP45:%.*]] = fmul <1 x double> [[BLOCK12]], [[SPLAT_SPLAT14]] +; CHECK-NEXT: [[TMP46:%.*]] = fadd <1 x double> [[TMP43]], [[TMP45]] +; CHECK-NEXT: [[BLOCK15:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP47:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT16:%.*]] = insertelement <1 x double> poison, double [[TMP47]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT17:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT16]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP48:%.*]] = fmul <1 x double> [[BLOCK15]], [[SPLAT_SPLAT17]] +; CHECK-NEXT: [[TMP49:%.*]] = fadd <1 x double> [[TMP46]], [[TMP48]] +; CHECK-NEXT: [[TMP50:%.*]] = shufflevector <1 x double> [[TMP49]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP51:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP50]], <3 x i32> +; CHECK-NEXT: [[BLOCK18:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP52:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT19:%.*]] = insertelement <1 x double> poison, double [[TMP52]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT20:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT19]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP53:%.*]] = fmul <1 x double> [[BLOCK18]], [[SPLAT_SPLAT20]] +; CHECK-NEXT: [[BLOCK21:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP54:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT22:%.*]] = insertelement <1 x double> poison, double [[TMP54]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT23:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT22]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP55:%.*]] = fmul <1 x double> [[BLOCK21]], [[SPLAT_SPLAT23]] +; CHECK-NEXT: [[TMP56:%.*]] = fadd <1 x double> [[TMP53]], [[TMP55]] +; CHECK-NEXT: [[BLOCK24:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP57:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT25:%.*]] = insertelement <1 x double> poison, double [[TMP57]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT26:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT25]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP58:%.*]] = fmul <1 x double> [[BLOCK24]], [[SPLAT_SPLAT26]] +; CHECK-NEXT: [[TMP59:%.*]] = fadd <1 x double> [[TMP56]], [[TMP58]] +; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <1 x double> [[TMP59]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <3 x double> [[TMP51]], <3 x double> [[TMP60]], <3 x i32> +; CHECK-NEXT: [[BLOCK27:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP62:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT28:%.*]] = insertelement <1 x double> poison, double [[TMP62]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT29:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT28]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP63:%.*]] = fmul <1 x double> [[BLOCK27]], [[SPLAT_SPLAT29]] +; CHECK-NEXT: [[BLOCK30:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP64:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT31:%.*]] = insertelement <1 x double> poison, double [[TMP64]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT32:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT31]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP65:%.*]] = fmul <1 x double> [[BLOCK30]], [[SPLAT_SPLAT32]] +; CHECK-NEXT: [[TMP66:%.*]] = fadd <1 x double> [[TMP63]], [[TMP65]] +; CHECK-NEXT: [[BLOCK33:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP67:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT34:%.*]] = insertelement <1 x double> poison, double [[TMP67]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT35:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT34]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP68:%.*]] = fmul <1 x double> [[BLOCK33]], [[SPLAT_SPLAT35]] +; CHECK-NEXT: [[TMP69:%.*]] = fadd <1 x double> [[TMP66]], [[TMP68]] +; CHECK-NEXT: [[TMP70:%.*]] = shufflevector <1 x double> [[TMP69]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP71:%.*]] = shufflevector <3 x double> [[TMP61]], <3 x double> [[TMP70]], <3 x i32> +; CHECK-NEXT: [[BLOCK36:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP72:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT37:%.*]] = insertelement <1 x double> poison, double [[TMP72]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT38:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT37]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP73:%.*]] = fmul <1 x double> [[BLOCK36]], [[SPLAT_SPLAT38]] +; CHECK-NEXT: [[BLOCK39:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP74:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT40:%.*]] = insertelement <1 x double> poison, double [[TMP74]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT41:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT40]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP75:%.*]] = fmul <1 x double> [[BLOCK39]], [[SPLAT_SPLAT41]] +; CHECK-NEXT: [[TMP76:%.*]] = fadd <1 x double> [[TMP73]], [[TMP75]] +; CHECK-NEXT: [[BLOCK42:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP77:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT43:%.*]] = insertelement <1 x double> poison, double [[TMP77]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT44:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT43]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP78:%.*]] = fmul <1 x double> [[BLOCK42]], [[SPLAT_SPLAT44]] +; CHECK-NEXT: [[TMP79:%.*]] = fadd <1 x double> [[TMP76]], [[TMP78]] +; CHECK-NEXT: [[TMP80:%.*]] = shufflevector <1 x double> [[TMP79]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP81:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP80]], <3 x i32> +; CHECK-NEXT: [[BLOCK45:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT46:%.*]] = insertelement <1 x double> poison, double [[TMP82]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT47:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT46]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP83:%.*]] = fmul <1 x double> [[BLOCK45]], [[SPLAT_SPLAT47]] +; CHECK-NEXT: [[BLOCK48:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP84:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT49:%.*]] = insertelement <1 x double> poison, double [[TMP84]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT50:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT49]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP85:%.*]] = fmul <1 x double> [[BLOCK48]], [[SPLAT_SPLAT50]] +; CHECK-NEXT: [[TMP86:%.*]] = fadd <1 x double> [[TMP83]], [[TMP85]] +; CHECK-NEXT: [[BLOCK51:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP87:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT52:%.*]] = insertelement <1 x double> poison, double [[TMP87]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT53:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT52]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP88:%.*]] = fmul <1 x double> [[BLOCK51]], [[SPLAT_SPLAT53]] +; CHECK-NEXT: [[TMP89:%.*]] = fadd <1 x double> [[TMP86]], [[TMP88]] +; CHECK-NEXT: [[TMP90:%.*]] = shufflevector <1 x double> [[TMP89]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP91:%.*]] = shufflevector <3 x double> [[TMP81]], <3 x double> [[TMP90]], <3 x i32> +; CHECK-NEXT: [[BLOCK54:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP92:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT55:%.*]] = insertelement <1 x double> poison, double [[TMP92]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT56:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT55]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP93:%.*]] = fmul <1 x double> [[BLOCK54]], [[SPLAT_SPLAT56]] +; CHECK-NEXT: [[BLOCK57:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP94:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT58:%.*]] = insertelement <1 x double> poison, double [[TMP94]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT59:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT58]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP95:%.*]] = fmul <1 x double> [[BLOCK57]], [[SPLAT_SPLAT59]] +; CHECK-NEXT: [[TMP96:%.*]] = fadd <1 x double> [[TMP93]], [[TMP95]] +; CHECK-NEXT: [[BLOCK60:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP97:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT61:%.*]] = insertelement <1 x double> poison, double [[TMP97]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT62:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT61]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP98:%.*]] = fmul <1 x double> [[BLOCK60]], [[SPLAT_SPLAT62]] +; CHECK-NEXT: [[TMP99:%.*]] = fadd <1 x double> [[TMP96]], [[TMP98]] +; CHECK-NEXT: [[TMP100:%.*]] = shufflevector <1 x double> [[TMP99]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP101:%.*]] = shufflevector <3 x double> [[TMP91]], <3 x double> [[TMP100]], <3 x i32> +; CHECK-NEXT: [[BLOCK63:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP102:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT64:%.*]] = insertelement <1 x double> poison, double [[TMP102]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT65:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT64]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP103:%.*]] = fmul <1 x double> [[BLOCK63]], [[SPLAT_SPLAT65]] +; CHECK-NEXT: [[BLOCK66:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP104:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT67:%.*]] = insertelement <1 x double> poison, double [[TMP104]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT68:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT67]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP105:%.*]] = fmul <1 x double> [[BLOCK66]], [[SPLAT_SPLAT68]] +; CHECK-NEXT: [[TMP106:%.*]] = fadd <1 x double> [[TMP103]], [[TMP105]] +; CHECK-NEXT: [[BLOCK69:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP107:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT70:%.*]] = insertelement <1 x double> poison, double [[TMP107]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT71:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT70]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP108:%.*]] = fmul <1 x double> [[BLOCK69]], [[SPLAT_SPLAT71]] +; CHECK-NEXT: [[TMP109:%.*]] = fadd <1 x double> [[TMP106]], [[TMP108]] +; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <1 x double> [[TMP109]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP111:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP110]], <3 x i32> +; CHECK-NEXT: [[BLOCK72:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP112:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT73:%.*]] = insertelement <1 x double> poison, double [[TMP112]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT74:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT73]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP113:%.*]] = fmul <1 x double> [[BLOCK72]], [[SPLAT_SPLAT74]] +; CHECK-NEXT: [[BLOCK75:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP114:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT76:%.*]] = insertelement <1 x double> poison, double [[TMP114]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT77:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT76]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP115:%.*]] = fmul <1 x double> [[BLOCK75]], [[SPLAT_SPLAT77]] +; CHECK-NEXT: [[TMP116:%.*]] = fadd <1 x double> [[TMP113]], [[TMP115]] +; CHECK-NEXT: [[BLOCK78:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP117:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT79:%.*]] = insertelement <1 x double> poison, double [[TMP117]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT80:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT79]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP118:%.*]] = fmul <1 x double> [[BLOCK78]], [[SPLAT_SPLAT80]] +; CHECK-NEXT: [[TMP119:%.*]] = fadd <1 x double> [[TMP116]], [[TMP118]] +; CHECK-NEXT: [[TMP120:%.*]] = shufflevector <1 x double> [[TMP119]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP121:%.*]] = shufflevector <3 x double> [[TMP111]], <3 x double> [[TMP120]], <3 x i32> +; CHECK-NEXT: [[BLOCK81:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP122:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT82:%.*]] = insertelement <1 x double> poison, double [[TMP122]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT83:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT82]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP123:%.*]] = fmul <1 x double> [[BLOCK81]], [[SPLAT_SPLAT83]] +; CHECK-NEXT: [[BLOCK84:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP124:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT85:%.*]] = insertelement <1 x double> poison, double [[TMP124]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT86:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT85]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP125:%.*]] = fmul <1 x double> [[BLOCK84]], [[SPLAT_SPLAT86]] +; CHECK-NEXT: [[TMP126:%.*]] = fadd <1 x double> [[TMP123]], [[TMP125]] +; CHECK-NEXT: [[BLOCK87:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> +; CHECK-NEXT: [[TMP127:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 +; CHECK-NEXT: [[SPLAT_SPLATINSERT88:%.*]] = insertelement <1 x double> poison, double [[TMP127]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLAT89:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT88]], <1 x double> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP128:%.*]] = fmul <1 x double> [[BLOCK87]], [[SPLAT_SPLAT89]] +; CHECK-NEXT: [[TMP129:%.*]] = fadd <1 x double> [[TMP126]], [[TMP128]] +; CHECK-NEXT: [[TMP130:%.*]] = shufflevector <1 x double> [[TMP129]], <1 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP131:%.*]] = shufflevector <3 x double> [[TMP121]], <3 x double> [[TMP130]], <3 x i32> +; CHECK-NEXT: [[TMP132:%.*]] = shufflevector <3 x double> [[TMP71]], <3 x double> [[TMP101]], <6 x i32> +; CHECK-NEXT: [[TMP133:%.*]] = shufflevector <3 x double> [[TMP131]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP134:%.*]] = shufflevector <6 x double> [[TMP132]], <6 x double> [[TMP133]], <9 x i32> +; CHECK-NEXT: ret <9 x double> [[TMP134]] +; +entry: + br i1 %cond, label %if.then, label %if.else + +if.then: ; preds = %entry + %A.trans = tail call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> %A, i32 3, i32 3) + br label %if.end + +if.else: ; preds = %entry + %B.trans = tail call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> %B, i32 3, i32 3) + br label %if.end + +if.end: ; preds = %if.then, %if.else + %merge = phi <9 x double> [ %A.trans, %if.then], [ %B.trans, %if.else ] + %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %C, <9 x double> %merge, i32 3, i32 3, i32 3) + ret <9 x double> %res +} diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll index 6ed8e46d62892..f07e1762d404f 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backwards-unsupported.ll @@ -1,260 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s -; Check that we we use flattened vectors for PHI operands and extract the columns afterwards. -define <9 x double> @unsupported_phi(i1 %cond, <9 x double> %A, <9 x double> %B, <9 x double> %C) { -; CHECK-LABEL: @unsupported_phi( -; CHECK-NEXT: entry: -; CHECK-NEXT: br i1 [[COND:%.*]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]] -; CHECK: if.then: -; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <9 x double> [[A:%.*]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <9 x double> [[A]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <9 x double> [[A]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 0 -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x double> poison, double [[TMP0]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <3 x double> [[TMP1]], double [[TMP2]], i64 1 -; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 0 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <3 x double> [[TMP3]], double [[TMP4]], i64 2 -; CHECK-NEXT: [[TMP6:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 1 -; CHECK-NEXT: [[TMP7:%.*]] = insertelement <3 x double> poison, double [[TMP6]], i64 0 -; CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 1 -; CHECK-NEXT: [[TMP9:%.*]] = insertelement <3 x double> [[TMP7]], double [[TMP8]], i64 1 -; CHECK-NEXT: [[TMP10:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 1 -; CHECK-NEXT: [[TMP11:%.*]] = insertelement <3 x double> [[TMP9]], double [[TMP10]], i64 2 -; CHECK-NEXT: [[TMP12:%.*]] = extractelement <3 x double> [[SPLIT3]], i64 2 -; CHECK-NEXT: [[TMP13:%.*]] = insertelement <3 x double> poison, double [[TMP12]], i64 0 -; CHECK-NEXT: [[TMP14:%.*]] = extractelement <3 x double> [[SPLIT4]], i64 2 -; CHECK-NEXT: [[TMP15:%.*]] = insertelement <3 x double> [[TMP13]], double [[TMP14]], i64 1 -; CHECK-NEXT: [[TMP16:%.*]] = extractelement <3 x double> [[SPLIT5]], i64 2 -; CHECK-NEXT: [[TMP17:%.*]] = insertelement <3 x double> [[TMP15]], double [[TMP16]], i64 2 -; CHECK-NEXT: br label [[IF_END:%.*]] -; CHECK: if.else: -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[B:%.*]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <9 x double> [[B]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <9 x double> [[B]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP21:%.*]] = extractelement <3 x double> [[SPLIT]], i64 0 -; CHECK-NEXT: [[TMP22:%.*]] = insertelement <3 x double> poison, double [[TMP21]], i64 0 -; CHECK-NEXT: [[TMP23:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 0 -; CHECK-NEXT: [[TMP24:%.*]] = insertelement <3 x double> [[TMP22]], double [[TMP23]], i64 1 -; CHECK-NEXT: [[TMP25:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 0 -; CHECK-NEXT: [[TMP26:%.*]] = insertelement <3 x double> [[TMP24]], double [[TMP25]], i64 2 -; CHECK-NEXT: [[TMP27:%.*]] = extractelement <3 x double> [[SPLIT]], i64 1 -; CHECK-NEXT: [[TMP28:%.*]] = insertelement <3 x double> poison, double [[TMP27]], i64 0 -; CHECK-NEXT: [[TMP29:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 1 -; CHECK-NEXT: [[TMP30:%.*]] = insertelement <3 x double> [[TMP28]], double [[TMP29]], i64 1 -; CHECK-NEXT: [[TMP31:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 1 -; CHECK-NEXT: [[TMP32:%.*]] = insertelement <3 x double> [[TMP30]], double [[TMP31]], i64 2 -; CHECK-NEXT: [[TMP33:%.*]] = extractelement <3 x double> [[SPLIT]], i64 2 -; CHECK-NEXT: [[TMP34:%.*]] = insertelement <3 x double> poison, double [[TMP33]], i64 0 -; CHECK-NEXT: [[TMP35:%.*]] = extractelement <3 x double> [[SPLIT1]], i64 2 -; CHECK-NEXT: [[TMP36:%.*]] = insertelement <3 x double> [[TMP34]], double [[TMP35]], i64 1 -; CHECK-NEXT: [[TMP37:%.*]] = extractelement <3 x double> [[SPLIT2]], i64 2 -; CHECK-NEXT: [[TMP38:%.*]] = insertelement <3 x double> [[TMP36]], double [[TMP37]], i64 2 -; CHECK-NEXT: br label [[IF_END]] -; CHECK: if.end: -; CHECK-NEXT: [[MERGE9:%.*]] = phi <3 x double> [ [[TMP5]], [[IF_THEN]] ], [ [[TMP26]], [[IF_ELSE]] ] -; CHECK-NEXT: [[MERGE10:%.*]] = phi <3 x double> [ [[TMP11]], [[IF_THEN]] ], [ [[TMP32]], [[IF_ELSE]] ] -; CHECK-NEXT: [[MERGE11:%.*]] = phi <3 x double> [ [[TMP17]], [[IF_THEN]] ], [ [[TMP38]], [[IF_ELSE]] ] -; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <9 x double> [[MERGE:%.*]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <9 x double> [[MERGE]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP42:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> poison, double [[TMP42]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP43:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] -; CHECK-NEXT: [[BLOCK12:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP44:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT13:%.*]] = insertelement <1 x double> poison, double [[TMP44]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT14:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT13]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP45:%.*]] = fmul <1 x double> [[BLOCK12]], [[SPLAT_SPLAT14]] -; CHECK-NEXT: [[TMP46:%.*]] = fadd <1 x double> [[TMP43]], [[TMP45]] -; CHECK-NEXT: [[BLOCK15:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP47:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT16:%.*]] = insertelement <1 x double> poison, double [[TMP47]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT17:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT16]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP48:%.*]] = fmul <1 x double> [[BLOCK15]], [[SPLAT_SPLAT17]] -; CHECK-NEXT: [[TMP49:%.*]] = fadd <1 x double> [[TMP46]], [[TMP48]] -; CHECK-NEXT: [[TMP50:%.*]] = shufflevector <1 x double> [[TMP49]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP51:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP50]], <3 x i32> -; CHECK-NEXT: [[BLOCK18:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP52:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT19:%.*]] = insertelement <1 x double> poison, double [[TMP52]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT20:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT19]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP53:%.*]] = fmul <1 x double> [[BLOCK18]], [[SPLAT_SPLAT20]] -; CHECK-NEXT: [[BLOCK21:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP54:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT22:%.*]] = insertelement <1 x double> poison, double [[TMP54]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT23:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT22]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP55:%.*]] = fmul <1 x double> [[BLOCK21]], [[SPLAT_SPLAT23]] -; CHECK-NEXT: [[TMP56:%.*]] = fadd <1 x double> [[TMP53]], [[TMP55]] -; CHECK-NEXT: [[BLOCK24:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP57:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT25:%.*]] = insertelement <1 x double> poison, double [[TMP57]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT26:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT25]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP58:%.*]] = fmul <1 x double> [[BLOCK24]], [[SPLAT_SPLAT26]] -; CHECK-NEXT: [[TMP59:%.*]] = fadd <1 x double> [[TMP56]], [[TMP58]] -; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <1 x double> [[TMP59]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <3 x double> [[TMP51]], <3 x double> [[TMP60]], <3 x i32> -; CHECK-NEXT: [[BLOCK27:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP62:%.*]] = extractelement <3 x double> [[MERGE9]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT28:%.*]] = insertelement <1 x double> poison, double [[TMP62]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT29:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT28]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP63:%.*]] = fmul <1 x double> [[BLOCK27]], [[SPLAT_SPLAT29]] -; CHECK-NEXT: [[BLOCK30:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP64:%.*]] = extractelement <3 x double> [[MERGE9]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT31:%.*]] = insertelement <1 x double> poison, double [[TMP64]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT32:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT31]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP65:%.*]] = fmul <1 x double> [[BLOCK30]], [[SPLAT_SPLAT32]] -; CHECK-NEXT: [[TMP66:%.*]] = fadd <1 x double> [[TMP63]], [[TMP65]] -; CHECK-NEXT: [[BLOCK33:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP67:%.*]] = extractelement <3 x double> [[MERGE9]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT34:%.*]] = insertelement <1 x double> poison, double [[TMP67]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT35:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT34]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP68:%.*]] = fmul <1 x double> [[BLOCK33]], [[SPLAT_SPLAT35]] -; CHECK-NEXT: [[TMP69:%.*]] = fadd <1 x double> [[TMP66]], [[TMP68]] -; CHECK-NEXT: [[TMP70:%.*]] = shufflevector <1 x double> [[TMP69]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP71:%.*]] = shufflevector <3 x double> [[TMP61]], <3 x double> [[TMP70]], <3 x i32> -; CHECK-NEXT: [[BLOCK36:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP72:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT37:%.*]] = insertelement <1 x double> poison, double [[TMP72]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT38:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT37]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP73:%.*]] = fmul <1 x double> [[BLOCK36]], [[SPLAT_SPLAT38]] -; CHECK-NEXT: [[BLOCK39:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP74:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT40:%.*]] = insertelement <1 x double> poison, double [[TMP74]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT41:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT40]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP75:%.*]] = fmul <1 x double> [[BLOCK39]], [[SPLAT_SPLAT41]] -; CHECK-NEXT: [[TMP76:%.*]] = fadd <1 x double> [[TMP73]], [[TMP75]] -; CHECK-NEXT: [[BLOCK42:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP77:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT43:%.*]] = insertelement <1 x double> poison, double [[TMP77]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT44:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT43]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP78:%.*]] = fmul <1 x double> [[BLOCK42]], [[SPLAT_SPLAT44]] -; CHECK-NEXT: [[TMP79:%.*]] = fadd <1 x double> [[TMP76]], [[TMP78]] -; CHECK-NEXT: [[TMP80:%.*]] = shufflevector <1 x double> [[TMP79]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP81:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP80]], <3 x i32> -; CHECK-NEXT: [[BLOCK45:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP82:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT46:%.*]] = insertelement <1 x double> poison, double [[TMP82]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT47:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT46]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP83:%.*]] = fmul <1 x double> [[BLOCK45]], [[SPLAT_SPLAT47]] -; CHECK-NEXT: [[BLOCK48:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP84:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT49:%.*]] = insertelement <1 x double> poison, double [[TMP84]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT50:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT49]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP85:%.*]] = fmul <1 x double> [[BLOCK48]], [[SPLAT_SPLAT50]] -; CHECK-NEXT: [[TMP86:%.*]] = fadd <1 x double> [[TMP83]], [[TMP85]] -; CHECK-NEXT: [[BLOCK51:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP87:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT52:%.*]] = insertelement <1 x double> poison, double [[TMP87]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT53:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT52]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP88:%.*]] = fmul <1 x double> [[BLOCK51]], [[SPLAT_SPLAT53]] -; CHECK-NEXT: [[TMP89:%.*]] = fadd <1 x double> [[TMP86]], [[TMP88]] -; CHECK-NEXT: [[TMP90:%.*]] = shufflevector <1 x double> [[TMP89]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP91:%.*]] = shufflevector <3 x double> [[TMP81]], <3 x double> [[TMP90]], <3 x i32> -; CHECK-NEXT: [[BLOCK54:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP92:%.*]] = extractelement <3 x double> [[MERGE10]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT55:%.*]] = insertelement <1 x double> poison, double [[TMP92]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT56:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT55]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP93:%.*]] = fmul <1 x double> [[BLOCK54]], [[SPLAT_SPLAT56]] -; CHECK-NEXT: [[BLOCK57:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP94:%.*]] = extractelement <3 x double> [[MERGE10]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT58:%.*]] = insertelement <1 x double> poison, double [[TMP94]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT59:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT58]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP95:%.*]] = fmul <1 x double> [[BLOCK57]], [[SPLAT_SPLAT59]] -; CHECK-NEXT: [[TMP96:%.*]] = fadd <1 x double> [[TMP93]], [[TMP95]] -; CHECK-NEXT: [[BLOCK60:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP97:%.*]] = extractelement <3 x double> [[MERGE10]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT61:%.*]] = insertelement <1 x double> poison, double [[TMP97]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT62:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT61]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP98:%.*]] = fmul <1 x double> [[BLOCK60]], [[SPLAT_SPLAT62]] -; CHECK-NEXT: [[TMP99:%.*]] = fadd <1 x double> [[TMP96]], [[TMP98]] -; CHECK-NEXT: [[TMP100:%.*]] = shufflevector <1 x double> [[TMP99]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP101:%.*]] = shufflevector <3 x double> [[TMP91]], <3 x double> [[TMP100]], <3 x i32> -; CHECK-NEXT: [[BLOCK63:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP102:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT64:%.*]] = insertelement <1 x double> poison, double [[TMP102]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT65:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT64]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP103:%.*]] = fmul <1 x double> [[BLOCK63]], [[SPLAT_SPLAT65]] -; CHECK-NEXT: [[BLOCK66:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP104:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT67:%.*]] = insertelement <1 x double> poison, double [[TMP104]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT68:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT67]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP105:%.*]] = fmul <1 x double> [[BLOCK66]], [[SPLAT_SPLAT68]] -; CHECK-NEXT: [[TMP106:%.*]] = fadd <1 x double> [[TMP103]], [[TMP105]] -; CHECK-NEXT: [[BLOCK69:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP107:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT70:%.*]] = insertelement <1 x double> poison, double [[TMP107]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT71:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT70]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP108:%.*]] = fmul <1 x double> [[BLOCK69]], [[SPLAT_SPLAT71]] -; CHECK-NEXT: [[TMP109:%.*]] = fadd <1 x double> [[TMP106]], [[TMP108]] -; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <1 x double> [[TMP109]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP111:%.*]] = shufflevector <3 x double> poison, <3 x double> [[TMP110]], <3 x i32> -; CHECK-NEXT: [[BLOCK72:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP112:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT73:%.*]] = insertelement <1 x double> poison, double [[TMP112]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT74:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT73]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP113:%.*]] = fmul <1 x double> [[BLOCK72]], [[SPLAT_SPLAT74]] -; CHECK-NEXT: [[BLOCK75:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP114:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT76:%.*]] = insertelement <1 x double> poison, double [[TMP114]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT77:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT76]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP115:%.*]] = fmul <1 x double> [[BLOCK75]], [[SPLAT_SPLAT77]] -; CHECK-NEXT: [[TMP116:%.*]] = fadd <1 x double> [[TMP113]], [[TMP115]] -; CHECK-NEXT: [[BLOCK78:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP117:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT79:%.*]] = insertelement <1 x double> poison, double [[TMP117]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT80:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT79]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP118:%.*]] = fmul <1 x double> [[BLOCK78]], [[SPLAT_SPLAT80]] -; CHECK-NEXT: [[TMP119:%.*]] = fadd <1 x double> [[TMP116]], [[TMP118]] -; CHECK-NEXT: [[TMP120:%.*]] = shufflevector <1 x double> [[TMP119]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP121:%.*]] = shufflevector <3 x double> [[TMP111]], <3 x double> [[TMP120]], <3 x i32> -; CHECK-NEXT: [[BLOCK81:%.*]] = shufflevector <3 x double> [[SPLIT9]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP122:%.*]] = extractelement <3 x double> [[MERGE11]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLATINSERT82:%.*]] = insertelement <1 x double> poison, double [[TMP122]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT83:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT82]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP123:%.*]] = fmul <1 x double> [[BLOCK81]], [[SPLAT_SPLAT83]] -; CHECK-NEXT: [[BLOCK84:%.*]] = shufflevector <3 x double> [[SPLIT10]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP124:%.*]] = extractelement <3 x double> [[MERGE11]], i64 1 -; CHECK-NEXT: [[SPLAT_SPLATINSERT85:%.*]] = insertelement <1 x double> poison, double [[TMP124]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT86:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT85]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP125:%.*]] = fmul <1 x double> [[BLOCK84]], [[SPLAT_SPLAT86]] -; CHECK-NEXT: [[TMP126:%.*]] = fadd <1 x double> [[TMP123]], [[TMP125]] -; CHECK-NEXT: [[BLOCK87:%.*]] = shufflevector <3 x double> [[SPLIT11]], <3 x double> poison, <1 x i32> -; CHECK-NEXT: [[TMP127:%.*]] = extractelement <3 x double> [[MERGE11]], i64 2 -; CHECK-NEXT: [[SPLAT_SPLATINSERT88:%.*]] = insertelement <1 x double> poison, double [[TMP127]], i64 0 -; CHECK-NEXT: [[SPLAT_SPLAT89:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT88]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP128:%.*]] = fmul <1 x double> [[BLOCK87]], [[SPLAT_SPLAT89]] -; CHECK-NEXT: [[TMP129:%.*]] = fadd <1 x double> [[TMP126]], [[TMP128]] -; CHECK-NEXT: [[TMP130:%.*]] = shufflevector <1 x double> [[TMP129]], <1 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP131:%.*]] = shufflevector <3 x double> [[TMP121]], <3 x double> [[TMP130]], <3 x i32> -; CHECK-NEXT: [[TMP132:%.*]] = shufflevector <3 x double> [[TMP71]], <3 x double> [[TMP101]], <6 x i32> -; CHECK-NEXT: [[TMP133:%.*]] = shufflevector <3 x double> [[TMP131]], <3 x double> poison, <6 x i32> -; CHECK-NEXT: [[TMP134:%.*]] = shufflevector <6 x double> [[TMP132]], <6 x double> [[TMP133]], <9 x i32> -; CHECK-NEXT: ret <9 x double> [[TMP134]] -; - - - -entry: - br i1 %cond, label %if.then, label %if.else - -if.then: ; preds = %entry - %A.trans = tail call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> %A, i32 3, i32 3) - br label %if.end - -if.else: ; preds = %entry - %B.trans = tail call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> %B, i32 3, i32 3) - br label %if.end - -if.end: ; preds = %if.then, %if.else - %merge = phi <9 x double> [ %A.trans, %if.then], [ %B.trans, %if.else ] - %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %C, <9 x double> %merge, i32 3, i32 3, i32 3) - ret <9 x double> %res -} - ; Make sure we use a flattened vector when calling @foo and the use its flat vector result properly. define <9 x double> @unsupported_call(i1 %cond, <9 x double> %A, <9 x double> %B) { ; CHECK-LABEL: @unsupported_call( From 905c1e93506605b841fc5beadf9c1dc94cb4ecd6 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 27 May 2025 15:41:56 -0700 Subject: [PATCH 03/26] clang-format --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index c06d08688ab1c..8db812eacbd48 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -371,10 +371,10 @@ class LowerMatrixIntrinsics { if (auto *CDV = dyn_cast(Constant)) { unsigned Width = SI.getStride(); size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8; - StringRef Data = CDV->getRawDataValues().substr( - J * Width * EltSize, Width * EltSize); - addVector(ConstantDataVector::getRaw(Data, Width, - CDV->getElementType())); + StringRef Data = CDV->getRawDataValues().substr(J * Width * EltSize, + Width * EltSize); + addVector( + ConstantDataVector::getRaw(Data, Width, CDV->getElementType())); } else if (isa(Constant)) addVector(PoisonValue::get(RowTy)); else if (isa(Constant)) From 9ee44f03681404f339c9dd8dcc0d4e5729333999 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 27 May 2025 15:56:50 -0700 Subject: [PATCH 04/26] add test for ConstantDataVector lowering --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index e2720b86d87ca..2d2125d7b444f 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -215,6 +215,58 @@ exit: ret void } +define void @matrix_phi_loop_cdv(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_cdv( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ , [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ , [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ , [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: + %phi = phi <9 x double> [, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = load <9 x double>, ptr %in2 + + ; Give in2 the shape: 3 x 3 + %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) + %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + + %sum = fadd <9 x double> %phi, %in2tt + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} + define <9 x double> @matrix_phi_ifthenelse(i1 %cond, <9 x double> %A, <9 x double> %B, <9 x double> %C) { ; CHECK-LABEL: @matrix_phi_ifthenelse( ; CHECK-NEXT: entry: From 169960dd3e7a9cb4e25924497c024285fed8a9f2 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 27 May 2025 15:58:04 -0700 Subject: [PATCH 05/26] move report_fatal_error outside of NDEBUG block --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8db812eacbd48..f447baf717b87 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -384,8 +384,8 @@ class LowerMatrixIntrinsics { else { #ifndef NDEBUG Constant->dump(); - report_fatal_error("unhandled ConstantData type"); #endif + report_fatal_error("unhandled ConstantData type"); } } } From 18951cdfd6ea206ca2d20dd68e07a04a7c5b1264 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Sat, 31 May 2025 12:55:32 -0700 Subject: [PATCH 06/26] fix bad merge --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 99b233d73ecf6..a439ffc071f86 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -234,6 +234,7 @@ static bool isUniformShape(Value *V) { return true; switch (I->getOpcode()) { + case Instruction::PHI: case Instruction::FNeg: return true; default: From f8aea05852275e3775e6ba26634170138e78b401 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 08:26:49 -0700 Subject: [PATCH 07/26] use col major load intrinsics --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 90 ++++++------------- 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index 2d2125d7b444f..510cd670d8de8 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -4,22 +4,22 @@ define void @matrix_phi_loop(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-LABEL: @matrix_phi_loop( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6 -; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: [[PHI9:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI10:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI11:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3 ; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6 -; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 16 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 8 ; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]] ; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]] ; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]] @@ -35,20 +35,16 @@ define void @matrix_phi_loop(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: ret void ; entry: - %mat = load <9 x double>, ptr %in1 + %in1v = call <9 x double> @llvm.matrix.column.major.load(ptr %in1, i64 3, i1 false, i32 3, i32 3) br label %loop loop: - %phi = phi <9 x double> [%mat, %entry], [%sum, %loop] + %phi = phi <9 x double> [%in1v, %entry], [%sum, %loop] %ctr = phi i32 [%count, %entry], [%dec, %loop] - %in2v = load <9 x double>, ptr %in2 + %in2v = call <9 x double> @llvm.matrix.column.major.load(ptr %in2, i64 3, i1 false, i32 3, i32 3) - ; Give in2 the shape: 3 x 3 - %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) - %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) - - %sum = fadd <9 x double> %phi, %in2tt + %sum = fadd <9 x double> %phi, %in2v %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 @@ -68,11 +64,11 @@ define void @matrix_phi_loop_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr ; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 -; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 ; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] ; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] ; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] @@ -94,13 +90,9 @@ loop: %phi = phi <9 x double> [zeroinitializer, %entry], [%sum, %loop] %ctr = phi i32 [%count, %entry], [%dec, %loop] - %in2v = load <9 x double>, ptr %in2 - - ; Give in2 the shape: 3 x 3 - %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) - %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in2, i64 3, i1 false, i32 3, i32 3) - %sum = fadd <9 x double> %phi, %in2tt + %sum = fadd <9 x double> %phi, %inv %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 @@ -111,32 +103,20 @@ exit: ret void } -define void @matrix_phi_loop_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) { +define void @matrix_phi_loop_undef(ptr %in, i32 %count, ptr %out) { ; CHECK-LABEL: @matrix_phi_loop_undef( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ undef, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI:%.*]] = phi <9 x double> [ undef, [[ENTRY:%.*]] ], [ [[SUM:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 -; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 -; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 -; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 -; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] -; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] -; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[INV:%.*]] = load <9 x double>, ptr [[IN:%.*]], align 128 +; CHECK-NEXT: [[SUM]] = fadd <9 x double> [[PHI]], [[INV]] ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 -; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 -; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: store <9 x double> [[SUM]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: ret void ; entry: @@ -146,13 +126,9 @@ loop: %phi = phi <9 x double> [undef, %entry], [%sum, %loop] %ctr = phi i32 [%count, %entry], [%dec, %loop] - %in2v = load <9 x double>, ptr %in2 - - ; Give in2 the shape: 3 x 3 - %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) - %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + %inv = load <9 x double>, ptr %in - %sum = fadd <9 x double> %phi, %in2tt + %sum = fadd <9 x double> %phi, %inv %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 @@ -163,7 +139,7 @@ exit: ret void } -define void @matrix_phi_loop_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) { +define void @matrix_phi_loop_poison(ptr %in, i32 %count, ptr %out) { ; CHECK-LABEL: @matrix_phi_loop_poison( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] @@ -172,11 +148,11 @@ define void @matrix_phi_loop_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 -; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 ; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] ; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] ; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] @@ -198,13 +174,9 @@ loop: %phi = phi <9 x double> [poison, %entry], [%sum, %loop] %ctr = phi i32 [%count, %entry], [%dec, %loop] - %in2v = load <9 x double>, ptr %in2 + %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 false, i32 3, i32 3) - ; Give in2 the shape: 3 x 3 - %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) - %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) - - %sum = fadd <9 x double> %phi, %in2tt + %sum = fadd <9 x double> %phi, %inv %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 @@ -215,7 +187,7 @@ exit: ret void } -define void @matrix_phi_loop_cdv(ptr %in1, ptr %in2, i32 %count, ptr %out) { +define void @matrix_phi_loop_cdv(ptr %in, i32 %count, ptr %out) { ; CHECK-LABEL: @matrix_phi_loop_cdv( ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[LOOP:%.*]] @@ -224,11 +196,11 @@ define void @matrix_phi_loop_cdv(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ , [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ , [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 -; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 ; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] ; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] ; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] @@ -250,13 +222,9 @@ loop: %phi = phi <9 x double> [, %entry], [%sum, %loop] %ctr = phi i32 [%count, %entry], [%dec, %loop] - %in2v = load <9 x double>, ptr %in2 - - ; Give in2 the shape: 3 x 3 - %in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3) - %in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3) + %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 false, i32 3, i32 3) - %sum = fadd <9 x double> %phi, %in2tt + %sum = fadd <9 x double> %phi, %inv %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 From e56b225095f142a022821559aa80c2da00c10d91 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 10:57:34 -0700 Subject: [PATCH 08/26] add tests for phi's consuming phi's, and phi's with more than two inputs --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index 510cd670d8de8..eccf8ec90760a 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -235,6 +235,126 @@ exit: ret void } +define void @matrix_phi_loop_delay(ptr %in1, ptr %in2, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_delay( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI14:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI15:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI16:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP0]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY]] ], [ [[TMP3:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP1]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP4:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP2]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP5:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3 +; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: [[TMP3]] = fadd <3 x double> [[PHI14]], [[COL_LOAD4]] +; CHECK-NEXT: [[TMP4]] = fadd <3 x double> [[PHI15]], [[COL_LOAD6]] +; CHECK-NEXT: [[TMP5]] = fadd <3 x double> [[PHI16]], [[COL_LOAD8]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP3]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP4]], ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[TMP5]], ptr [[VEC_GEP13]], align 16 +; CHECK-NEXT: ret void +; +entry: + %in1v = call <9 x double> @llvm.matrix.column.major.load(ptr %in1, i64 3, i1 false, i32 3, i32 3) + br label %loop + +loop: + %phi2 = phi <9 x double> [%in1v, %entry], [%phi, %loop] + %phi = phi <9 x double> [%in1v, %entry], [%sum, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %in2v = call <9 x double> @llvm.matrix.column.major.load(ptr %in2, i64 3, i1 false, i32 3, i32 3) + + %sum = fadd <9 x double> %phi2, %in2v + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <9 x double> %sum, ptr %out + ret void +} + +define void @matrix_phi_three_preds(i1 %cond1, i1 %cond2, ptr %a, ptr %b, ptr %c, ptr %out) { +; CHECK-LABEL: @matrix_phi_three_preds( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND1:%.*]], label [[BB1:%.*]], label [[BBA:%.*]] +; CHECK: bb1: +; CHECK-NEXT: br i1 [[COND2:%.*]], label [[BBB:%.*]], label [[BBC:%.*]] +; CHECK: bba: +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[A:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[A]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[A]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: br label [[EXIT:%.*]] +; CHECK: bbb: +; CHECK-NEXT: [[COL_LOAD9:%.*]] = load <3 x double>, ptr [[B:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP10:%.*]] = getelementptr double, ptr [[B]], i64 3 +; CHECK-NEXT: [[COL_LOAD11:%.*]] = load <3 x double>, ptr [[VEC_GEP10]], align 8 +; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[B]], i64 6 +; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <3 x double>, ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: bbc: +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[C:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[C]], i64 3 +; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[C]], i64 6 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: [[PHI14:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[BBA]] ], [ [[COL_LOAD9]], [[BBB]] ], [ [[COL_LOAD4]], [[BBC]] ] +; CHECK-NEXT: [[PHI15:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[BBA]] ], [ [[COL_LOAD11]], [[BBB]] ], [ [[COL_LOAD6]], [[BBC]] ] +; CHECK-NEXT: [[PHI16:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[BBA]] ], [ [[COL_LOAD13]], [[BBB]] ], [ [[COL_LOAD8]], [[BBC]] ] +; CHECK-NEXT: store <3 x double> [[PHI14]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP17:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[PHI15]], ptr [[VEC_GEP17]], align 8 +; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[PHI16]], ptr [[VEC_GEP18]], align 8 +; CHECK-NEXT: ret void +; +entry: + br i1 %cond1, label %bb1, label %bba + +bb1: + br i1 %cond2, label %bbb, label %bbc + +bba: + %va = call <9 x double> @llvm.matrix.column.major.load(ptr %a, i64 3, i1 false, i32 3, i32 3) + br label %exit + +bbb: + %vb = call <9 x double> @llvm.matrix.column.major.load(ptr %b, i64 3, i1 false, i32 3, i32 3) + br label %exit + +bbc: + %vc = call <9 x double> @llvm.matrix.column.major.load(ptr %c, i64 3, i1 false, i32 3, i32 3) + br label %exit + +exit: + %phi = phi <9 x double> [%va, %bba], [%vb, %bbb], [%vc, %bbc] + call void @llvm.matrix.column.major.store(<9 x double> %phi, ptr %out, i64 3, i1 false, i32 3, i32 3) + ret void +} + define <9 x double> @matrix_phi_ifthenelse(i1 %cond, <9 x double> %A, <9 x double> %B, <9 x double> %C) { ; CHECK-LABEL: @matrix_phi_ifthenelse( ; CHECK-NEXT: entry: From ffbc73f7b8c6dbd87df2e0cee013c19b201cf8a5 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 11:10:11 -0700 Subject: [PATCH 09/26] handle phi's more like other ops. instcombine will clean up after us --- .../Scalar/LowerMatrixIntrinsics.cpp | 65 +++++++------- .../Transforms/LowerMatrixIntrinsics/phi.ll | 90 ++++++++++++------- 2 files changed, 95 insertions(+), 60 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index a439ffc071f86..8333fde287f67 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -590,23 +590,6 @@ class LowerMatrixIntrinsics { MatrixVal = M.embedInVector(Builder); } - // If it's a PHI, split it now. We'll take care of fixing up the operands - // later once we're in VisitPHI. - if (auto *PHI = dyn_cast(MatrixVal)) { - auto *EltTy = cast(PHI->getType())->getElementType(); - MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; - - IRBuilder<>::InsertPointGuard IPG(Builder); - Builder.SetInsertPoint(PHI); - for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) - PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), - PHI->getNumIncomingValues(), - PHI->getName())); - - Inst2ColumnMatrix[PHI] = PhiM; - return PhiM; - } - // If it's a constant, materialize the split version of it with this shape. if (auto *IncomingConst = dyn_cast(MatrixVal)) return MatrixTy(IncomingConst, SI); @@ -1122,12 +1105,9 @@ class LowerMatrixIntrinsics { Changed |= VisitLoad(cast(Inst), Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) Changed |= VisitStore(cast(Inst), Op1, Op2, Builder); - } - - // Fifth, lower all the PHI's with shape information. - for (Instruction *Inst : MatrixInsts) - if (auto *PHI = dyn_cast(Inst)) + else if (auto *PHI = dyn_cast(Inst)) Changed |= VisitPHI(PHI); + } if (ORE) { RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); @@ -2191,22 +2171,47 @@ class LowerMatrixIntrinsics { if (I == ShapeMap.end()) return false; + const ShapeInfo &SI = I->second; + IRBuilder<> Builder(Inst); - MatrixTy PhiM = getMatrix(Inst, I->second, Builder); + auto getMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { + auto *I = Inst2ColumnMatrix.find(MatrixVal); + if (I != Inst2ColumnMatrix.end()) + return I->second; - for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues(); - IncomingI != IncomingE; ++IncomingI) { - Value *IncomingV = Inst->getIncomingValue(IncomingI); - BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI); + if (auto *PHI = dyn_cast(MatrixVal)) { + auto *EltTy = cast(PHI->getType())->getElementType(); + MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; + + IRBuilder<>::InsertPointGuard IPG(Builder); + Builder.SetInsertPoint(PHI); + for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) + PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), + PHI->getNumIncomingValues(), + PHI->getName())); + + Inst2ColumnMatrix[PHI] = PhiM; + return PhiM; + } // getMatrix() may insert some instructions. The safe place to insert them // is at the end of the parent block, where the register allocator would // have inserted the copies that materialize the PHI. - if (auto *IncomingInst = dyn_cast(IncomingV)) - Builder.SetInsertPoint(IncomingInst->getParent()->getTerminator()); + if (auto *MatrixInst = dyn_cast(MatrixVal)) + Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); + + return this->getMatrix(MatrixVal, SI, Builder); + }; + + MatrixTy PhiM = getMatrix(Inst); + + for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues(); + IncomingI != IncomingE; ++IncomingI) { + Value *IncomingV = Inst->getIncomingValue(IncomingI); + BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI); - MatrixTy OpM = getMatrix(IncomingV, I->second, Builder); + MatrixTy OpM = getMatrix(IncomingV); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { PHINode *NewPHI = cast(PhiM.getVector(VI)); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index eccf8ec90760a..e0abbadd6101a 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -20,18 +20,24 @@ define void @matrix_phi_loop(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6 ; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 8 -; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]] -; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]] -; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]] +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]] +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP12]], align 8 ; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP13]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP13]], align 16 ; CHECK-NEXT: ret void ; entry: @@ -69,18 +75,24 @@ define void @matrix_phi_loop_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 ; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 -; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] -; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] -; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP7]], align 8 ; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP8]], align 16 ; CHECK-NEXT: ret void ; entry: @@ -153,18 +165,24 @@ define void @matrix_phi_loop_poison(ptr %in, i32 %count, ptr %out) { ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 ; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 -; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] -; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] -; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP7]], align 8 ; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP8]], align 16 ; CHECK-NEXT: ret void ; entry: @@ -201,18 +219,24 @@ define void @matrix_phi_loop_cdv(ptr %in, i32 %count, ptr %out) { ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6 ; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 -; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] -; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] -; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]] +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]] +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP7]], align 8 ; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP8]], align 16 ; CHECK-NEXT: ret void ; entry: @@ -257,18 +281,24 @@ define void @matrix_phi_loop_delay(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8 ; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6 ; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 8 -; CHECK-NEXT: [[TMP3]] = fadd <3 x double> [[PHI14]], [[COL_LOAD4]] -; CHECK-NEXT: [[TMP4]] = fadd <3 x double> [[PHI15]], [[COL_LOAD6]] -; CHECK-NEXT: [[TMP5]] = fadd <3 x double> [[PHI16]], [[COL_LOAD8]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[PHI14]], [[COL_LOAD4]] +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[PHI15]], [[COL_LOAD6]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[PHI16]], [[COL_LOAD8]] +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <6 x double> [[TMP9]], <6 x double> [[TMP10]], <9 x i32> ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: [[TMP3]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP4]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP5]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP3]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP4]], ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP12]], align 8 ; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP5]], ptr [[VEC_GEP13]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP13]], align 16 ; CHECK-NEXT: ret void ; entry: From 15fd60b064a7b948952ca49395ffb2d31b96c98d Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 11:20:42 -0700 Subject: [PATCH 10/26] handle phi's with shape mismatch --- .../Scalar/LowerMatrixIntrinsics.cpp | 26 ++++++---- .../Transforms/LowerMatrixIntrinsics/phi.ll | 47 +++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8333fde287f67..eda5e53782e1b 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2176,9 +2176,23 @@ class LowerMatrixIntrinsics { IRBuilder<> Builder(Inst); auto getMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { - auto *I = Inst2ColumnMatrix.find(MatrixVal); - if (I != Inst2ColumnMatrix.end()) - return I->second; + // getMatrix() and embedInVector() may insert some instructions. The safe + // place to insert them is at the end of the parent block, where the + // register allocator would have inserted the copies that materialize the + // PHI. + if (auto *MatrixInst = dyn_cast(MatrixVal)) + Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); + + auto Found = Inst2ColumnMatrix.find(MatrixVal); + if (Found != Inst2ColumnMatrix.end()) { + MatrixTy &M = Found->second; + // Return the found matrix, if its shape matches the requested shape + // information + if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) + return M; + + MatrixVal = M.embedInVector(Builder); + } if (auto *PHI = dyn_cast(MatrixVal)) { auto *EltTy = cast(PHI->getType())->getElementType(); @@ -2195,12 +2209,6 @@ class LowerMatrixIntrinsics { return PhiM; } - // getMatrix() may insert some instructions. The safe place to insert them - // is at the end of the parent block, where the register allocator would - // have inserted the copies that materialize the PHI. - if (auto *MatrixInst = dyn_cast(MatrixVal)) - Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); - return this->getMatrix(MatrixVal, SI, Builder); }; diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index e0abbadd6101a..c53735a850e24 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -385,6 +385,53 @@ exit: ret void } +define void @matrix_phi_two_preds_shape_mismatch(i1 %cond1, ptr %a, ptr %b, ptr %out) { +; CHECK-LABEL: @matrix_phi_two_preds_shape_mismatch( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND1:%.*]], label [[BBA:%.*]], label [[BBB:%.*]] +; CHECK: bba: +; CHECK-NEXT: [[COL_LOAD16:%.*]] = load <3 x double>, ptr [[A:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP17:%.*]] = getelementptr double, ptr [[A]], i64 3 +; CHECK-NEXT: [[COL_LOAD18:%.*]] = load <3 x double>, ptr [[VEC_GEP17]], align 8 +; CHECK-NEXT: [[VEC_GEP19:%.*]] = getelementptr double, ptr [[A]], i64 6 +; CHECK-NEXT: [[COL_LOAD20:%.*]] = load <3 x double>, ptr [[VEC_GEP19]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD16]], <3 x double> [[COL_LOAD18]], <6 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD20]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[TMP2]], <9 x double> poison, <9 x i32> +; CHECK-NEXT: br label [[EXIT:%.*]] +; CHECK: bbb: +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <9 x double>, ptr [[B:%.*]], align 8 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: [[TMP11:%.*]] = phi <9 x double> [ [[SPLIT]], [[BBA]] ], [ [[COL_LOAD]], [[BBB]] ] +; CHECK-NEXT: [[SPLIT38:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT39:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT40:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: store <3 x double> [[SPLIT38]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP41:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[SPLIT39]], ptr [[VEC_GEP41]], align 8 +; CHECK-NEXT: [[VEC_GEP42:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[SPLIT40]], ptr [[VEC_GEP42]], align 8 +; CHECK-NEXT: ret void +; +entry: + br i1 %cond1, label %bba, label %bbb + +bba: + %va = call <9 x double> @llvm.matrix.column.major.load(ptr %a, i64 3, i1 false, i32 3, i32 3) + br label %exit + +bbb: + %vb = call <9 x double> @llvm.matrix.column.major.load(ptr %b, i64 9, i1 false, i32 9, i32 1) + br label %exit + +exit: + %phi = phi <9 x double> [%va, %bba], [%vb, %bbb] + call void @llvm.matrix.column.major.store(<9 x double> %phi, ptr %out, i64 3, i1 false, i32 3, i32 3) + ret void +} + define <9 x double> @matrix_phi_ifthenelse(i1 %cond, <9 x double> %A, <9 x double> %B, <9 x double> %C) { ; CHECK-LABEL: @matrix_phi_ifthenelse( ; CHECK-NEXT: entry: From 655eb8879334f8d8fa65498724be82f95cef0b3c Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 11:41:02 -0700 Subject: [PATCH 11/26] simplify getMatrix shim --- .../Scalar/LowerMatrixIntrinsics.cpp | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index a697672989b3a..00ba8651ec077 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2164,47 +2164,39 @@ class LowerMatrixIntrinsics { bool VisitPHI(PHINode *Inst) { auto I = ShapeMap.find(Inst); - if (I == ShapeMap.end()) - return false; + assert(I != ShapeMap.end() && + "must only visit instructions with shape info"); const ShapeInfo &SI = I->second; IRBuilder<> Builder(Inst); + // Shim this->getMatrix to insert split phi's as needed. auto getMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { - // getMatrix() and embedInVector() may insert some instructions. The safe - // place to insert them is at the end of the parent block, where the - // register allocator would have inserted the copies that materialize the - // PHI. - if (auto *MatrixInst = dyn_cast(MatrixVal)) - Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); - - auto Found = Inst2ColumnMatrix.find(MatrixVal); - if (Found != Inst2ColumnMatrix.end()) { - MatrixTy &M = Found->second; - // Return the found matrix, if its shape matches the requested shape - // information - if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) - return M; + IRBuilder<>::InsertPointGuard IPG(Builder); - MatrixVal = M.embedInVector(Builder); - } - - if (auto *PHI = dyn_cast(MatrixVal)) { - auto *EltTy = cast(PHI->getType())->getElementType(); - MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; + auto I = Inst2ColumnMatrix.find(MatrixVal); + if (I == Inst2ColumnMatrix.end()) { + if (auto *PHI = dyn_cast(MatrixVal)) { + auto *EltTy = cast(PHI->getType())->getElementType(); + MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; - IRBuilder<>::InsertPointGuard IPG(Builder); - Builder.SetInsertPoint(PHI); - for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) - PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), - PHI->getNumIncomingValues(), - PHI->getName())); + Builder.SetInsertPoint(PHI); + for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) + PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), + PHI->getNumIncomingValues(), + PHI->getName())); - Inst2ColumnMatrix[PHI] = PhiM; - return PhiM; + Inst2ColumnMatrix[PHI] = PhiM; + } } + // getMatrix() may insert some instructions for reshaping. The safe place + // to insert them is at the end of the parent block, where the register + // allocator would have inserted the copies that materialize the PHI. + if (auto *MatrixInst = dyn_cast(MatrixVal)) + Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); + return this->getMatrix(MatrixVal, SI, Builder); }; From e262f767a19d21b4dc3fa28137ed3ac3c19658c4 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 2 Jun 2025 11:44:50 -0700 Subject: [PATCH 12/26] test the other order of shape mismatch --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 76 ++++++++++++++----- 1 file changed, 57 insertions(+), 19 deletions(-) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index c53735a850e24..738325a0f438d 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -385,18 +385,18 @@ exit: ret void } -define void @matrix_phi_two_preds_shape_mismatch(i1 %cond1, ptr %a, ptr %b, ptr %out) { -; CHECK-LABEL: @matrix_phi_two_preds_shape_mismatch( +define void @matrix_phi_two_preds_shape_mismatch1(i1 %cond1, ptr %a, ptr %b, ptr %out) { +; CHECK-LABEL: @matrix_phi_two_preds_shape_mismatch1( ; CHECK-NEXT: entry: ; CHECK-NEXT: br i1 [[COND1:%.*]], label [[BBA:%.*]], label [[BBB:%.*]] ; CHECK: bba: -; CHECK-NEXT: [[COL_LOAD16:%.*]] = load <3 x double>, ptr [[A:%.*]], align 8 -; CHECK-NEXT: [[VEC_GEP17:%.*]] = getelementptr double, ptr [[A]], i64 3 -; CHECK-NEXT: [[COL_LOAD18:%.*]] = load <3 x double>, ptr [[VEC_GEP17]], align 8 -; CHECK-NEXT: [[VEC_GEP19:%.*]] = getelementptr double, ptr [[A]], i64 6 -; CHECK-NEXT: [[COL_LOAD20:%.*]] = load <3 x double>, ptr [[VEC_GEP19]], align 8 -; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD16]], <3 x double> [[COL_LOAD18]], <6 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD20]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[A:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[A]], i64 3 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, ptr [[A]], i64 6 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[VEC_GEP3]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD1]], <3 x double> [[COL_LOAD2]], <6 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD4]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> ; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[TMP2]], <9 x double> poison, <9 x i32> ; CHECK-NEXT: br label [[EXIT:%.*]] @@ -404,15 +404,8 @@ define void @matrix_phi_two_preds_shape_mismatch(i1 %cond1, ptr %a, ptr %b, ptr ; CHECK-NEXT: [[COL_LOAD:%.*]] = load <9 x double>, ptr [[B:%.*]], align 8 ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: -; CHECK-NEXT: [[TMP11:%.*]] = phi <9 x double> [ [[SPLIT]], [[BBA]] ], [ [[COL_LOAD]], [[BBB]] ] -; CHECK-NEXT: [[SPLIT38:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT39:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[SPLIT40:%.*]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: store <3 x double> [[SPLIT38]], ptr [[OUT:%.*]], align 8 -; CHECK-NEXT: [[VEC_GEP41:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[SPLIT39]], ptr [[VEC_GEP41]], align 8 -; CHECK-NEXT: [[VEC_GEP42:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[SPLIT40]], ptr [[VEC_GEP42]], align 8 +; CHECK-NEXT: [[PHI5:%.*]] = phi <9 x double> [ [[SPLIT]], [[BBA]] ], [ [[COL_LOAD]], [[BBB]] ] +; CHECK-NEXT: store <9 x double> [[PHI5]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: ret void ; entry: @@ -428,7 +421,52 @@ bbb: exit: %phi = phi <9 x double> [%va, %bba], [%vb, %bbb] - call void @llvm.matrix.column.major.store(<9 x double> %phi, ptr %out, i64 3, i1 false, i32 3, i32 3) + store <9 x double> %phi, ptr %out + ret void +} + +define void @matrix_phi_two_preds_shape_mismatch2(i1 %cond1, ptr %a, ptr %b, ptr %out) { +; CHECK-LABEL: @matrix_phi_two_preds_shape_mismatch2( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[COND1:%.*]], label [[BBA:%.*]], label [[BBB:%.*]] +; CHECK: bba: +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <9 x double>, ptr [[A:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <9 x double> [[COL_LOAD4]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <9 x double> [[COL_LOAD4]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <9 x double> [[COL_LOAD4]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: br label [[EXIT:%.*]] +; CHECK: bbb: +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[B:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[B]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[B]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ [[SPLIT]], [[BBA]] ], [ [[COL_LOAD]], [[BBB]] ] +; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ [[SPLIT8]], [[BBA]] ], [ [[COL_LOAD1]], [[BBB]] ] +; CHECK-NEXT: [[PHI7:%.*]] = phi <3 x double> [ [[SPLIT9]], [[BBA]] ], [ [[COL_LOAD3]], [[BBB]] ] +; CHECK-NEXT: store <3 x double> [[PHI5]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: [[VEC_GEP10:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[PHI6]], ptr [[VEC_GEP10]], align 8 +; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr double, ptr [[OUT]], i64 6 +; CHECK-NEXT: store <3 x double> [[PHI7]], ptr [[VEC_GEP11]], align 16 +; CHECK-NEXT: ret void +; +entry: + br i1 %cond1, label %bba, label %bbb + +bba: + %va = call <9 x double> @llvm.matrix.column.major.load(ptr %a, i64 9, i1 false, i32 9, i32 1) + br label %exit + +bbb: + %vb = call <9 x double> @llvm.matrix.column.major.load(ptr %b, i64 3, i1 false, i32 3, i32 3) + br label %exit + +exit: + %phi = phi <9 x double> [%va, %bba], [%vb, %bbb] + store <9 x double> %phi, ptr %out ret void } From 2c86c2fb263dfe60ab270877b185773bf10a9561 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 9 Jun 2025 15:32:58 -0700 Subject: [PATCH 13/26] clang-format --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 00ba8651ec077..ec9ddb6e2b6fd 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2184,8 +2184,8 @@ class LowerMatrixIntrinsics { Builder.SetInsertPoint(PHI); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), - PHI->getNumIncomingValues(), - PHI->getName())); + PHI->getNumIncomingValues(), + PHI->getName())); Inst2ColumnMatrix[PHI] = PhiM; } From 2e5b2d406eef57683f6019feee862f2f48b5522c Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 10 Jun 2025 18:24:51 -0700 Subject: [PATCH 14/26] clang-format --- .../Scalar/LowerMatrixIntrinsics.cpp | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index bf5bf328908e7..3caa28d855cce 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1230,7 +1230,8 @@ class LowerMatrixIntrinsics { } /// Replace intrinsic calls. - MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { + MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI, + IRBuilder<> &Builder) { assert(Inst->getCalledFunction() && Inst->getCalledFunction()->isIntrinsic()); @@ -1338,7 +1339,8 @@ class LowerMatrixIntrinsics { /// Lower a load instruction with shape information. MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, - Value *Stride, bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { + Value *Stride, bool IsVolatile, ShapeInfo Shape, + IRBuilder<> &Builder) { return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape, Builder); } @@ -1415,7 +1417,8 @@ class LowerMatrixIntrinsics { Value *Stride = Inst->getArgOperand(2); return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, cast(Inst->getArgOperand(3))->isOne(), - {Inst->getArgOperand(4), Inst->getArgOperand(5)}, Builder); + {Inst->getArgOperand(4), Inst->getArgOperand(5)}, + Builder); } // Set elements I..I+NumElts-1 to Block @@ -2250,15 +2253,18 @@ class LowerMatrixIntrinsics { } /// Lower load instructions. - MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) { + MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, + IRBuilder<> &Builder) { return LowerLoad(Inst, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, Builder); + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, + Builder); } MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr, IRBuilder<> &Builder) { return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, Builder); + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, + Builder); } MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { @@ -2309,13 +2315,15 @@ class LowerMatrixIntrinsics { // finalizeLowering() may also insert instructions in some cases. The safe // place for those is at the end of the initial block of PHIs. auto IP = Inst->getInsertionPointAfterDef(); - assert(IP.has_value() && "expected to find a valid insertion point after the phi"); + assert(IP.has_value() && + "expected to find a valid insertion point after the phi"); Builder.SetInsertPoint(*IP); return PhiM; } /// Lower binary operators. - MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { + MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI, + IRBuilder<> &Builder) { Value *Lhs = Inst->getOperand(0); Value *Rhs = Inst->getOperand(1); @@ -2337,7 +2345,8 @@ class LowerMatrixIntrinsics { } /// Lower unary operators. - MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { + MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI, + IRBuilder<> &Builder) { Value *Op = Inst->getOperand(0); MatrixTy Result; @@ -2363,7 +2372,8 @@ class LowerMatrixIntrinsics { } /// Lower cast instructions. - MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape, IRBuilder<> &Builder) { + MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape, + IRBuilder<> &Builder) { Value *Op = Inst->getOperand(0); MatrixTy Result; From 4ba4e6643d411de590261e163b4dfd3b637ff1fc Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 12 Jun 2025 09:19:58 -0700 Subject: [PATCH 15/26] [Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch --- .../Scalar/LowerMatrixIntrinsics.cpp | 9 ++- .../LowerMatrixIntrinsics/select.ll | 61 +++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index bb1216125b875..be192df8a2392 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2410,14 +2410,13 @@ class LowerMatrixIntrinsics { MatrixTy A = getMatrix(OpA, Shape, Builder); MatrixTy B = getMatrix(OpB, Shape, Builder); - Value *CondV[2]; + SmallVector CondV; if (isa(Cond->getType())) { MatrixTy C = getMatrix(Cond, Shape, Builder); - CondV[0] = C.getVector(0); - CondV[1] = C.getVector(1); + llvm::copy(C.vectors(), std::back_inserter(CondV)); } else { - CondV[0] = Cond; - CondV[1] = Cond; + CondV.resize(A.getNumVectors()); + std::fill(CondV.begin(), CondV.end(), Cond); } for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors())) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll index 70b0dfdb3e7e8..bd97915759aac 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll @@ -144,3 +144,64 @@ define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { store <4 x float> %op, ptr %out ret void } + +define void @select_2x2_vcond_shape4(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape4( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x float>, ptr [[RHS:%.*]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[COL_LOAD1]], <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD2]] +; CHECK-NEXT: store <4 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1) + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 4, i1 false, i32 4, i32 1) + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} + +define void @select_2x2_vcond_shape5(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape5( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <1 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 1 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <1 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[LHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <1 x float>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[LHS]], i64 3 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <1 x float>, ptr [[VEC_GEP4]], align 4 +; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <1 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr i1, ptr [[COND]], i64 1 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <1 x i1>, ptr [[VEC_GEP7]], align 1 +; CHECK-NEXT: [[VEC_GEP9:%.*]] = getelementptr i1, ptr [[COND]], i64 2 +; CHECK-NEXT: [[COL_LOAD10:%.*]] = load <1 x i1>, ptr [[VEC_GEP9]], align 1 +; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr i1, ptr [[COND]], i64 3 +; CHECK-NEXT: [[COL_LOAD12:%.*]] = load <1 x i1>, ptr [[VEC_GEP11]], align 1 +; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <1 x float>, ptr [[RHS:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP14:%.*]] = getelementptr float, ptr [[RHS]], i64 1 +; CHECK-NEXT: [[COL_LOAD15:%.*]] = load <1 x float>, ptr [[VEC_GEP14]], align 4 +; CHECK-NEXT: [[VEC_GEP16:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <1 x float>, ptr [[VEC_GEP16]], align 4 +; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr float, ptr [[RHS]], i64 3 +; CHECK-NEXT: [[COL_LOAD19:%.*]] = load <1 x float>, ptr [[VEC_GEP18]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <1 x i1> [[COL_LOAD6]], <1 x float> [[COL_LOAD]], <1 x float> [[COL_LOAD13]] +; CHECK-NEXT: [[TMP2:%.*]] = select <1 x i1> [[COL_LOAD8]], <1 x float> [[COL_LOAD1]], <1 x float> [[COL_LOAD15]] +; CHECK-NEXT: [[TMP3:%.*]] = select <1 x i1> [[COL_LOAD10]], <1 x float> [[COL_LOAD3]], <1 x float> [[COL_LOAD17]] +; CHECK-NEXT: [[TMP4:%.*]] = select <1 x i1> [[COL_LOAD12]], <1 x float> [[COL_LOAD5]], <1 x float> [[COL_LOAD19]] +; CHECK-NEXT: store <1 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP20:%.*]] = getelementptr float, ptr [[OUT]], i64 1 +; CHECK-NEXT: store <1 x float> [[TMP2]], ptr [[VEC_GEP20]], align 4 +; CHECK-NEXT: [[VEC_GEP21:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <1 x float> [[TMP3]], ptr [[VEC_GEP21]], align 8 +; CHECK-NEXT: [[VEC_GEP22:%.*]] = getelementptr float, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <1 x float> [[TMP4]], ptr [[VEC_GEP22]], align 4 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 1, i1 false, i32 1, i32 4) + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 1, i1 false, i32 1, i32 4) + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} From 4cbc8397b00b52f69d72ee46d4ea36bd86c1fe89 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 16 Jun 2025 08:23:23 -0700 Subject: [PATCH 16/26] review feedback: parens for initializer --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index be192df8a2392..8ac8cef338fbd 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2285,7 +2285,7 @@ class LowerMatrixIntrinsics { if (I == Inst2ColumnMatrix.end()) { if (auto *PHI = dyn_cast(MatrixVal)) { auto *EltTy = cast(PHI->getType())->getElementType(); - MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy}; + MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy); Builder.SetInsertPoint(PHI); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) From 6f8ec49b73d0e391006c1d12ddebea43f280181f Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 16 Jun 2025 08:23:51 -0700 Subject: [PATCH 17/26] review feedback: rename to GetMatrix --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8ac8cef338fbd..d3b27f8b81fca 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2278,7 +2278,7 @@ class LowerMatrixIntrinsics { MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { // Shim this->getMatrix to insert split phi's as needed. - auto getMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { + auto GetMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { IRBuilder<>::InsertPointGuard IPG(Builder); auto I = Inst2ColumnMatrix.find(MatrixVal); @@ -2306,14 +2306,14 @@ class LowerMatrixIntrinsics { return this->getMatrix(MatrixVal, SI, Builder); }; - MatrixTy PhiM = getMatrix(Inst); + MatrixTy PhiM = GetMatrix(Inst); for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues(); IncomingI != IncomingE; ++IncomingI) { Value *IncomingV = Inst->getIncomingValue(IncomingI); BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI); - MatrixTy OpM = getMatrix(IncomingV); + MatrixTy OpM = GetMatrix(IncomingV); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { PHINode *NewPHI = cast(PhiM.getVector(VI)); From 67ead3776dd02730918cc27c2473f583c9f09de5 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 16 Jun 2025 13:27:02 -0700 Subject: [PATCH 18/26] drop code for splitting constants, add test for it --- .../Scalar/LowerMatrixIntrinsics.cpp | 31 --------- .../LowerMatrixIntrinsics/constant.ll | 68 +++++++++++++++++++ 2 files changed, 68 insertions(+), 31 deletions(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index d3b27f8b81fca..a4244e854f7a2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -416,33 +416,6 @@ class LowerMatrixIntrinsics { addVector(PoisonValue::get(FixedVectorType::get( EltTy, isColumnMajor() ? NumRows : NumColumns))); } - MatrixTy(ConstantData *Constant, const ShapeInfo &SI) - : IsColumnMajor(SI.IsColumnMajor) { - Type *EltTy = cast(Constant->getType())->getElementType(); - Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows)); - - for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) { - if (auto *CDV = dyn_cast(Constant)) { - unsigned Width = SI.getStride(); - size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8; - StringRef Data = CDV->getRawDataValues().substr(J * Width * EltSize, - Width * EltSize); - addVector( - ConstantDataVector::getRaw(Data, Width, CDV->getElementType())); - } else if (isa(Constant)) - addVector(PoisonValue::get(RowTy)); - else if (isa(Constant)) - addVector(UndefValue::get(RowTy)); - else if (isa(Constant)) - addVector(ConstantAggregateZero::get(RowTy)); - else { -#ifndef NDEBUG - Constant->dump(); -#endif - report_fatal_error("unhandled ConstantData type"); - } - } - } Value *getVector(unsigned i) const { return Vectors[i]; } Value *getColumn(unsigned i) const { @@ -647,10 +620,6 @@ class LowerMatrixIntrinsics { MatrixVal = M.embedInVector(Builder); } - // If it's a constant, materialize the split version of it with this shape. - if (auto *IncomingConst = dyn_cast(MatrixVal)) - return MatrixTy(IncomingConst, SI); - // Otherwise split MatrixVal. SmallVector SplitVecs; for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll new file mode 100644 index 0000000000000..32a4d191897b1 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @ramp_2x2(ptr %out) { +; CHECK-LABEL: @ramp_2x2( +; CHECK-NEXT: store <2 x i32> , ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> , ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret void +; + call void @llvm.matrix.column.major.store(<4 x i32> , ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @poison_2x2(ptr %out) { +; CHECK-LABEL: @poison_2x2( +; CHECK-NEXT: store <2 x i32> poison, ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> poison, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret void +; + call void @llvm.matrix.column.major.store(<4 x i32> poison, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @undef_2x2(ptr %out) { +; CHECK-LABEL: @undef_2x2( +; CHECK-NEXT: store <2 x i32> undef, ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> undef, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret void +; + call void @llvm.matrix.column.major.store(<4 x i32> undef, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @zeroinitializer_2x2(ptr %out) { +; CHECK-LABEL: @zeroinitializer_2x2( +; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret void +; + call void @llvm.matrix.column.major.store(<4 x i32> zeroinitializer, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @ramp_bitcast(ptr %out) { +; CHECK-LABEL: @ramp_bitcast( +; CHECK-NEXT: store <2 x float> , ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> , ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: store <4 x float> , ptr [[OUT]], align 4 +; CHECK-NEXT: store <1 x float> splat (float 0x36A0000000000000), ptr [[OUT]], align 4 +; CHECK-NEXT: [[VEC_GEP1:%.*]] = getelementptr float, ptr [[OUT]], i64 4 +; CHECK-NEXT: store <1 x float> splat (float 0x36B0000000000000), ptr [[VEC_GEP1]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 8 +; CHECK-NEXT: store <1 x float> splat (float 0x36B8000000000000), ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[OUT]], i64 12 +; CHECK-NEXT: store <1 x float> splat (float 0x36C0000000000000), ptr [[VEC_GEP3]], align 4 +; CHECK-NEXT: ret void +; + %val = bitcast <4 x i32> to <4 x float> + call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 2, i1 false, i32 2, i32 2) + call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 4, i1 false, i32 4, i32 1) + call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 4, i1 false, i32 1, i32 4) + ret void +} From c9b39921d6830ec4e1e82d5da03069f4dbb38016 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 16 Jun 2025 14:04:01 -0700 Subject: [PATCH 19/26] split phi's in two phases --- .../Scalar/LowerMatrixIntrinsics.cpp | 60 +++++++++---------- .../Transforms/LowerMatrixIntrinsics/phi.ll | 40 +++++++------ 2 files changed, 52 insertions(+), 48 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index a4244e854f7a2..d6f570bdac803 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1138,7 +1138,27 @@ class LowerMatrixIntrinsics { Changed |= !FusedInsts.empty(); - // Fourth, lower remaining instructions with shape information. + // Fourth, pre-process all the PHINode's. The incoming values will be + // assigned later in VisitPHI. + for (Instruction *Inst : MatrixInsts) { + auto *PHI = dyn_cast(Inst); + if (!PHI) + continue; + + const ShapeInfo &SI = ShapeMap.at(Inst); + auto *EltTy = cast(PHI->getType())->getElementType(); + MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy); + + IRBuilder<> Builder(Inst); + for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) + PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), + PHI->getNumIncomingValues(), + PHI->getName())); + assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?"); + Inst2ColumnMatrix[PHI] = PhiM; + } + + // Fifth, lower remaining instructions with shape information. for (Instruction *Inst : MatrixInsts) { if (FusedInsts.count(Inst)) continue; @@ -2246,42 +2266,22 @@ class LowerMatrixIntrinsics { } MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { - // Shim this->getMatrix to insert split phi's as needed. - auto GetMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy { + // Shim this->getMatrix to adjust where it creates new instructions, which + // it may need to insert for re-shaping. + auto GetMatrix = [this, &Builder, SI, Inst](Value *MatrixVal) -> MatrixTy { IRBuilder<>::InsertPointGuard IPG(Builder); - - auto I = Inst2ColumnMatrix.find(MatrixVal); - if (I == Inst2ColumnMatrix.end()) { - if (auto *PHI = dyn_cast(MatrixVal)) { - auto *EltTy = cast(PHI->getType())->getElementType(); - MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy); - - Builder.SetInsertPoint(PHI); - for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) - PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), - PHI->getNumIncomingValues(), - PHI->getName())); - - Inst2ColumnMatrix[PHI] = PhiM; - } - } - - // getMatrix() may insert some instructions for reshaping. The safe place - // to insert them is at the end of the parent block, where the register - // allocator would have inserted the copies that materialize the PHI. - if (auto *MatrixInst = dyn_cast(MatrixVal)) - Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator()); + if (auto *MatrixInst = dyn_cast(MatrixVal)) { + if (auto MaybeIP = MatrixInst->getInsertionPointAfterDef()) + Builder.SetInsertPoint(*MaybeIP); + } else + Builder.SetInsertPoint(Inst->getIterator()); return this->getMatrix(MatrixVal, SI, Builder); }; MatrixTy PhiM = GetMatrix(Inst); - for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues(); - IncomingI != IncomingE; ++IncomingI) { - Value *IncomingV = Inst->getIncomingValue(IncomingI); - BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI); - + for (auto [IncomingV, IncomingB] : llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) { MatrixTy OpM = GetMatrix(IncomingV); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index 738325a0f438d..8470f33c2e298 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -26,11 +26,11 @@ define void @matrix_phi_loop(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> -; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 @@ -81,11 +81,11 @@ define void @matrix_phi_loop_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> -; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 @@ -171,11 +171,11 @@ define void @matrix_phi_loop_poison(ptr %in, i32 %count, ptr %out) { ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> -; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 @@ -225,11 +225,11 @@ define void @matrix_phi_loop_cdv(ptr %in, i32 %count, ptr %out) { ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <6 x double> [[TMP3]], <6 x double> [[TMP4]], <9 x i32> -; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: [[TMP0]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP1]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: [[TMP2]] = shufflevector <9 x double> [[TMP5]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 @@ -272,9 +272,9 @@ define void @matrix_phi_loop_delay(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[PHI14:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI15:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[PHI16:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[TMP0]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY]] ], [ [[TMP3:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[TMP1]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP4:%.*]], [[LOOP]] ] -; CHECK-NEXT: [[TMP2]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP5:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP0]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY]] ], [ [[SPLIT:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP1]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[SPLIT10:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP2]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[SPLIT11:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 8 ; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3 @@ -287,18 +287,21 @@ define void @matrix_phi_loop_delay(ptr %in1, ptr %in2, i32 %count, ptr %out) { ; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP7]], <6 x i32> ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <3 x double> [[TMP8]], <3 x double> poison, <6 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <6 x double> [[TMP9]], <6 x double> [[TMP10]], <9 x i32> +; CHECK-NEXT: [[SPLIT]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT10]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[SPLIT11]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = fadd <3 x double> [[TMP6]], [[TMP0]] +; CHECK-NEXT: [[TMP13:%.*]] = fadd <3 x double> [[TMP7]], [[TMP1]] +; CHECK-NEXT: [[TMP14:%.*]] = fadd <3 x double> [[TMP8]], [[TMP2]] ; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 -; CHECK-NEXT: [[TMP3]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP4]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> -; CHECK-NEXT: [[TMP5]] = shufflevector <9 x double> [[TMP11]], <9 x double> poison, <3 x i32> ; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[OUT:%.*]], align 128 +; CHECK-NEXT: store <3 x double> [[TMP12]], ptr [[OUT:%.*]], align 128 ; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3 -; CHECK-NEXT: store <3 x double> [[TMP7]], ptr [[VEC_GEP12]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP13]], ptr [[VEC_GEP12]], align 8 ; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6 -; CHECK-NEXT: store <3 x double> [[TMP8]], ptr [[VEC_GEP13]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP14]], ptr [[VEC_GEP13]], align 16 ; CHECK-NEXT: ret void ; entry: @@ -313,13 +316,14 @@ loop: %in2v = call <9 x double> @llvm.matrix.column.major.load(ptr %in2, i64 3, i1 false, i32 3, i32 3) %sum = fadd <9 x double> %phi2, %in2v + %sum2 = fadd <9 x double> %sum, %phi %dec = sub i32 %ctr, 1 %cmp = icmp eq i32 %dec, 0 br i1 %cmp, label %exit, label %loop exit: - store <9 x double> %sum, ptr %out + store <9 x double> %sum2, ptr %out ret void } From dd55682e33ed19ac5755b28c4c2ab1312fb4434c Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Mon, 16 Jun 2025 14:27:47 -0700 Subject: [PATCH 20/26] clang-format --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 5fb1403a7df82..f1288280d9167 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2280,7 +2280,8 @@ class LowerMatrixIntrinsics { MatrixTy PhiM = GetMatrix(Inst); - for (auto [IncomingV, IncomingB] : llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) { + for (auto [IncomingV, IncomingB] : + llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) { MatrixTy OpM = GetMatrix(IncomingV); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { From 5a991f7fd5dfb800dd81a9bd6e9487ff2ec1b2ea Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 17 Jun 2025 13:00:19 -0700 Subject: [PATCH 21/26] rm constant.ll --- .../LowerMatrixIntrinsics/constant.ll | 68 ------------------- 1 file changed, 68 deletions(-) delete mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll deleted file mode 100644 index 32a4d191897b1..0000000000000 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/constant.ll +++ /dev/null @@ -1,68 +0,0 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s - -define void @ramp_2x2(ptr %out) { -; CHECK-LABEL: @ramp_2x2( -; CHECK-NEXT: store <2 x i32> , ptr [[OUT:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x i32> , ptr [[VEC_GEP]], align 4 -; CHECK-NEXT: ret void -; - call void @llvm.matrix.column.major.store(<4 x i32> , ptr %out, i64 2, i1 false, i32 2, i32 2) - ret void -} - -define void @poison_2x2(ptr %out) { -; CHECK-LABEL: @poison_2x2( -; CHECK-NEXT: store <2 x i32> poison, ptr [[OUT:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x i32> poison, ptr [[VEC_GEP]], align 4 -; CHECK-NEXT: ret void -; - call void @llvm.matrix.column.major.store(<4 x i32> poison, ptr %out, i64 2, i1 false, i32 2, i32 2) - ret void -} - -define void @undef_2x2(ptr %out) { -; CHECK-LABEL: @undef_2x2( -; CHECK-NEXT: store <2 x i32> undef, ptr [[OUT:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x i32> undef, ptr [[VEC_GEP]], align 4 -; CHECK-NEXT: ret void -; - call void @llvm.matrix.column.major.store(<4 x i32> undef, ptr %out, i64 2, i1 false, i32 2, i32 2) - ret void -} - -define void @zeroinitializer_2x2(ptr %out) { -; CHECK-LABEL: @zeroinitializer_2x2( -; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[OUT:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[VEC_GEP]], align 4 -; CHECK-NEXT: ret void -; - call void @llvm.matrix.column.major.store(<4 x i32> zeroinitializer, ptr %out, i64 2, i1 false, i32 2, i32 2) - ret void -} - -define void @ramp_bitcast(ptr %out) { -; CHECK-LABEL: @ramp_bitcast( -; CHECK-NEXT: store <2 x float> , ptr [[OUT:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x float> , ptr [[VEC_GEP]], align 4 -; CHECK-NEXT: store <4 x float> , ptr [[OUT]], align 4 -; CHECK-NEXT: store <1 x float> splat (float 0x36A0000000000000), ptr [[OUT]], align 4 -; CHECK-NEXT: [[VEC_GEP1:%.*]] = getelementptr float, ptr [[OUT]], i64 4 -; CHECK-NEXT: store <1 x float> splat (float 0x36B0000000000000), ptr [[VEC_GEP1]], align 4 -; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 8 -; CHECK-NEXT: store <1 x float> splat (float 0x36B8000000000000), ptr [[VEC_GEP2]], align 4 -; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[OUT]], i64 12 -; CHECK-NEXT: store <1 x float> splat (float 0x36C0000000000000), ptr [[VEC_GEP3]], align 4 -; CHECK-NEXT: ret void -; - %val = bitcast <4 x i32> to <4 x float> - call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 2, i1 false, i32 2, i32 2) - call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 4, i1 false, i32 4, i32 1) - call void @llvm.matrix.column.major.store(<4 x float> %val, ptr %out, i64 4, i1 false, i32 1, i32 4) - ret void -} From e9d0e62c106589e339fc94cf472914ea7e9b73fc Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 17 Jun 2025 18:27:03 -0700 Subject: [PATCH 22/26] test that shows reshape shuffles are inserted in the correct spot --- .../Transforms/LowerMatrixIntrinsics/phi.ll | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll index 8470f33c2e298..9fdb2ce4dfa74 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll @@ -327,6 +327,70 @@ exit: ret void } +define void @matrix_phi_loop_delay_reshape(ptr %in1, ptr %in2, ptr %in3, i32 %count, ptr %out) { +; CHECK-LABEL: @matrix_phi_loop_delay_reshape( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[IN3:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN3]], i64 2 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: [[VEC_GEP1:%.*]] = getelementptr double, ptr [[IN3]], i64 4 +; CHECK-NEXT: [[COL_LOAD12:%.*]] = load <2 x double>, ptr [[VEC_GEP1]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <2 x double> [[COL_LOAD1]], <2 x double> [[COL_LOAD8]], <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD12]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[TMP0]], <4 x double> [[TMP1]], <6 x i32> +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[TMP2]], <6 x double> poison, <3 x i32> +; CHECK-NEXT: [[COL_LOAD10:%.*]] = shufflevector <6 x double> [[TMP2]], <6 x double> poison, <3 x i32> +; CHECK-NEXT: [[COL_LOAD11:%.*]] = load <6 x double>, ptr [[IN2:%.*]], align 8 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3 +; CHECK-NEXT: [[COL_LOAD14:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[PHI2:%.*]] = phi <3 x double> [ [[SPLIT]], [[ENTRY:%.*]] ], [ [[PHI1:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI39:%.*]] = phi <3 x double> [ [[COL_LOAD10]], [[ENTRY]] ], [ [[PHI4:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[PHI25:%.*]] = phi <6 x double> [ [[COL_LOAD11]], [[ENTRY]] ], [ [[PHI25]], [[LOOP]] ] +; CHECK-NEXT: [[PHI1]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY]] ], [ [[PHI2]], [[LOOP]] ] +; CHECK-NEXT: [[PHI4]] = phi <3 x double> [ [[COL_LOAD14]], [[ENTRY]] ], [ [[PHI39]], [[LOOP]] ] +; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <6 x double> [[PHI25]], <6 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <6 x double> [[PHI25]], <6 x double> poison, <3 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = fadd <3 x double> [[TMP3]], [[PHI1]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <3 x double> [[TMP4]], [[PHI4]] +; CHECK-NEXT: [[TMP5:%.*]] = fadd <3 x double> [[TMP7]], [[PHI2]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd <3 x double> [[TMP8]], [[PHI39]] +; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK: exit: +; CHECK-NEXT: store <3 x double> [[TMP5]], ptr [[OUT:%.*]], align 64 +; CHECK-NEXT: [[VEC_GEP30:%.*]] = getelementptr double, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <3 x double> [[TMP6]], ptr [[VEC_GEP30]], align 8 +; CHECK-NEXT: ret void +; +entry: + %in1v = call <6 x double> @llvm.matrix.column.major.load(ptr %in3, i64 2, i1 false, i32 2, i32 3) + %in2v = call <6 x double> @llvm.matrix.column.major.load(ptr %in2, i64 6, i1 false, i32 6, i32 1) + %in3v = call <6 x double> @llvm.matrix.column.major.load(ptr %in1, i64 3, i1 false, i32 3, i32 2) + br label %loop + +loop: + %phi = phi <6 x double> [%in1v, %entry], [%phi3, %loop] + %phi2 = phi <6 x double> [%in2v, %entry], [%phi2, %loop] + %phi3 = phi <6 x double> [%in3v, %entry], [%phi, %loop] + %ctr = phi i32 [%count, %entry], [%dec, %loop] + + %sum = fadd <6 x double> %phi2, %phi3 + %sum2 = fadd <6 x double> %sum, %phi + + %dec = sub i32 %ctr, 1 + %cmp = icmp eq i32 %dec, 0 + br i1 %cmp, label %exit, label %loop + +exit: + store <6 x double> %sum2, ptr %out + ret void +} + define void @matrix_phi_three_preds(i1 %cond1, i1 %cond2, ptr %a, ptr %b, ptr %c, ptr %out) { ; CHECK-LABEL: @matrix_phi_three_preds( ; CHECK-NEXT: entry: From a760840b2cb7704abcb141c63ff8374ba13034f2 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 17 Jun 2025 19:16:05 -0700 Subject: [PATCH 23/26] florian's suggestion is a little simpler: we already know it's a phi --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index f1288280d9167..db2a04354b498 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2292,10 +2292,7 @@ class LowerMatrixIntrinsics { // finalizeLowering() may also insert instructions in some cases. The safe // place for those is at the end of the initial block of PHIs. - auto IP = Inst->getInsertionPointAfterDef(); - assert(IP.has_value() && - "expected to find a valid insertion point after the phi"); - Builder.SetInsertPoint(*IP); + Builder.SetInsertPoint(Inst->getParent()->getFirstInsertionPt()); return PhiM; } From a62ef50d2ef206faf5b086c7f2e8c844aee8d54b Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 17 Jun 2025 19:21:19 -0700 Subject: [PATCH 24/26] also use getInsertionPointAtDef() for non-inst phi operand reshape placement --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index db2a04354b498..e5ce1a59d5888 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2272,8 +2272,8 @@ class LowerMatrixIntrinsics { if (auto *MatrixInst = dyn_cast(MatrixVal)) { if (auto MaybeIP = MatrixInst->getInsertionPointAfterDef()) Builder.SetInsertPoint(*MaybeIP); - } else - Builder.SetInsertPoint(Inst->getIterator()); + } else if (auto MaybeIP = Inst->getInsertionPointAfterDef()) + Builder.SetInsertPoint(*MaybeIP); return this->getMatrix(MatrixVal, SI, Builder); }; From 47dd2e28c86c76cccb21da7cec2efaef616c6810 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 17 Jun 2025 19:44:47 -0700 Subject: [PATCH 25/26] inline the lambda shim, to simplify --- .../Scalar/LowerMatrixIntrinsics.cpp | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index e5ce1a59d5888..678be0cf7301f 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2265,24 +2265,21 @@ class LowerMatrixIntrinsics { } MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { - // Shim this->getMatrix to adjust where it creates new instructions, which - // it may need to insert for re-shaping. - auto GetMatrix = [this, &Builder, SI, Inst](Value *MatrixVal) -> MatrixTy { - IRBuilder<>::InsertPointGuard IPG(Builder); - if (auto *MatrixInst = dyn_cast(MatrixVal)) { - if (auto MaybeIP = MatrixInst->getInsertionPointAfterDef()) - Builder.SetInsertPoint(*MaybeIP); - } else if (auto MaybeIP = Inst->getInsertionPointAfterDef()) - Builder.SetInsertPoint(*MaybeIP); - - return this->getMatrix(MatrixVal, SI, Builder); - }; - - MatrixTy PhiM = GetMatrix(Inst); + auto BlockIP = Inst->getParent()->getFirstInsertionPt(); + Builder.SetInsertPoint(BlockIP); + MatrixTy PhiM = getMatrix(Inst, SI, Builder); for (auto [IncomingV, IncomingB] : llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) { - MatrixTy OpM = GetMatrix(IncomingV); + // getMatrix() may insert some instructions to help with reshaping. The + // safest place for those is at the top of the block after the rest of the + // PHI's. Even better, if we can put it in the incoming block. + Builder.SetInsertPoint(BlockIP); + if (auto *IncomingInst = dyn_cast(IncomingV)) + if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef()) + Builder.SetInsertPoint(*MaybeIP); + + MatrixTy OpM = getMatrix(IncomingV, SI, Builder); for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { PHINode *NewPHI = cast(PhiM.getVector(VI)); @@ -2292,7 +2289,7 @@ class LowerMatrixIntrinsics { // finalizeLowering() may also insert instructions in some cases. The safe // place for those is at the end of the initial block of PHIs. - Builder.SetInsertPoint(Inst->getParent()->getFirstInsertionPt()); + Builder.SetInsertPoint(BlockIP); return PhiM; } From 0debc0841a0619d1c63c393faedfa548931dc795 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Wed, 18 Jun 2025 08:59:07 -0700 Subject: [PATCH 26/26] rm unnecessary include --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 678be0cf7301f..fa9e44617b7c8 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -32,7 +32,6 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h"