diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 7ef95800975db..90cd279e8a457 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1613,6 +1613,22 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Instruction *Overflow = foldLShrOverflowBit(I)) return Overflow; + // Transform ((pow2 << x) >> cttz(pow2 << y)) -> ((1 << x) >> y) + Value *Shl0_Op0, *Shl0_Op1, *Shl1_Op1; + BinaryOperator *Shl1; + if (match(Op0, m_Shl(m_Value(Shl0_Op0), m_Value(Shl0_Op1))) && + match(Op1, m_Intrinsic(m_BinOp(Shl1))) && + match(Shl1, m_Shl(m_Specific(Shl0_Op0), m_Value(Shl1_Op1))) && + isKnownToBeAPowerOfTwo(Shl0_Op0, /*OrZero=*/true, 0, &I)) { + auto *Shl0 = cast(Op0); + bool HasNUW = Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap(); + bool HasNSW = Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap(); + if (HasNUW || HasNSW) { + Value *NewShl = Builder.CreateShl(ConstantInt::get(Shl1->getType(), 1), + Shl0_Op1, "", HasNUW, HasNSW); + return BinaryOperator::CreateLShr(NewShl, Shl1_Op1); + } + } return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/shift-cttz-ctlz.ll b/llvm/test/Transforms/InstCombine/shift-cttz-ctlz.ll index 63caec9501325..e82e33e9d7f04 100644 --- a/llvm/test/Transforms/InstCombine/shift-cttz-ctlz.ll +++ b/llvm/test/Transforms/InstCombine/shift-cttz-ctlz.ll @@ -103,4 +103,34 @@ entry: ret i32 %res } +define i64 @fold_cttz_64() vscale_range(1,16) { +; CHECK-LABEL: define i64 @fold_cttz_64( +; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 4 +; +entry: + %vscale = tail call i64 @llvm.vscale.i64() + %shl0 = shl nuw nsw i64 %vscale, 4 + %shl1 = shl nuw nsw i64 %vscale, 2 + %cttz = tail call range(i64 2, 65) i64 @llvm.cttz.i64(i64 %shl1, i1 true) + %div1 = lshr i64 %shl0, %cttz + ret i64 %div1 +} + +define i32 @fold_cttz_32() vscale_range(1,16) { +; CHECK-LABEL: define i32 @fold_cttz_32( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i32 4 +; +entry: + %vscale = tail call i32 @llvm.vscale.i32() + %shl0 = shl nuw nsw i32 %vscale, 4 + %shl1 = shl nuw nsw i32 %vscale, 2 + %cttz = tail call range(i32 2, 65) i32 @llvm.cttz.i32(i32 %shl1, i1 true) + %div1 = lshr i32 %shl0, %cttz + ret i32 %div1 +} + declare void @use(i32)