From e10966f1484d655c9ca532e3942cde38acc0328a Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 28 May 2024 09:22:05 -0700 Subject: [PATCH 1/5] [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC I plan to add other combines on TRUNCATE_VECTOR_VL. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 103 ++++++++++---------- 1 file changed, 53 insertions(+), 50 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f0e5a7d393b6c..47b1cc1ba6460 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, return true; } +static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { + // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) + // This would be benefit for the cases where X and Y are both the same value + // type of low precision vectors. Since the truncate would be lowered into + // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate + // restriction, such pattern would be expanded into a series of "vsetvli" + // and "vnsrl" instructions later to reach this point. + auto IsTruncNode = [](SDValue V) { + if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + return false; + SDValue VL = V.getOperand(2); + auto *C = dyn_cast(VL); + // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand + bool IsVLMAXForVMSET = (C && C->isAllOnes()) || + (isa(VL) && + cast(VL)->getReg() == RISCV::X0); + return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET; + }; + + SDValue Op = N->getOperand(0); + + // We need to first find the inner level of TRUNCATE_VECTOR_VL node + // to distinguish such pattern. + while (IsTruncNode(Op)) { + if (!Op.hasOneUse()) + return SDValue(); + Op = Op.getOperand(0); + } + + if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse()) + return SDValue(); + + SDValue N0 = Op.getOperand(0); + SDValue N1 = Op.getOperand(1); + if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() || + N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse()) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N10 = N1.getOperand(0); + if (!N00.getValueType().isVector() || + N00.getValueType() != N10.getValueType() || + N->getValueType(0) != N10.getValueType()) + return SDValue(); + + unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1; + SDValue SMin = + DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10, + DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0))); + return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin); +} SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { @@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } } return SDValue(); - case RISCVISD::TRUNCATE_VECTOR_VL: { - // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) - // This would be benefit for the cases where X and Y are both the same value - // type of low precision vectors. Since the truncate would be lowered into - // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate - // restriction, such pattern would be expanded into a series of "vsetvli" - // and "vnsrl" instructions later to reach this point. - auto IsTruncNode = [](SDValue V) { - if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) - return false; - SDValue VL = V.getOperand(2); - auto *C = dyn_cast(VL); - // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand - bool IsVLMAXForVMSET = (C && C->isAllOnes()) || - (isa(VL) && - cast(VL)->getReg() == RISCV::X0); - return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && - IsVLMAXForVMSET; - }; - - SDValue Op = N->getOperand(0); - - // We need to first find the inner level of TRUNCATE_VECTOR_VL node - // to distinguish such pattern. - while (IsTruncNode(Op)) { - if (!Op.hasOneUse()) - return SDValue(); - Op = Op.getOperand(0); - } - - if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) { - SDValue N0 = Op.getOperand(0); - SDValue N1 = Op.getOperand(1); - if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && - N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) { - SDValue N00 = N0.getOperand(0); - SDValue N10 = N1.getOperand(0); - if (N00.getValueType().isVector() && - N00.getValueType() == N10.getValueType() && - N->getValueType(0) == N10.getValueType()) { - unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1; - SDValue SMin = DAG.getNode( - ISD::SMIN, SDLoc(N1), N->getValueType(0), N10, - DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0))); - return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin); - } - } - } - break; - } + case RISCVISD::TRUNCATE_VECTOR_VL: + return combineTruncOfSraSext(N, DAG); case ISD::TRUNCATE: return performTRUNCATECombine(N, DAG, Subtarget); case ISD::SELECT: From 4e227983c1e3c290724f09e4968610e7b0c21689 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 28 May 2024 09:55:48 -0700 Subject: [PATCH 2/5] [RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. We checked the VL and mask of any additional TRUNCATE_VECTOR_VL nodes we peek through, but not the outermost. This moves the check to the outer node and then verifies all the additional nodes have the same VL and Mask. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 47b1cc1ba6460..288e874276e07 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16088,22 +16088,25 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, } static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { + SDValue Mask = N->getOperand(1); + SDValue VL = N->getOperand(2); + + bool IsVLMAX = isAllOnesConstant(VL) || + (isa(VL) && + cast(VL)->getReg() == RISCV::X0); + if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL || + Mask.getOperand(0) != VL) + return SDValue(); + // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) // This would be benefit for the cases where X and Y are both the same value // type of low precision vectors. Since the truncate would be lowered into // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate // restriction, such pattern would be expanded into a series of "vsetvli" // and "vnsrl" instructions later to reach this point. - auto IsTruncNode = [](SDValue V) { - if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) - return false; - SDValue VL = V.getOperand(2); - auto *C = dyn_cast(VL); - // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand - bool IsVLMAXForVMSET = (C && C->isAllOnes()) || - (isa(VL) && - cast(VL)->getReg() == RISCV::X0); - return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET; + auto IsTruncNode = [&](SDValue V) { + return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL && + V.getOperand(1) == Mask && V.getOperand(2) == VL; }; SDValue Op = N->getOperand(0); From 78777ec5442cd9a71c639ce685512e699241200b Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 28 May 2024 16:12:22 -0700 Subject: [PATCH 3/5] fixup! move comment. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6a9a61e480294..f4da46f82a810 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16128,6 +16128,12 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, return true; } +// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) +// This would be benefit for the cases where X and Y are both the same value +// type of low precision vectors. Since the truncate would be lowered into +// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate +// restriction, such pattern would be expanded into a series of "vsetvli" +// and "vnsrl" instructions later to reach this point. static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { SDValue Mask = N->getOperand(1); SDValue VL = N->getOperand(2); @@ -16139,12 +16145,6 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { Mask.getOperand(0) != VL) return SDValue(); - // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) - // This would be benefit for the cases where X and Y are both the same value - // type of low precision vectors. Since the truncate would be lowered into - // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate - // restriction, such pattern would be expanded into a series of "vsetvli" - // and "vnsrl" instructions later to reach this point. auto IsTruncNode = [&](SDValue V) { return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL && V.getOperand(1) == Mask && V.getOperand(2) == VL; From 0a0682e168dd275e2fd7139bbd6c5ca472418630 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 28 May 2024 22:58:30 -0700 Subject: [PATCH 4/5] fixup! Update new test. --- llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll index 8dbb57fd15cf1..1bd83734a03cb 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s -; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK +; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK define @vsra_vv_nxv1i8( %va, %vb) { ; CHECK-LABEL: vsra_vv_nxv1i8: @@ -937,13 +937,17 @@ define @vsra_vi_mask_nxv8i32( %va, @vsra_vv_nxv1i8_sext_zext_mixed_trunc( %va, %vb, %m, i32 %evl) { +define @vsra_vv_nxv1i8_sext_zext_mixed_trunc( %va, %vb, %m, i32 zeroext %evl) { ; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext_mixed_trunc: ; CHECK: # %bb.0: -; CHECK-NEXT: li a0, 7 -; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma -; CHECK-NEXT: vmin.vx v9, v8, a0 -; CHECK-NEXT: vsra.vv v8, v8, v9 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vsext.vf4 v9, v8 +; CHECK-NEXT: vzext.vf4 v10, v8 +; CHECK-NEXT: vsra.vv v8, v9, v10 +; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t ; CHECK-NEXT: ret %sexted_va = sext %va to %zexted_vb = zext %va to From fc4e2d4153f12c722dff70e41ed8f1e421f6c2f4 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Wed, 29 May 2024 09:03:11 -0700 Subject: [PATCH 5/5] fixup! Remove unnecessary check-prefix. --- llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll index 1bd83734a03cb..382c8297473b7 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK -; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK +; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s define @vsra_vv_nxv1i8( %va, %vb) { ; CHECK-LABEL: vsra_vv_nxv1i8: