Skip to content

[DAG] fold avgs(sext(x), sext(y)) -> sext(avgs(x, y)) #95365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
DAG.getShiftAmountConstant(1, VT, DL));

// fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
// fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
if (sd_match(
N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
X.getValueType() == Y.getValueType() &&
Expand All @@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
}
if (sd_match(
N, m_BinOp(ISD::AVGFLOORS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
X.getValueType() == Y.getValueType() &&
hasOperation(ISD::AVGFLOORS, X.getValueType())) {
SDValue AvgFloorS = DAG.getNode(ISD::AVGFLOORS, DL, X.getValueType(), X, Y);
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorS);
}
if (sd_match(
N, m_BinOp(ISD::AVGCEILS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
X.getValueType() == Y.getValueType() &&
hasOperation(ISD::AVGCEILS, X.getValueType())) {
SDValue AvgCeilS = DAG.getNode(ISD::AVGCEILS, DL, X.getValueType(), X, Y);
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilS);
}

// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ define <8 x i16> @urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-LABEL: hadds_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = sext <8 x i8> %a0 to <8 x i16>
Expand All @@ -110,9 +109,8 @@ define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
define <8 x i16> @shaddu_sext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-LABEL: shaddu_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NEXT: srhadd v0.8h, v0.8h, v1.8h
; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = sext <8 x i8> %a0 to <8 x i16>
Expand Down
78 changes: 78 additions & 0 deletions llvm/test/CodeGen/AArch64/avg.ll
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,81 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
%avg = sub <16 x i16> %or, %shift
ret <16 x i16> %avg
}

define <16 x i16> @sext_avgfloors(<16 x i8> %a0, <16 x i8> %a1) {
; CHECK-LABEL: sext_avgfloors:
; CHECK: // %bb.0:
; CHECK-NEXT: shadd v0.16b, v0.16b, v1.16b
; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
; CHECK-NEXT: ret
%x0 = sext <16 x i8> %a0 to <16 x i16>
%x1 = sext <16 x i8> %a1 to <16 x i16>
%and = and <16 x i16> %x0, %x1
%xor = xor <16 x i16> %x0, %x1
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%avg = add <16 x i16> %and, %shift
ret <16 x i16> %avg
}

define <16 x i16> @sext_avgfloors_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
; CHECK-LABEL: sext_avgfloors_mismatch:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll2 v2.8h, v1.16b, #0
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: sshll v3.8h, v0.8b, #0
; CHECK-NEXT: sshll2 v0.8h, v0.16b, #0
; CHECK-NEXT: shl v1.8h, v1.8h, #12
; CHECK-NEXT: shl v2.8h, v2.8h, #12
; CHECK-NEXT: sshr v4.8h, v1.8h, #12
; CHECK-NEXT: sshr v1.8h, v2.8h, #12
; CHECK-NEXT: shadd v1.8h, v0.8h, v1.8h
; CHECK-NEXT: shadd v0.8h, v3.8h, v4.8h
; CHECK-NEXT: ret
%x0 = sext <16 x i8> %a0 to <16 x i16>
%x1 = sext <16 x i4> %a1 to <16 x i16>
%and = and <16 x i16> %x0, %x1
%xor = xor <16 x i16> %x0, %x1
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%avg = add <16 x i16> %and, %shift
ret <16 x i16> %avg
}

define <16 x i16> @sext_avgceils(<16 x i8> %a0, <16 x i8> %a1) {
; CHECK-LABEL: sext_avgceils:
; CHECK: // %bb.0:
; CHECK-NEXT: srhadd v0.16b, v0.16b, v1.16b
; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
; CHECK-NEXT: ret
%x0 = sext <16 x i8> %a0 to <16 x i16>
%x1 = sext <16 x i8> %a1 to <16 x i16>
%or = or <16 x i16> %x0, %x1
%xor = xor <16 x i16> %x0, %x1
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%avg = sub <16 x i16> %or, %shift
ret <16 x i16> %avg
}

define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
; CHECK-LABEL: sext_avgceils_mismatch:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-NEXT: sshll v3.8h, v1.8b, #0
; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NEXT: shl v2.8h, v2.8h, #12
; CHECK-NEXT: shl v0.8h, v0.8h, #12
; CHECK-NEXT: sshr v2.8h, v2.8h, #12
; CHECK-NEXT: sshr v0.8h, v0.8h, #12
; CHECK-NEXT: srhadd v1.8h, v0.8h, v1.8h
; CHECK-NEXT: srhadd v0.8h, v2.8h, v3.8h
; CHECK-NEXT: ret
%x0 = sext <16 x i4> %a0 to <16 x i16>
%x1 = sext <16 x i8> %a1 to <16 x i16>
%or = or <16 x i16> %x0, %x1
%xor = xor <16 x i16> %x0, %x1
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%avg = sub <16 x i16> %or, %shift
ret <16 x i16> %avg
}
Loading