-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[DAG] fold avgs(sext(x), sext(y))
-> sext(avgs(x, y))
#95365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: None (c8ef) ChangesFollow up of #95134. Context: #95134 (comment). Full diff: https://github.com/llvm/llvm-project/pull/95365.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 78970bc4fe4ab..0d4df4a7ecda5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
DAG.getShiftAmountConstant(1, VT, DL));
// fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
+ // fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))
if (sd_match(
N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
}
+ if (sd_match(
+ N, m_BinOp(ISD::AVGFLOORU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+ X.getValueType() == Y.getValueType() &&
+ hasOperation(ISD::AVGFLOORU, X.getValueType())) {
+ SDValue AvgFloorU = DAG.getNode(ISD::AVGFLOORU, DL, X.getValueType(), X, Y);
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorU);
+ }
+ if (sd_match(
+ N, m_BinOp(ISD::AVGCEILU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+ X.getValueType() == Y.getValueType() &&
+ hasOperation(ISD::AVGCEILU, X.getValueType())) {
+ SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilU);
+ }
// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index dc87708555987..e61b47772b7d7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -68,3 +68,87 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
%avg = sub <16 x i16> %or, %shift
ret <16 x i16> %avg
}
+
+define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgflooru:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT: shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT: shadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT: ret
+ %x0 = sext <16 x i8> %a0 to <16 x i16>
+ %x1 = sext <16 x i8> %a1 to <16 x i16>
+ %and = and <16 x i16> %x0, %x1
+ %xor = xor <16 x i16> %x0, %x1
+ %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+ %avg = add <16 x i16> %and, %shift
+ ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgflooru_mismatch:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll2 v2.8h, v1.16b, #0
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: sshll v3.8h, v0.8b, #0
+; CHECK-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT: shl v1.8h, v1.8h, #12
+; CHECK-NEXT: shl v2.8h, v2.8h, #12
+; CHECK-NEXT: sshr v4.8h, v1.8h, #12
+; CHECK-NEXT: sshr v1.8h, v2.8h, #12
+; CHECK-NEXT: shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT: shadd v0.8h, v3.8h, v4.8h
+; CHECK-NEXT: ret
+ %x0 = sext <16 x i8> %a0 to <16 x i16>
+ %x1 = sext <16 x i4> %a1 to <16 x i16>
+ %and = and <16 x i16> %x0, %x1
+ %xor = xor <16 x i16> %x0, %x1
+ %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+ %avg = add <16 x i16> %and, %shift
+ ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT: srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT: srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT: ret
+ %x0 = sext <16 x i8> %a0 to <16 x i16>
+ %x1 = sext <16 x i8> %a1 to <16 x i16>
+ %or = or <16 x i16> %x0, %x1
+ %xor = xor <16 x i16> %x0, %x1
+ %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+ %avg = sub <16 x i16> %or, %shift
+ ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu_mismatch:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT: shl v2.8h, v2.8h, #12
+; CHECK-NEXT: shl v0.8h, v0.8h, #12
+; CHECK-NEXT: sshr v2.8h, v2.8h, #12
+; CHECK-NEXT: sshr v0.8h, v0.8h, #12
+; CHECK-NEXT: srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT: srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT: ret
+ %x0 = sext <16 x i4> %a0 to <16 x i16>
+ %x1 = sext <16 x i8> %a1 to <16 x i16>
+ %or = or <16 x i16> %x0, %x1
+ %xor = xor <16 x i16> %x0, %x1
+ %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+ %avg = sub <16 x i16> %or, %shift
+ ret <16 x i16> %avg
+}
|
No, this needs to be: avgs(sext(x), sext(y)) -> sext(avgs(x, y)) |
avgu(sext(x), sext(y))
-> sext(avgu(x, y))
avgs(sext(x), sext(y))
-> sext(avgs(x, y))
Sorry, my bad. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
Context: #95365 (comment) The current implementation of `m_SExt` matches both `ISD::SIGN_EXTEND` and `ISD::SIGN_EXTEND_INREG`. However, in cases where we specifically need to match _only_ `ISD::SIGN_EXTEND`, such as in the SelectionDAG graph below, this can lead to issues and unintended combinations. ``` SelectionDAG has 13 nodes: t0: ch,glue = EntryToken t2: v2i32,ch = CopyFromReg t0, Register:v2i32 %0 t21: v2i32 = sign_extend_inreg t2, ValueType:ch:v2i8 t4: v2i32,ch = CopyFromReg t0, Register:v2i32 %1 t22: v2i32 = sign_extend_inreg t4, ValueType:ch:v2i8 t23: v2i32 = avgfloors t21, t22 t24: v2i32 = sign_extend_inreg t23, ValueType:ch:v2i8 t15: ch,glue = CopyToReg t0, Register:v2i32 $d0, t24 t16: ch = AArch64ISD::RET_GLUE t15, Register:v2i32 $d0, t15:1 ```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please can you rebase + regenerate the tests now that #95415 has landed
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
Follow up of #95134.
Context: #95134 (comment).