Skip to content

Commit f68dd9d

Browse files
committed
[RISCV] Implement EmitTargetCodeForStrcmp for unaligned case.
In case when strings are unaligned and of the arguments is a known constant string we specialize the `strcmp` function. First, we check the above two conditions in `EmitTargetCodeForStrcmp` and if they are satisfied we emit target node `RISCVISD::STRCMP`. The node has additional argument to indicate which of the strings (first or second) was constant. During `ISel` we match it to the pseudo instruction `PseudoSTRCMPI`. Finally, during `FinalizeLowering` we expand the pseudo into code. This optimization is triggered about 2000 times on C/C++ spec2017 benchmarks, but unfortunately it doesn't have any noticable performance impact on the dynamic instruction count. This optimization is off by default. Note that gcc already does this.
1 parent d1a461d commit f68dd9d

File tree

9 files changed

+723
-1
lines changed

9 files changed

+723
-1
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ add_llvm_target(RISCVCodeGen
5252
RISCVPushPopOptimizer.cpp
5353
RISCVRegisterInfo.cpp
5454
RISCVSubtarget.cpp
55+
RISCVSelectionDAGTargetInfo.cpp
5556
RISCVTargetMachine.cpp
5657
RISCVTargetObjectFile.cpp
5758
RISCVTargetTransformInfo.cpp

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17655,6 +17655,167 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
1765517655
return DoneMBB;
1765617656
}
1765717657

