diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f6992119280c1..f8be8edf21829 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -785,6 +785,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final void handlePotentiallyDeadBlocks(SmallVectorImpl &Worklist); void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); + + // Take the exact integer log2 of the value. If DoFold is true, create the + // actual instructions, otherwise return a non-null dummy value. Return + // nullptr on failure. Note, if DoFold is true the caller must ensure that + // takeLog2 will succeed, otherwise it may create stray instructions. + Value *takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero, bool DoFold); + + Value *tryGetLog2(Value *Op, bool AssumeNonZero) { + if (takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/false)) + return takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/true); + return nullptr; + } }; class Negator final { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 0c34cf01bdf1a..1c5070a1b867c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -185,9 +185,6 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, return nullptr; } -static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool AssumeNonZero, bool DoFold); - Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -531,19 +528,13 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // (shl Op1, Log2(Op0)) // if Log2(Op1) folds away -> // (shl Op0, Log2(Op1)) - if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, - /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, - /*DoFold*/ true); + if (Value *Res = tryGetLog2(Op0, /*AssumeNonZero=*/false)) { BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); return Shl; } - if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, - /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, - /*DoFold*/ true); + if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/false)) { BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); @@ -1407,13 +1398,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { return nullptr; } -static const unsigned MaxDepth = 6; - -// Take the exact integer log2 of the value. If DoFold is true, create the -// actual instructions, otherwise return a non-null dummy value. Return nullptr -// on failure. -static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool AssumeNonZero, bool DoFold) { +Value *InstCombinerImpl::takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero, + bool DoFold) { auto IfFold = [DoFold](function_ref Fn) { if (!DoFold) return reinterpret_cast(-1); @@ -1432,14 +1418,14 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, }); // The remaining tests are all recursive, so bail out if we hit the limit. - if (Depth++ == MaxDepth) + if (Depth++ == MaxAnalysisRecursionDepth) return nullptr; // log2(zext X) -> zext log2(X) // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(trunc x) -> trunc log2(X) @@ -1447,7 +1433,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, if (match(Op, m_Trunc(m_Value(X)))) { auto *TI = cast(Op); if (AssumeNonZero || TI->hasNoUnsignedWrap()) - if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateTrunc(LogX, Op->getType(), "", /*IsNUW=*/TI->hasNoUnsignedWrap()); @@ -1460,7 +1446,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, auto *BO = cast(Op); // nuw will be set if the `shl` is trivially non-zero. if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) - if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); } @@ -1469,26 +1455,25 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) { auto *PEO = cast(Op); if (AssumeNonZero || PEO->isExact()) - if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSub(LogX, Y); }); } // log2(X & Y) -> either log2(X) or log2(Y) // This requires `AssumeNonZero` as `X & Y` may be zero when X != Y. if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) { - if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return LogX; }); - if (Value *LogY = takeLog2(Builder, Y, Depth, AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Y, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return LogY; }); } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: Require one use? if (SelectInst *SI = dyn_cast(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, - AssumeNonZero, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, - AssumeNonZero, DoFold)) + if (Value *LogX = takeLog2(SI->getOperand(1), Depth, AssumeNonZero, DoFold)) + if (Value *LogY = + takeLog2(SI->getOperand(2), Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1499,9 +1484,9 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { // Use AssumeNonZero as false here. Otherwise we can hit case where // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + if (Value *LogX = takeLog2(MinMax->getLHS(), Depth, /*AssumeNonZero*/ false, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + if (Value *LogY = takeLog2(MinMax->getRHS(), Depth, /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, @@ -1614,13 +1599,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, - /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, - /*AssumeNonZero*/ true, /*DoFold*/ true); + if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/true)) return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); - } return nullptr; }