10b57cec5SDimitry Andric //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 20b57cec5SDimitry Andric // 30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60b57cec5SDimitry Andric // 70b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 80b57cec5SDimitry Andric // 90b57cec5SDimitry Andric // 100b57cec5SDimitry Andric // Arguments to kernel and device functions are passed via param space, 110b57cec5SDimitry Andric // which imposes certain restrictions: 120b57cec5SDimitry Andric // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 130b57cec5SDimitry Andric // 140b57cec5SDimitry Andric // Kernel parameters are read-only and accessible only via ld.param 150b57cec5SDimitry Andric // instruction, directly or via a pointer. Pointers to kernel 160b57cec5SDimitry Andric // arguments can't be converted to generic address space. 170b57cec5SDimitry Andric // 180b57cec5SDimitry Andric // Device function parameters are directly accessible via 190b57cec5SDimitry Andric // ld.param/st.param, but taking the address of one returns a pointer 200b57cec5SDimitry Andric // to a copy created in local space which *can't* be used with 210b57cec5SDimitry Andric // ld.param/st.param. 220b57cec5SDimitry Andric // 230b57cec5SDimitry Andric // Copying a byval struct into local memory in IR allows us to enforce 240b57cec5SDimitry Andric // the param space restrictions, gives the rest of IR a pointer w/o 250b57cec5SDimitry Andric // param space restrictions, and gives us an opportunity to eliminate 260b57cec5SDimitry Andric // the copy. 270b57cec5SDimitry Andric // 280b57cec5SDimitry Andric // Pointer arguments to kernel functions need more work to be lowered: 290b57cec5SDimitry Andric // 300b57cec5SDimitry Andric // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 310b57cec5SDimitry Andric // global address space. This allows later optimizations to emit 320b57cec5SDimitry Andric // ld.global.*/st.global.* for accessing these pointer arguments. For 330b57cec5SDimitry Andric // example, 340b57cec5SDimitry Andric // 350b57cec5SDimitry Andric // define void @foo(float* %input) { 360b57cec5SDimitry Andric // %v = load float, float* %input, align 4 370b57cec5SDimitry Andric // ... 380b57cec5SDimitry Andric // } 390b57cec5SDimitry Andric // 400b57cec5SDimitry Andric // becomes 410b57cec5SDimitry Andric // 420b57cec5SDimitry Andric // define void @foo(float* %input) { 430b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 440b57cec5SDimitry Andric // %input3 = addrspacecast float addrspace(1)* %input2 to float* 450b57cec5SDimitry Andric // %v = load float, float* %input3, align 4 460b57cec5SDimitry Andric // ... 470b57cec5SDimitry Andric // } 480b57cec5SDimitry Andric // 490b57cec5SDimitry Andric // Later, NVPTXInferAddressSpaces will optimize it to 500b57cec5SDimitry Andric // 510b57cec5SDimitry Andric // define void @foo(float* %input) { 520b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 530b57cec5SDimitry Andric // %v = load float, float addrspace(1)* %input2, align 4 540b57cec5SDimitry Andric // ... 550b57cec5SDimitry Andric // } 560b57cec5SDimitry Andric // 570b57cec5SDimitry Andric // 2. Convert pointers in a byval kernel parameter to pointers in the global 580b57cec5SDimitry Andric // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 590b57cec5SDimitry Andric // 600b57cec5SDimitry Andric // struct S { 610b57cec5SDimitry Andric // int *x; 620b57cec5SDimitry Andric // int *y; 630b57cec5SDimitry Andric // }; 640b57cec5SDimitry Andric // __global__ void foo(S s) { 650b57cec5SDimitry Andric // int *b = s.y; 660b57cec5SDimitry Andric // // use b 670b57cec5SDimitry Andric // } 680b57cec5SDimitry Andric // 690b57cec5SDimitry Andric // "b" points to the global address space. In the IR level, 700b57cec5SDimitry Andric // 710b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 720b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 730b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 740b57cec5SDimitry Andric // ; use %b 750b57cec5SDimitry Andric // } 760b57cec5SDimitry Andric // 770b57cec5SDimitry Andric // becomes 780b57cec5SDimitry Andric // 790b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 800b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 810b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 820b57cec5SDimitry Andric // %b_global = addrspacecast i32* %b to i32 addrspace(1)* 830b57cec5SDimitry Andric // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 840b57cec5SDimitry Andric // ; use %b_generic 850b57cec5SDimitry Andric // } 860b57cec5SDimitry Andric // 870b57cec5SDimitry Andric // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 880b57cec5SDimitry Andric // cancel the addrspacecast pair this pass emits. 890b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 900b57cec5SDimitry Andric 9181ad6265SDimitry Andric #include "MCTargetDesc/NVPTXBaseInfo.h" 920b57cec5SDimitry Andric #include "NVPTX.h" 930b57cec5SDimitry Andric #include "NVPTXTargetMachine.h" 940b57cec5SDimitry Andric #include "NVPTXUtilities.h" 950b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 9606c3fb27SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 970b57cec5SDimitry Andric #include "llvm/IR/Function.h" 980b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 990b57cec5SDimitry Andric #include "llvm/IR/Module.h" 1000b57cec5SDimitry Andric #include "llvm/IR/Type.h" 10106c3fb27SDimitry Andric #include "llvm/InitializePasses.h" 1020b57cec5SDimitry Andric #include "llvm/Pass.h" 103bdd1243dSDimitry Andric #include <numeric> 10481ad6265SDimitry Andric #include <queue> 1050b57cec5SDimitry Andric 106fe6060f1SDimitry Andric #define DEBUG_TYPE "nvptx-lower-args" 107fe6060f1SDimitry Andric 1080b57cec5SDimitry Andric using namespace llvm; 1090b57cec5SDimitry Andric 1100b57cec5SDimitry Andric namespace llvm { 1110b57cec5SDimitry Andric void initializeNVPTXLowerArgsPass(PassRegistry &); 1120b57cec5SDimitry Andric } 1130b57cec5SDimitry Andric 1140b57cec5SDimitry Andric namespace { 1150b57cec5SDimitry Andric class NVPTXLowerArgs : public FunctionPass { 1160b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 1170b57cec5SDimitry Andric 11806c3fb27SDimitry Andric bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F); 11906c3fb27SDimitry Andric bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F); 1200b57cec5SDimitry Andric 1210b57cec5SDimitry Andric // handle byval parameters 12206c3fb27SDimitry Andric void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg); 1230b57cec5SDimitry Andric // Knowing Ptr must point to the global address space, this function 1240b57cec5SDimitry Andric // addrspacecasts Ptr to global and then back to generic. This allows 1250b57cec5SDimitry Andric // NVPTXInferAddressSpaces to fold the global-to-generic cast into 1260b57cec5SDimitry Andric // loads/stores that appear later. 1270b57cec5SDimitry Andric void markPointerAsGlobal(Value *Ptr); 1280b57cec5SDimitry Andric 1290b57cec5SDimitry Andric public: 1300b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 13106c3fb27SDimitry Andric NVPTXLowerArgs() : FunctionPass(ID) {} 1320b57cec5SDimitry Andric StringRef getPassName() const override { 1330b57cec5SDimitry Andric return "Lower pointer arguments of CUDA kernels"; 1340b57cec5SDimitry Andric } 13506c3fb27SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 13606c3fb27SDimitry Andric AU.addRequired<TargetPassConfig>(); 13706c3fb27SDimitry Andric } 1380b57cec5SDimitry Andric }; 1390b57cec5SDimitry Andric } // namespace 1400b57cec5SDimitry Andric 1410b57cec5SDimitry Andric char NVPTXLowerArgs::ID = 1; 1420b57cec5SDimitry Andric 14306c3fb27SDimitry Andric INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args", 14406c3fb27SDimitry Andric "Lower arguments (NVPTX)", false, false) 14506c3fb27SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 14606c3fb27SDimitry Andric INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args", 1470b57cec5SDimitry Andric "Lower arguments (NVPTX)", false, false) 1480b57cec5SDimitry Andric 1490b57cec5SDimitry Andric // ============================================================================= 1500b57cec5SDimitry Andric // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 151fe6060f1SDimitry Andric // and we can't guarantee that the only accesses are loads, 1520b57cec5SDimitry Andric // then add the following instructions to the first basic block: 1530b57cec5SDimitry Andric // 1540b57cec5SDimitry Andric // %temp = alloca %struct.x, align 8 1550b57cec5SDimitry Andric // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 1560b57cec5SDimitry Andric // %tv = load %struct.x addrspace(101)* %tempd 1570b57cec5SDimitry Andric // store %struct.x %tv, %struct.x* %temp, align 8 1580b57cec5SDimitry Andric // 1590b57cec5SDimitry Andric // The above code allocates some space in the stack and copies the incoming 1600b57cec5SDimitry Andric // struct from param space to local space. 1610b57cec5SDimitry Andric // Then replace all occurrences of %d by %temp. 162fe6060f1SDimitry Andric // 163fe6060f1SDimitry Andric // In case we know that all users are GEPs or Loads, replace them with the same 164fe6060f1SDimitry Andric // ones in parameter AS, so we can access them using ld.param. 1650b57cec5SDimitry Andric // ============================================================================= 166fe6060f1SDimitry Andric 167fe6060f1SDimitry Andric // Replaces the \p OldUser instruction with the same in parameter AS. 168fe6060f1SDimitry Andric // Only Load and GEP are supported. 169fe6060f1SDimitry Andric static void convertToParamAS(Value *OldUser, Value *Param) { 170fe6060f1SDimitry Andric Instruction *I = dyn_cast<Instruction>(OldUser); 171fe6060f1SDimitry Andric assert(I && "OldUser must be an instruction"); 172fe6060f1SDimitry Andric struct IP { 173fe6060f1SDimitry Andric Instruction *OldInstruction; 174fe6060f1SDimitry Andric Value *NewParam; 175fe6060f1SDimitry Andric }; 176fe6060f1SDimitry Andric SmallVector<IP> ItemsToConvert = {{I, Param}}; 177fe6060f1SDimitry Andric SmallVector<Instruction *> InstructionsToDelete; 178fe6060f1SDimitry Andric 179fe6060f1SDimitry Andric auto CloneInstInParamAS = [](const IP &I) -> Value * { 180fe6060f1SDimitry Andric if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { 181fe6060f1SDimitry Andric LI->setOperand(0, I.NewParam); 182fe6060f1SDimitry Andric return LI; 183fe6060f1SDimitry Andric } 184fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { 185fe6060f1SDimitry Andric SmallVector<Value *, 4> Indices(GEP->indices()); 186fe6060f1SDimitry Andric auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(), 187fe6060f1SDimitry Andric I.NewParam, Indices, 188fe6060f1SDimitry Andric GEP->getName(), GEP); 189fe6060f1SDimitry Andric NewGEP->setIsInBounds(GEP->isInBounds()); 190fe6060f1SDimitry Andric return NewGEP; 191fe6060f1SDimitry Andric } 192fe6060f1SDimitry Andric if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { 19306c3fb27SDimitry Andric auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM); 194fe6060f1SDimitry Andric return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, 195fe6060f1SDimitry Andric BC->getName(), BC); 196fe6060f1SDimitry Andric } 197fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { 198fe6060f1SDimitry Andric assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); 199fe6060f1SDimitry Andric (void)ASC; 200fe6060f1SDimitry Andric // Just pass through the argument, the old ASC is no longer needed. 201fe6060f1SDimitry Andric return I.NewParam; 202fe6060f1SDimitry Andric } 203fe6060f1SDimitry Andric llvm_unreachable("Unsupported instruction"); 204fe6060f1SDimitry Andric }; 205fe6060f1SDimitry Andric 206fe6060f1SDimitry Andric while (!ItemsToConvert.empty()) { 207fe6060f1SDimitry Andric IP I = ItemsToConvert.pop_back_val(); 208fe6060f1SDimitry Andric Value *NewInst = CloneInstInParamAS(I); 209fe6060f1SDimitry Andric 210fe6060f1SDimitry Andric if (NewInst && NewInst != I.OldInstruction) { 211fe6060f1SDimitry Andric // We've created a new instruction. Queue users of the old instruction to 212fe6060f1SDimitry Andric // be converted and the instruction itself to be deleted. We can't delete 213fe6060f1SDimitry Andric // the old instruction yet, because it's still in use by a load somewhere. 21481ad6265SDimitry Andric for (Value *V : I.OldInstruction->users()) 215fe6060f1SDimitry Andric ItemsToConvert.push_back({cast<Instruction>(V), NewInst}); 216fe6060f1SDimitry Andric 217fe6060f1SDimitry Andric InstructionsToDelete.push_back(I.OldInstruction); 218fe6060f1SDimitry Andric } 219fe6060f1SDimitry Andric } 220fe6060f1SDimitry Andric 221fe6060f1SDimitry Andric // Now we know that all argument loads are using addresses in parameter space 222fe6060f1SDimitry Andric // and we can finally remove the old instructions in generic AS. Instructions 223fe6060f1SDimitry Andric // scheduled for removal should be processed in reverse order so the ones 224fe6060f1SDimitry Andric // closest to the load are deleted first. Otherwise they may still be in use. 225fe6060f1SDimitry Andric // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will 226fe6060f1SDimitry Andric // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by 227fe6060f1SDimitry Andric // the BitCast. 22881ad6265SDimitry Andric for (Instruction *I : llvm::reverse(InstructionsToDelete)) 22981ad6265SDimitry Andric I->eraseFromParent(); 23081ad6265SDimitry Andric } 23181ad6265SDimitry Andric 23281ad6265SDimitry Andric // Adjust alignment of arguments passed byval in .param address space. We can 23381ad6265SDimitry Andric // increase alignment of such arguments in a way that ensures that we can 23481ad6265SDimitry Andric // effectively vectorize their loads. We should also traverse all loads from 23581ad6265SDimitry Andric // byval pointer and adjust their alignment, if those were using known offset. 23681ad6265SDimitry Andric // Such alignment changes must be conformed with parameter store and load in 23781ad6265SDimitry Andric // NVPTXTargetLowering::LowerCall. 23881ad6265SDimitry Andric static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, 23981ad6265SDimitry Andric const NVPTXTargetLowering *TLI) { 24081ad6265SDimitry Andric Function *Func = Arg->getParent(); 24181ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 24281ad6265SDimitry Andric const DataLayout DL(Func->getParent()); 24381ad6265SDimitry Andric 24481ad6265SDimitry Andric uint64_t NewArgAlign = 24581ad6265SDimitry Andric TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value(); 24681ad6265SDimitry Andric uint64_t CurArgAlign = 24781ad6265SDimitry Andric Arg->getAttribute(Attribute::Alignment).getValueAsInt(); 24881ad6265SDimitry Andric 24981ad6265SDimitry Andric if (CurArgAlign >= NewArgAlign) 25081ad6265SDimitry Andric return; 25181ad6265SDimitry Andric 25281ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of " 25381ad6265SDimitry Andric << CurArgAlign << " for " << *Arg << '\n'); 25481ad6265SDimitry Andric 25581ad6265SDimitry Andric auto NewAlignAttr = 25681ad6265SDimitry Andric Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign); 25781ad6265SDimitry Andric Arg->removeAttr(Attribute::Alignment); 25881ad6265SDimitry Andric Arg->addAttr(NewAlignAttr); 25981ad6265SDimitry Andric 26081ad6265SDimitry Andric struct Load { 26181ad6265SDimitry Andric LoadInst *Inst; 26281ad6265SDimitry Andric uint64_t Offset; 26381ad6265SDimitry Andric }; 26481ad6265SDimitry Andric 26581ad6265SDimitry Andric struct LoadContext { 26681ad6265SDimitry Andric Value *InitialVal; 26781ad6265SDimitry Andric uint64_t Offset; 26881ad6265SDimitry Andric }; 26981ad6265SDimitry Andric 27081ad6265SDimitry Andric SmallVector<Load> Loads; 27181ad6265SDimitry Andric std::queue<LoadContext> Worklist; 27281ad6265SDimitry Andric Worklist.push({ArgInParamAS, 0}); 27381ad6265SDimitry Andric 27481ad6265SDimitry Andric while (!Worklist.empty()) { 27581ad6265SDimitry Andric LoadContext Ctx = Worklist.front(); 27681ad6265SDimitry Andric Worklist.pop(); 27781ad6265SDimitry Andric 27881ad6265SDimitry Andric for (User *CurUser : Ctx.InitialVal->users()) { 27981ad6265SDimitry Andric if (auto *I = dyn_cast<LoadInst>(CurUser)) { 28081ad6265SDimitry Andric Loads.push_back({I, Ctx.Offset}); 28181ad6265SDimitry Andric continue; 28281ad6265SDimitry Andric } 28381ad6265SDimitry Andric 28481ad6265SDimitry Andric if (auto *I = dyn_cast<BitCastInst>(CurUser)) { 28581ad6265SDimitry Andric Worklist.push({I, Ctx.Offset}); 28681ad6265SDimitry Andric continue; 28781ad6265SDimitry Andric } 28881ad6265SDimitry Andric 28981ad6265SDimitry Andric if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) { 29081ad6265SDimitry Andric APInt OffsetAccumulated = 29181ad6265SDimitry Andric APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM)); 29281ad6265SDimitry Andric 29381ad6265SDimitry Andric if (!I->accumulateConstantOffset(DL, OffsetAccumulated)) 29481ad6265SDimitry Andric continue; 29581ad6265SDimitry Andric 29681ad6265SDimitry Andric uint64_t OffsetLimit = -1; 29781ad6265SDimitry Andric uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit); 29881ad6265SDimitry Andric assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX"); 29981ad6265SDimitry Andric 30081ad6265SDimitry Andric Worklist.push({I, Ctx.Offset + Offset}); 30181ad6265SDimitry Andric continue; 30281ad6265SDimitry Andric } 30381ad6265SDimitry Andric 30481ad6265SDimitry Andric llvm_unreachable("All users must be one of: load, " 30581ad6265SDimitry Andric "bitcast, getelementptr."); 30681ad6265SDimitry Andric } 30781ad6265SDimitry Andric } 30881ad6265SDimitry Andric 30981ad6265SDimitry Andric for (Load &CurLoad : Loads) { 310bdd1243dSDimitry Andric Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset)); 31181ad6265SDimitry Andric Align CurLoadAlign(CurLoad.Inst->getAlign()); 31281ad6265SDimitry Andric CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign)); 31381ad6265SDimitry Andric } 314fe6060f1SDimitry Andric } 315fe6060f1SDimitry Andric 31606c3fb27SDimitry Andric void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, 31706c3fb27SDimitry Andric Argument *Arg) { 3180b57cec5SDimitry Andric Function *Func = Arg->getParent(); 3190b57cec5SDimitry Andric Instruction *FirstInst = &(Func->getEntryBlock().front()); 32081ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 32181ad6265SDimitry Andric assert(StructType && "Missing byval type"); 322fe6060f1SDimitry Andric 323fe6060f1SDimitry Andric auto IsALoadChain = [&](Value *Start) { 324fe6060f1SDimitry Andric SmallVector<Value *, 16> ValuesToCheck = {Start}; 325fe6060f1SDimitry Andric auto IsALoadChainInstr = [](Value *V) -> bool { 326fe6060f1SDimitry Andric if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V)) 327fe6060f1SDimitry Andric return true; 328fe6060f1SDimitry Andric // ASC to param space are OK, too -- we'll just strip them. 329fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) { 330fe6060f1SDimitry Andric if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) 331fe6060f1SDimitry Andric return true; 332fe6060f1SDimitry Andric } 333fe6060f1SDimitry Andric return false; 334fe6060f1SDimitry Andric }; 335fe6060f1SDimitry Andric 336fe6060f1SDimitry Andric while (!ValuesToCheck.empty()) { 337fe6060f1SDimitry Andric Value *V = ValuesToCheck.pop_back_val(); 338fe6060f1SDimitry Andric if (!IsALoadChainInstr(V)) { 339fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V 340fe6060f1SDimitry Andric << "\n"); 341fe6060f1SDimitry Andric (void)Arg; 342fe6060f1SDimitry Andric return false; 343fe6060f1SDimitry Andric } 344fe6060f1SDimitry Andric if (!isa<LoadInst>(V)) 345fe6060f1SDimitry Andric llvm::append_range(ValuesToCheck, V->users()); 346fe6060f1SDimitry Andric } 347fe6060f1SDimitry Andric return true; 348fe6060f1SDimitry Andric }; 349fe6060f1SDimitry Andric 350fe6060f1SDimitry Andric if (llvm::all_of(Arg->users(), IsALoadChain)) { 351fe6060f1SDimitry Andric // Convert all loads and intermediate operations to use parameter AS and 352fe6060f1SDimitry Andric // skip creation of a local copy of the argument. 353fe6060f1SDimitry Andric SmallVector<User *, 16> UsersToUpdate(Arg->users()); 354fe6060f1SDimitry Andric Value *ArgInParamAS = new AddrSpaceCastInst( 355fe6060f1SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 356fe6060f1SDimitry Andric FirstInst); 35781ad6265SDimitry Andric for (Value *V : UsersToUpdate) 358fe6060f1SDimitry Andric convertToParamAS(V, ArgInParamAS); 359fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); 36081ad6265SDimitry Andric 36181ad6265SDimitry Andric const auto *TLI = 36206c3fb27SDimitry Andric cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering()); 36381ad6265SDimitry Andric 36481ad6265SDimitry Andric adjustByValArgAlignment(Arg, ArgInParamAS, TLI); 36581ad6265SDimitry Andric 366fe6060f1SDimitry Andric return; 367fe6060f1SDimitry Andric } 368fe6060f1SDimitry Andric 369fe6060f1SDimitry Andric // Otherwise we have to create a temporary copy. 3705ffd83dbSDimitry Andric const DataLayout &DL = Func->getParent()->getDataLayout(); 3715ffd83dbSDimitry Andric unsigned AS = DL.getAllocaAddrSpace(); 3720b57cec5SDimitry Andric AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 3730b57cec5SDimitry Andric // Set the alignment to alignment of the byval parameter. This is because, 3740b57cec5SDimitry Andric // later load/stores assume that alignment, and we are going to replace 3750b57cec5SDimitry Andric // the use of the byval parameter with this alloca instruction. 3765ffd83dbSDimitry Andric AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo()) 37781ad6265SDimitry Andric .value_or(DL.getPrefTypeAlign(StructType))); 3780b57cec5SDimitry Andric Arg->replaceAllUsesWith(AllocA); 3790b57cec5SDimitry Andric 3800b57cec5SDimitry Andric Value *ArgInParam = new AddrSpaceCastInst( 3810b57cec5SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 3820b57cec5SDimitry Andric FirstInst); 383e8d8bef9SDimitry Andric // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX 384e8d8bef9SDimitry Andric // addrspacecast preserves alignment. Since params are constant, this load is 385e8d8bef9SDimitry Andric // definitely not volatile. 3860b57cec5SDimitry Andric LoadInst *LI = 387e8d8bef9SDimitry Andric new LoadInst(StructType, ArgInParam, Arg->getName(), 388e8d8bef9SDimitry Andric /*isVolatile=*/false, AllocA->getAlign(), FirstInst); 3890b57cec5SDimitry Andric new StoreInst(LI, AllocA, FirstInst); 3900b57cec5SDimitry Andric } 3910b57cec5SDimitry Andric 3920b57cec5SDimitry Andric void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 39306c3fb27SDimitry Andric if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC) 3940b57cec5SDimitry Andric return; 3950b57cec5SDimitry Andric 3960b57cec5SDimitry Andric // Deciding where to emit the addrspacecast pair. 3970b57cec5SDimitry Andric BasicBlock::iterator InsertPt; 3980b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 3990b57cec5SDimitry Andric // Insert at the functon entry if Ptr is an argument. 4000b57cec5SDimitry Andric InsertPt = Arg->getParent()->getEntryBlock().begin(); 4010b57cec5SDimitry Andric } else { 4020b57cec5SDimitry Andric // Insert right after Ptr if Ptr is an instruction. 4030b57cec5SDimitry Andric InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 4040b57cec5SDimitry Andric assert(InsertPt != InsertPt->getParent()->end() && 4050b57cec5SDimitry Andric "We don't call this function with Ptr being a terminator."); 4060b57cec5SDimitry Andric } 4070b57cec5SDimitry Andric 4080b57cec5SDimitry Andric Instruction *PtrInGlobal = new AddrSpaceCastInst( 40906c3fb27SDimitry Andric Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL), 4100b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4110b57cec5SDimitry Andric Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 4120b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4130b57cec5SDimitry Andric // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 4140b57cec5SDimitry Andric Ptr->replaceAllUsesWith(PtrInGeneric); 4150b57cec5SDimitry Andric PtrInGlobal->setOperand(0, Ptr); 4160b57cec5SDimitry Andric } 4170b57cec5SDimitry Andric 4180b57cec5SDimitry Andric // ============================================================================= 4190b57cec5SDimitry Andric // Main function for this pass. 4200b57cec5SDimitry Andric // ============================================================================= 42106c3fb27SDimitry Andric bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM, 42206c3fb27SDimitry Andric Function &F) { 42306c3fb27SDimitry Andric // Copying of byval aggregates + SROA may result in pointers being loaded as 42406c3fb27SDimitry Andric // integers, followed by intotoptr. We may want to mark those as global, too, 42506c3fb27SDimitry Andric // but only if the loaded integer is used exclusively for conversion to a 42606c3fb27SDimitry Andric // pointer with inttoptr. 42706c3fb27SDimitry Andric auto HandleIntToPtr = [this](Value &V) { 42806c3fb27SDimitry Andric if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) { 42906c3fb27SDimitry Andric SmallVector<User *, 16> UsersToUpdate(V.users()); 430*5f757f3fSDimitry Andric for (User *U : UsersToUpdate) 431*5f757f3fSDimitry Andric markPointerAsGlobal(U); 43206c3fb27SDimitry Andric } 43306c3fb27SDimitry Andric }; 43406c3fb27SDimitry Andric if (TM.getDrvInterface() == NVPTX::CUDA) { 4350b57cec5SDimitry Andric // Mark pointers in byval structs as global. 4360b57cec5SDimitry Andric for (auto &B : F) { 4370b57cec5SDimitry Andric for (auto &I : B) { 4380b57cec5SDimitry Andric if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 43906c3fb27SDimitry Andric if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) { 440e8d8bef9SDimitry Andric Value *UO = getUnderlyingObject(LI->getPointerOperand()); 4410b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(UO)) { 4420b57cec5SDimitry Andric if (Arg->hasByValAttr()) { 4430b57cec5SDimitry Andric // LI is a load from a pointer within a byval kernel parameter. 44406c3fb27SDimitry Andric if (LI->getType()->isPointerTy()) 4450b57cec5SDimitry Andric markPointerAsGlobal(LI); 44606c3fb27SDimitry Andric else 44706c3fb27SDimitry Andric HandleIntToPtr(*LI); 4480b57cec5SDimitry Andric } 4490b57cec5SDimitry Andric } 4500b57cec5SDimitry Andric } 4510b57cec5SDimitry Andric } 4520b57cec5SDimitry Andric } 4530b57cec5SDimitry Andric } 4540b57cec5SDimitry Andric } 4550b57cec5SDimitry Andric 456fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); 4570b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 4580b57cec5SDimitry Andric if (Arg.getType()->isPointerTy()) { 4590b57cec5SDimitry Andric if (Arg.hasByValAttr()) 46006c3fb27SDimitry Andric handleByValParam(TM, &Arg); 46106c3fb27SDimitry Andric else if (TM.getDrvInterface() == NVPTX::CUDA) 4620b57cec5SDimitry Andric markPointerAsGlobal(&Arg); 46306c3fb27SDimitry Andric } else if (Arg.getType()->isIntegerTy() && 46406c3fb27SDimitry Andric TM.getDrvInterface() == NVPTX::CUDA) { 46506c3fb27SDimitry Andric HandleIntToPtr(Arg); 4660b57cec5SDimitry Andric } 4670b57cec5SDimitry Andric } 4680b57cec5SDimitry Andric return true; 4690b57cec5SDimitry Andric } 4700b57cec5SDimitry Andric 4710b57cec5SDimitry Andric // Device functions only need to copy byval args into local memory. 47206c3fb27SDimitry Andric bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM, 47306c3fb27SDimitry Andric Function &F) { 474fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); 4750b57cec5SDimitry Andric for (Argument &Arg : F.args()) 4760b57cec5SDimitry Andric if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 47706c3fb27SDimitry Andric handleByValParam(TM, &Arg); 4780b57cec5SDimitry Andric return true; 4790b57cec5SDimitry Andric } 4800b57cec5SDimitry Andric 4810b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnFunction(Function &F) { 48206c3fb27SDimitry Andric auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>(); 48306c3fb27SDimitry Andric 48406c3fb27SDimitry Andric return isKernelFunction(F) ? runOnKernelFunction(TM, F) 48506c3fb27SDimitry Andric : runOnDeviceFunction(TM, F); 4860b57cec5SDimitry Andric } 4870b57cec5SDimitry Andric 48806c3fb27SDimitry Andric FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); } 489