Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ struct GenericToNVVMPass : PassInfoMixin<GenericToNVVMPass> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

struct NVPTXCopyByValArgsPass : PassInfoMixin<NVPTXCopyByValArgsPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

namespace NVPTX {
enum DrvInterface {
NVCL,
Expand Down
73 changes: 48 additions & 25 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,33 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
PI.setAborted(&II);
}
}; // struct ArgUseChecker

void copyByValParam(Function &F, Argument &Arg) {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
// Otherwise we have to create a temporary copy.
BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
Type *StructType = Arg.getParamByValType();
const DataLayout &DL = F.getDataLayout();
AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
Arg.getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg.replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
&Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
Arg.getName(), FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
IRBuilder<> IRB(&*FirstInst);
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
ArgSize);
}
} // namespace

void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Expand All @@ -558,7 +585,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,

ArgUseChecker AUC(DL, IsGridConstant);
ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
// Easy case, accessing parameter directly is fine.
if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
// Convert all loads and intermediate operations to use parameter AS and
Expand Down Expand Up @@ -587,7 +614,6 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
// However, we're still not allowed to write to it. If the user specified
// `__grid_constant__` for the argument, we'll consider escaped pointer as
// read-only.
unsigned AS = DL.getAllocaAddrSpace();
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
// Replace all argument pointer uses (which might include a device function
Expand All @@ -612,29 +638,8 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,

// Do not replace Arg in the cast to param space
CastToParam->setOperand(0, Arg);
} else {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
// Otherwise we have to create a temporary copy.
AllocaInst *AllocA =
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
Arg->replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
Arg->getName(), FirstInst);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
IRBuilder<> IRB(&*FirstInst);
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
ArgSize);
}
} else
copyByValParam(*Func, *Arg);
}

void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
Expand Down Expand Up @@ -734,3 +739,21 @@ bool NVPTXLowerArgs::runOnFunction(Function &F) {
}

FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }

static bool copyFunctionByValArgs(Function &F) {
LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
<< "\n");
bool Changed = false;
for (Argument &Arg : F.args())
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
copyByValParam(F, Arg);
Changed = true;
}
return Changed;
}

PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
FunctionAnalysisManager &AM) {
return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA())
#endif
FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
FUNCTION_PASS("nvvm-reflect", NVVMReflectPass())
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
#undef FUNCTION_PASS
7 changes: 7 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ static cl::opt<bool> UseShortPointersOpt(
"Use 32-bit pointers for accessing const/local/shared address spaces."),
cl::init(false), cl::Hidden);

static cl::opt<bool> EarlyByValArgsCopy(
"nvptx-early-byval-copy",
cl::desc("Create a copy of byval function arguments early."),
cl::init(false), cl::Hidden);

namespace llvm {

void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
Expand Down Expand Up @@ -236,6 +241,8 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// Note: NVVMIntrRangePass was causing numerical discrepancies at one
// point, if issues crop up, consider disabling.
FPM.addPass(NVVMIntrRangePass());
if (EarlyByValArgsCopy)
FPM.addPass(NVPTXCopyByValArgsPass());
PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
});
}
Expand Down
Loading
Loading