From 9707684e7714b1d485a31906ff3e17d6f6f94ace Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 1 Apr 2024 19:20:07 +0800 Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests for PR87042. NFC. --- llvm/test/Transforms/InstCombine/pr87042.ll | 49 +++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/pr87042.ll diff --git a/llvm/test/Transforms/InstCombine/pr87042.ll b/llvm/test/Transforms/InstCombine/pr87042.ll new file mode 100644 index 0000000000000..d9624faaedfa5 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pr87042.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S -passes=instcombine < %s | FileCheck %s + +define i64 @test_disjoint_or(i1 %cond, i64 %x) { +; CHECK-LABEL: define i64 @test_disjoint_or( +; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) { +; CHECK-NEXT: [[OR2:%.*]] = or disjoint i64 [[X]], 7 +; CHECK-NEXT: ret i64 [[OR2]] +; + %or1 = or disjoint i64 %x, 7 + %sel1 = select i1 %cond, i64 %or1, i64 %x + %or2 = or i64 %sel1, 7 + ret i64 %or2 +} + +define i64 @test_or(i1 %cond, i64 %x) { +; CHECK-LABEL: define i64 @test_or( +; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) { +; CHECK-NEXT: [[OR2:%.*]] = or i64 [[X]], 7 +; CHECK-NEXT: ret i64 [[OR2]] +; + %or1 = or i64 %x, 7 + %sel1 = select i1 %cond, i64 %or1, i64 %x + %or2 = or i64 %sel1, 7 + ret i64 %or2 +} + +define i64 @pr87042(i64 %x) { +; CHECK-LABEL: define i64 @pr87042( +; CHECK-SAME: i64 [[X:%.*]]) { +; CHECK-NEXT: [[AND1:%.*]] = and i64 [[X]], 65535 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[AND1]], 0 +; CHECK-NEXT: [[OR1:%.*]] = or disjoint i64 [[X]], 7 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i64 [[OR1]], i64 [[X]] +; CHECK-NEXT: [[AND2:%.*]] = and i64 [[SEL1]], 16776960 +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i64 [[AND2]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i64 [[OR1]], i64 [[SEL1]] +; CHECK-NEXT: ret i64 [[SEL2]] +; + %and1 = and i64 %x, 65535 + %cmp1 = icmp eq i64 %and1, 0 + %or1 = or disjoint i64 %x, 7 + %sel1 = select i1 %cmp1, i64 %or1, i64 %x + %and2 = and i64 %sel1, 16776960 + %cmp2 = icmp eq i64 %and2, 0 + %or2 = or i64 %sel1, 7 + %sel2 = select i1 %cmp2, i64 %or2, i64 %sel1 + ret i64 %sel2 +} From 900106267f135c4e3c5a6f744096e61dd94f3e44 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 1 Apr 2024 19:21:02 +0800 Subject: [PATCH 2/2] [InstCombine] Drop poison-generating flags in `threadBinOpOverSelect` --- .../include/llvm/Analysis/InstructionSimplify.h | 11 +++++++++++ llvm/lib/Analysis/InstructionSimplify.cpp | 17 ++++++++++++++++- .../InstCombine/InstructionCombining.cpp | 9 +++++++++ llvm/test/Transforms/InstCombine/pr87042.ll | 4 ++-- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index 03d7ad12c12d8..eacc9a25086cb 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -264,6 +264,17 @@ simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, SmallVectorImpl *DropFlags = nullptr); +/// In the case of a binary operation with a select instruction as an operand, +/// try to simplify the binop by seeing whether evaluating it on both branches +/// of the select results in the same value. Returns the common value if so, +/// otherwise returns null. +/// +/// If DropFlags is passed, then the result is only valid if +/// poison-generating flags/metadata on the instruction are dropped. +Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + Instruction **DropFlags); + /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively. /// /// This first performs a normal RAUW of I with SimpleV. It then recursively diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 9ff3faff79902..cee84793d4326 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -393,7 +393,8 @@ static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, /// otherwise returns null. static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, - unsigned MaxRecurse) { + unsigned MaxRecurse, + Instruction **DropFlags = nullptr) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -447,6 +448,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue(); Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS; Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch; + + if (Simplified->hasPoisonGeneratingFlags()) { + if (!DropFlags) + return nullptr; + *DropFlags = Simplified; + } + if (Simplified->getOperand(0) == UnsimplifiedLHS && Simplified->getOperand(1) == UnsimplifiedRHS) return Simplified; @@ -460,6 +468,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, return nullptr; } +Value *llvm::threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + Instruction **DropFlags) { + return ::threadBinOpOverSelect(Opcode, LHS, RHS, Q, RecursionLimit, + DropFlags); +} + /// In the case of a comparison with a select instruction, try to simplify the /// comparison by seeing whether both branches of the select result in the same /// value. Returns the common value if so, otherwise returns null. diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 7c40fb4fc8608..d3ac2762d64f4 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1964,6 +1964,15 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { if (auto *Sel = dyn_cast(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) return NewSel; + + const SimplifyQuery SQ = getSimplifyQuery().getWithInstruction(&I); + Instruction *DropFlags = nullptr; + if (Value *V = threadBinOpOverSelect(I.getOpcode(), I.getOperand(0), + I.getOperand(1), SQ, &DropFlags)) { + if (DropFlags) + DropFlags->dropPoisonGeneratingFlags(); + return replaceInstUsesWith(I, V); + } } else if (auto *PN = dyn_cast(I.getOperand(0))) { if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; diff --git a/llvm/test/Transforms/InstCombine/pr87042.ll b/llvm/test/Transforms/InstCombine/pr87042.ll index d9624faaedfa5..b3d6821e43c3d 100644 --- a/llvm/test/Transforms/InstCombine/pr87042.ll +++ b/llvm/test/Transforms/InstCombine/pr87042.ll @@ -4,7 +4,7 @@ define i64 @test_disjoint_or(i1 %cond, i64 %x) { ; CHECK-LABEL: define i64 @test_disjoint_or( ; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) { -; CHECK-NEXT: [[OR2:%.*]] = or disjoint i64 [[X]], 7 +; CHECK-NEXT: [[OR2:%.*]] = or i64 [[X]], 7 ; CHECK-NEXT: ret i64 [[OR2]] ; %or1 = or disjoint i64 %x, 7 @@ -30,7 +30,7 @@ define i64 @pr87042(i64 %x) { ; CHECK-SAME: i64 [[X:%.*]]) { ; CHECK-NEXT: [[AND1:%.*]] = and i64 [[X]], 65535 ; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[AND1]], 0 -; CHECK-NEXT: [[OR1:%.*]] = or disjoint i64 [[X]], 7 +; CHECK-NEXT: [[OR1:%.*]] = or i64 [[X]], 7 ; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i64 [[OR1]], i64 [[X]] ; CHECK-NEXT: [[AND2:%.*]] = and i64 [[SEL1]], 16776960 ; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i64 [[AND2]], 0