diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index 03d7ad12c12d8..eacc9a25086cb 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -264,6 +264,17 @@ simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, SmallVectorImpl *DropFlags = nullptr); +/// In the case of a binary operation with a select instruction as an operand, +/// try to simplify the binop by seeing whether evaluating it on both branches +/// of the select results in the same value. Returns the common value if so, +/// otherwise returns null. +/// +/// If DropFlags is passed, then the result is only valid if +/// poison-generating flags/metadata on the instruction are dropped. +Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + Instruction **DropFlags); + /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively. /// /// This first performs a normal RAUW of I with SimpleV. It then recursively diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 9ff3faff79902..cee84793d4326 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -393,7 +393,8 @@ static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, /// otherwise returns null. static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, - unsigned MaxRecurse) { + unsigned MaxRecurse, + Instruction **DropFlags = nullptr) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -447,6 +448,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue(); Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS; Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch; + + if (Simplified->hasPoisonGeneratingFlags()) { + if (!DropFlags) + return nullptr; + *DropFlags = Simplified; + } + if (Simplified->getOperand(0) == UnsimplifiedLHS && Simplified->getOperand(1) == UnsimplifiedRHS) return Simplified; @@ -460,6 +468,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, return nullptr; } +Value *llvm::threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + Instruction **DropFlags) { + return ::threadBinOpOverSelect(Opcode, LHS, RHS, Q, RecursionLimit, + DropFlags); +} + /// In the case of a comparison with a select instruction, try to simplify the /// comparison by seeing whether both branches of the select result in the same /// value. Returns the common value if so, otherwise returns null. diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 7c40fb4fc8608..d3ac2762d64f4 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1964,6 +1964,15 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { if (auto *Sel = dyn_cast(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) return NewSel; + + const SimplifyQuery SQ = getSimplifyQuery().getWithInstruction(&I); + Instruction *DropFlags = nullptr; + if (Value *V = threadBinOpOverSelect(I.getOpcode(), I.getOperand(0), + I.getOperand(1), SQ, &DropFlags)) { + if (DropFlags) + DropFlags->dropPoisonGeneratingFlags(); + return replaceInstUsesWith(I, V); + } } else if (auto *PN = dyn_cast(I.getOperand(0))) { if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; diff --git a/llvm/test/Transforms/InstCombine/pr87042.ll b/llvm/test/Transforms/InstCombine/pr87042.ll new file mode 100644 index 0000000000000..b3d6821e43c3d --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pr87042.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S -passes=instcombine < %s | FileCheck %s + +define i64 @test_disjoint_or(i1 %cond, i64 %x) { +; CHECK-LABEL: define i64 @test_disjoint_or( +; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) { +; CHECK-NEXT: [[OR2:%.*]] = or i64 [[X]], 7 +; CHECK-NEXT: ret i64 [[OR2]] +; + %or1 = or disjoint i64 %x, 7 + %sel1 = select i1 %cond, i64 %or1, i64 %x + %or2 = or i64 %sel1, 7 + ret i64 %or2 +} + +define i64 @test_or(i1 %cond, i64 %x) { +; CHECK-LABEL: define i64 @test_or( +; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) { +; CHECK-NEXT: [[OR2:%.*]] = or i64 [[X]], 7 +; CHECK-NEXT: ret i64 [[OR2]] +; + %or1 = or i64 %x, 7 + %sel1 = select i1 %cond, i64 %or1, i64 %x + %or2 = or i64 %sel1, 7 + ret i64 %or2 +} + +define i64 @pr87042(i64 %x) { +; CHECK-LABEL: define i64 @pr87042( +; CHECK-SAME: i64 [[X:%.*]]) { +; CHECK-NEXT: [[AND1:%.*]] = and i64 [[X]], 65535 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[AND1]], 0 +; CHECK-NEXT: [[OR1:%.*]] = or i64 [[X]], 7 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i64 [[OR1]], i64 [[X]] +; CHECK-NEXT: [[AND2:%.*]] = and i64 [[SEL1]], 16776960 +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i64 [[AND2]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i64 [[OR1]], i64 [[SEL1]] +; CHECK-NEXT: ret i64 [[SEL2]] +; + %and1 = and i64 %x, 65535 + %cmp1 = icmp eq i64 %and1, 0 + %or1 = or disjoint i64 %x, 7 + %sel1 = select i1 %cmp1, i64 %or1, i64 %x + %and2 = and i64 %sel1, 16776960 + %cmp2 = icmp eq i64 %and2, 0 + %or2 = or i64 %sel1, 7 + %sel2 = select i1 %cmp2, i64 %or2, i64 %sel1 + ret i64 %sel2 +}