17658+
static MachineBasicBlock *emitSTRCMPI(MachineInstr &MI, MachineBasicBlock *MBB,
17659+
const RISCVSubtarget &Subtarget) {
17660+
17661+
const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
17662+
MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
17663+
MachineFunction &MF = *MI.getParent()->getParent();
17664+
DebugLoc DL = MI.getDebugLoc();
17665+
17666+
const GlobalVariable *GV = cast<GlobalVariable>(MI.getOperand(2).getGlobal());
17667+
StringRef Str = cast<ConstantDataArray>(GV->getInitializer())->getAsCString();
17668+
int NumOfBytes = Str.str().length();
17669+
const BasicBlock *LLVM_BB = MBB->getBasicBlock();
17670+
MachineFunction::iterator MBBI = ++MBB->getIterator();
17671+
17672+
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(LLVM_BB);
17673+
MF.insert(MBBI, ExitMBB);
17674+
ExitMBB->splice(ExitMBB->end(), MBB, std::next(MI.getIterator()), MBB->end());
17675+
ExitMBB->transferSuccessorsAndUpdatePHIs(MBB);
17676+
MBBI = ExitMBB->getIterator();
17677+
17678+
// In the code below we assume that the constant string is second argument
17679+
// and negate the result if needed.
17680+
bool NeedToNegateResult = MI.getOperand(3).getImm() == 0;
17681+
Register PHIReg = NeedToNegateResult
17682+
? MRI.createVirtualRegister(&RISCV::GPRRegClass)
17683+
: MI.getOperand(0).getReg();
17684+
MachineInstrBuilder PHI_MIB =
17685+
BuildMI(*ExitMBB, ExitMBB->begin(), DL, TII.get(RISCV::PHI), PHIReg);
17686+
if (NeedToNegateResult) {
17687+
BuildMI(*ExitMBB, ++ExitMBB->begin(), DL, TII.get(RISCV::SUB),
17688+
MI.getOperand(0).getReg())
17689+
.addReg(RISCV::X0)
17690+
.addReg(PHIReg);
17691+
}
17692+
17693+
MachineBasicBlock *ReturnEarlyNullByteMBB =
17694+
MF.CreateMachineBasicBlock(LLVM_BB);
17695+
MF.insert(MBBI, ReturnEarlyNullByteMBB);
17696+
Register NegReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17697+
BuildMI(*ReturnEarlyNullByteMBB, ReturnEarlyNullByteMBB->end(), DL,
17698+
TII.get(RISCV::ADDI), NegReg)
17699+
.addReg(RISCV::X0)
17700+
.addImm(-1);
17701+
ReturnEarlyNullByteMBB->addSuccessor(ExitMBB);
17702+
PHI_MIB.addReg(NegReg).addMBB(ReturnEarlyNullByteMBB);
17703+
MBBI = ReturnEarlyNullByteMBB->getIterator();
17704+
17705+
Register BaseReg = MI.getOperand(1).getReg();
17706+
MachineMemOperand &MMO = *MI.memoperands()[0];
17707+
17708+
MachineBasicBlock *CheckNullByteMBB = MF.CreateMachineBasicBlock(LLVM_BB);
17709+
MF.insert(MBBI, CheckNullByteMBB);
17710+
Register LoadedLastByteReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17711+
MachineInstr &LoadLastByteMI =
17712+
*BuildMI(*CheckNullByteMBB, CheckNullByteMBB->end(), DL,
17713+
TII.get(RISCV::LBU), LoadedLastByteReg)
17714+
.addReg(BaseReg)
17715+
.addImm(NumOfBytes)
17716+
.cloneMemRefs(MI)
17717+
.getInstr();
17718+
MachineMemOperand *NewMMO = MF.getMachineMemOperand(
17719+
MMO.getPointerInfo(), MachineMemOperand::MOLoad, LLT(MVT::i8), Align(1));
17720+
LoadLastByteMI.setMemRefs(MF, {NewMMO});
17721+
LoadLastByteMI.memoperands()[0]->setOffset(NumOfBytes);
17722+
17723+
Register NegLoadedLastByteReg =
17724+
MRI.createVirtualRegister(&RISCV::GPRRegClass);
17725+
BuildMI(*CheckNullByteMBB, CheckNullByteMBB->end(), DL, TII.get(RISCV::SUB),
17726+
NegLoadedLastByteReg)
17727+
.addReg(RISCV::X0)
17728+
.addReg(LoadedLastByteReg);
17729+
BuildMI(*CheckNullByteMBB, CheckNullByteMBB->end(), DL,
17730+
TII.get(RISCV::PseudoBR))
17731+
.addMBB(ExitMBB);
17732+
CheckNullByteMBB->addSuccessor(ExitMBB);
17733+
PHI_MIB.addReg(NegLoadedLastByteReg).addMBB(CheckNullByteMBB);
17734+
MBBI = CheckNullByteMBB->getIterator();
17735+
17736+
// First byte will be processed in the original MBB.
17737+
// Create NewMBBs for all other (non-null) bytes.
17738+
MachineFunction::iterator NewMBBI = MBBI;
17739+
SmallVector<MachineBasicBlock *> NewMBBs(NumOfBytes);
17740+
for (int i = NumOfBytes - 2; i >= 0; --i) {
17741+
MachineBasicBlock *NewMBB = MF.CreateMachineBasicBlock(LLVM_BB);
17742+
NewMBBs[i] = NewMBB;
17743+
MF.insert(NewMBBI, NewMBB);
17744+
NewMBBI = NewMBB->getIterator();
17745+
}
17746+
// The CheckNullByteMBB will be a fall-through successor
17747+
// of the block checking last non-null byte.
17748+
NewMBBs[NumOfBytes - 1] = CheckNullByteMBB;
17749+
17750+
int64_t Offset = 0;
17751+
char Byte = Str[0];
17752+
MachineBasicBlock::iterator MII = std::next(MI.getIterator());
17753+
MachineBasicBlock *CurrMBB = MBB;
17754+
MachineBasicBlock *NextMBB = NewMBBs[0];
17755+
17756+
auto emitCodeToCheckOneByteEquality = [&] {
17757+
Register LoadedByteReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17758+
MachineInstr &LoadByteMI =
17759+
*BuildMI(*CurrMBB, MII, DL, TII.get(RISCV::LBU), LoadedByteReg)
17760+
.addReg(BaseReg)
17761+
.addImm(Offset)
17762+
.cloneMemRefs(MI)
17763+
.getInstr();
17764+
MachineMemOperand *NewMMO =
17765+
MF.getMachineMemOperand(MMO.getPointerInfo(), MachineMemOperand::MOLoad,
17766+
LLT(MVT::i8), Align(1));
17767+
LoadByteMI.setMemRefs(MF, {NewMMO});
17768+
LoadByteMI.memoperands()[0]->setOffset(Offset);
17769+
17770+
BuildMI(*CurrMBB, MII, DL, TII.get(RISCV::BEQ))
17771+
.addReg(LoadedByteReg)
17772+
.addReg(RISCV::X0)
17773+
.addMBB(ReturnEarlyNullByteMBB);
17774+
17775+
MBBI = NextMBB->getIterator();
17776+
MachineBasicBlock *CheckBytesEqualMBB = MF.CreateMachineBasicBlock(LLVM_BB);
17777+
MF.insert(MBBI, CheckBytesEqualMBB);
17778+
CurrMBB->addSuccessor(ReturnEarlyNullByteMBB);
17779+
CurrMBB->addSuccessor(CheckBytesEqualMBB);
17780+
17781+
MachineBasicBlock::iterator CheckBytesEqualMMBI =
17782+
CheckBytesEqualMBB->begin();
17783+
Register DiffReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
17784+
BuildMI(*CheckBytesEqualMBB, CheckBytesEqualMMBI, DL, TII.get(RISCV::ADDI),
17785+
DiffReg)
17786+
.addReg(LoadedByteReg)
17787+
.addImm(-Byte);
17788+
17789+
BuildMI(*CheckBytesEqualMBB, CheckBytesEqualMMBI, DL, TII.get(RISCV::BNE))
17790+
.addReg(DiffReg)
17791+
.addReg(RISCV::X0)
17792+
.addMBB(ExitMBB);
17793+
17794+
CheckBytesEqualMBB->addSuccessor(ExitMBB);
17795+
PHI_MIB.addReg(DiffReg).addMBB(CheckBytesEqualMBB);
17796+
CheckBytesEqualMBB->addSuccessor(NextMBB);
17797+
};
17798+
17799+
// Check the first byte.
17800+
emitCodeToCheckOneByteEquality();
17801+
17802+
for (int i = 0; i < NumOfBytes - 1; ++i) {
17803+
++Offset;
17804+
Byte = Str[i + 1];
17805+
CurrMBB = NewMBBs[i];
17806+
MII = CurrMBB->begin();
17807+
NextMBB = NewMBBs[i + 1];
17808+
// Check all other non-null bytes.
17809+
// On the last iteration of this loop,
17810+
// NextMBB is CheckNullByteMBB, so it will become
17811+
// a fall-through successor of basic block checking last non-null byte.
17812+
emitCodeToCheckOneByteEquality();
17813+
}
17814+
17815+
MI.eraseFromParent();
17816+
return ExitMBB;
17817+
}
17818+
1765817819
MachineBasicBlock *
1765917820
RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
1766017821
MachineBasicBlock *BB) const {
@@ -17737,6 +17898,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
1773717898
case RISCV::PseudoFROUND_D_INX:
1773817899
case RISCV::PseudoFROUND_D_IN32X:
1773917900
return emitFROUND(MI, BB, Subtarget);
17901+
case RISCV::PseudoSTRCMPI:
17902+
return emitSTRCMPI(MI, BB, Subtarget);
1774017903
case TargetOpcode::STATEPOINT:
1774117904
case TargetOpcode::STACKMAP:
1774217905
case TargetOpcode::PATCHPOINT:
@@ -19512,6 +19675,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1951219675
NODE_NAME_CASE(SWAP_CSR)
1951319676
NODE_NAME_CASE(CZERO_EQZ)
1951419677
NODE_NAME_CASE(CZERO_NEZ)
19678+
NODE_NAME_CASE(STRCMP)
1951519679
NODE_NAME_CASE(SF_VC_XV_SE)
1951619680
NODE_NAME_CASE(SF_VC_IV_SE)
1951719681
NODE_NAME_CASE(SF_VC_VV_SE)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ enum NodeType : unsigned {
456456
TH_LDD,
457457
TH_SWD,
458458
TH_SDD,
459+
STRCMP
459460
};
460461
// clang-format on
461462
} // namespace RISCVISD

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,29 @@ def : Pat<(shl (zext GPR:$rs), uimm5:$shamt),
19521952
(SRLI (i64 (SLLI GPR:$rs, 32)), (ImmSubFrom32 uimm5:$shamt))>;
19531953
}
19541954

1955+
def riscv_strcmp : SDNode<
1956+
"RISCVISD::STRCMP",
1957+
SDTypeProfile<1, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>,
1958+
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]
1959+
>;
1960+
1961+
let usesCustomInserter = 1, mayLoad = 1, mayStore = 0, hasSideEffects = 0 in
1962+
def PseudoSTRCMPI : Pseudo<
1963+
(outs GPR:$rd),
1964+
(ins GPR:$str1, i64imm:$str2, i64imm:$constant_str_idx),
1965+
[]
1966+
>;
1967+
1968+
def : Pat<
1969+
(XLenVT (riscv_strcmp tglobaladdr:$str1, iPTR:$str2)),
1970+
(PseudoSTRCMPI GPR:$str2, tglobaladdr:$str1, 0)
1971+
>;
1972+
1973+
def : Pat<
1974+
(XLenVT (riscv_strcmp iPTR:$str1, tglobaladdr:$str2)),
1975+
(PseudoSTRCMPI GPR:$str1, tglobaladdr:$str2, 1)
1976+
>;
1977+
19551978
//===----------------------------------------------------------------------===//
19561979
// Standard extensions
19571980
//===----------------------------------------------------------------------===//
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//===-- RISCVSelectionDAGTargetInfo.cpp - RISCV SelectionDAG Info
2+
//-----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements the RISCVSelectionDAGTargetInfo class.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "RISCVSelectionDAGTargetInfo.h"
15+
#include "RISCVSubtarget.h"
16+
#include "llvm/CodeGen/SelectionDAG.h"
17+
#include "llvm/IR/GlobalValue.h"
18+
#include "llvm/IR/GlobalVariable.h"
19+
#include "llvm/IR/Type.h"
20+
21+
using namespace llvm;
22+
23+
#define DEBUG_TYPE "riscv-selectiondag-target-info"
24+
25+
static cl::opt<unsigned> MaxStrcmpSpecializeLength(
26+
"riscv-max-strcmp-specialize-length", cl::Hidden,
27+
cl::desc("Do not specialize strcmp if the length of constant string is "
28+
"greater or equal to this parameter"),
29+
cl::init(0));
30+
31+
static bool canSpecializeStrcmp(const GlobalAddressSDNode *GA) {
32+
const GlobalVariable *GV = dyn_cast<GlobalVariable>(GA->getGlobal());
33+
if (!GV || !GV->isConstant() || !GV->hasInitializer())
34+
return false;
35+
// NOTE: this doesn't work for empty strings
36+
const ConstantDataArray *CDA =
37+
dyn_cast<ConstantDataArray>(GV->getInitializer());
38+
if (!CDA || !CDA->isCString())
39+
return false;
40+
41+
StringRef CString = CDA->getAsCString();
42+
if (CString.str().length() >= MaxStrcmpSpecializeLength)
43+
return false;
44+
45+
return true;
46+
}
47+
48+
std::pair<SDValue, SDValue>
49+
RISCVSelectionDAGTargetInfo::EmitTargetCodeForStrcmp(
50+
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Src1,
51+
SDValue Src2, MachinePointerInfo Op1PtrInfo,
52+
MachinePointerInfo Op2PtrInfo) const {
53+
// This is the default setting, so exit early if the optimization is turned
54+
// off.
55+
if (MaxStrcmpSpecializeLength == 0)
56+
return std::make_pair(SDValue(), Chain);
57+
58+
const RISCVSubtarget &Subtarget =
59+
DAG.getMachineFunction().getSubtarget<RISCVSubtarget>();
60+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
61+
MVT XLenVT = Subtarget.getXLenVT();
62+
const DataLayout &DLayout = DAG.getDataLayout();
63+
64+
Align NeededAlignment = Align(XLenVT.getSizeInBits() / 8);
65+
Align Src1Align;
66+
Align Src2Align;
67+
if (const Value *Src1V = dyn_cast_if_present<const Value *>(Op1PtrInfo.V)) {
68+
Src1Align = Src1V->getPointerAlignment(DLayout);
69+
}
70+
if (const Value *Src2V = dyn_cast_if_present<const Value *>(Op2PtrInfo.V)) {
71+
Src2Align = Src2V->getPointerAlignment(DLayout);
72+
}
73+
if (!(Src1Align < NeededAlignment || Src2Align < NeededAlignment))
74+
return std::make_pair(SDValue(), Chain);
75+
76+
const GlobalAddressSDNode *CStringGA = nullptr;
77+
SDValue Other;
78+
MachinePointerInfo MPI;
79+
bool ConstantStringIsSecond = false;
80+
81+
const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Src1);
82+
if (GA && canSpecializeStrcmp(GA)) {
83+
CStringGA = GA;
84+
Other = Src2;
85+
MPI = Op2PtrInfo;
86+
}
87+
if (!CStringGA) {
88+
GA = dyn_cast<GlobalAddressSDNode>(Src2);
89+
if (GA && canSpecializeStrcmp(GA)) {
90+
ConstantStringIsSecond = true;
91+
CStringGA = GA;
92+
Other = Src1;
93+
MPI = Op1PtrInfo;
94+
}
95+
}
96+
97+
if (!CStringGA)
98+
return std::make_pair(SDValue(), Chain);
99+
100+
// It could be that the non-constant string is actually aligned, but
101+
// we can't prove it, so getPointerAlignment will return Align(1).
102+
// In this case, if the constant string is sufficiently aligned, It is better
103+
// to call to libc's strcmp?
104+
Align ConstantStrAlignment = ConstantStringIsSecond ? Src2Align : Src1Align;
105+
if (ConstantStrAlignment >= NeededAlignment)
106+
return std::make_pair(SDValue(), Chain);
107+
108+
SDValue TGA = DAG.getTargetGlobalAddress(CStringGA->getGlobal(), DL,
109+
TLI.getPointerTy(DLayout), 0,
110+
CStringGA->getTargetFlags());
111+
112+
SDValue Str1 = TGA;
113+
SDValue Str2 = Other;
114+
if (ConstantStringIsSecond)
115+
std::swap(Str1, Str2);
116+
117+
MachineFunction &MF = DAG.getMachineFunction();
118+
MachineMemOperand *MMO = MF.getMachineMemOperand(
119+
MPI, MachineMemOperand::MOLoad, LLT(MVT::i8), Align(1));
120+
// TODO: what should be the MemVT?
121+
SDValue STRCMPNode = DAG.getMemIntrinsicNode(
122+
RISCVISD::STRCMP, DL, DAG.getVTList(XLenVT, MVT::Other),
123+
{Chain, Str1, Str2}, MVT::i8, MMO);
124+
125+
SDValue ChainOut = STRCMPNode.getValue(1);
126+
return std::make_pair(STRCMPNode, ChainOut);
127+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===-- RISCVSelectionDAGTargetInfo.h - RISCV SelectionDAG Info ---*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file defines the RISCV subclass for SelectionDAGTargetInfo.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_RISCV_RISCVSELECTIONDAGINFO_H
15+
#define LLVM_LIB_TARGET_RISCV_RISCVSELECTIONDAGINFO_H
16+
17+
#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
18+
19+
namespace llvm {
20+
21+
class RISCVSelectionDAGTargetInfo : public SelectionDAGTargetInfo {
22+
public:
23+
explicit RISCVSelectionDAGTargetInfo() = default;
24+
std::pair<SDValue, SDValue>
25+
EmitTargetCodeForStrcmp(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
26+
SDValue Src1, SDValue Src2,
27+
MachinePointerInfo Op1PtrInfo,
28+
MachinePointerInfo Op2PtrInfo) const override;
29+
};
30+
31+
} // end namespace llvm
32+
33+
#endif

0 commit comments

Comments
 (0)