10b57cec5SDimitry Andric //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 20b57cec5SDimitry Andric // 30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60b57cec5SDimitry Andric // 70b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 80b57cec5SDimitry Andric // 90b57cec5SDimitry Andric // 100b57cec5SDimitry Andric // Arguments to kernel and device functions are passed via param space, 110b57cec5SDimitry Andric // which imposes certain restrictions: 120b57cec5SDimitry Andric // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 130b57cec5SDimitry Andric // 140b57cec5SDimitry Andric // Kernel parameters are read-only and accessible only via ld.param 150b57cec5SDimitry Andric // instruction, directly or via a pointer. Pointers to kernel 160b57cec5SDimitry Andric // arguments can't be converted to generic address space. 170b57cec5SDimitry Andric // 180b57cec5SDimitry Andric // Device function parameters are directly accessible via 190b57cec5SDimitry Andric // ld.param/st.param, but taking the address of one returns a pointer 200b57cec5SDimitry Andric // to a copy created in local space which *can't* be used with 210b57cec5SDimitry Andric // ld.param/st.param. 220b57cec5SDimitry Andric // 230b57cec5SDimitry Andric // Copying a byval struct into local memory in IR allows us to enforce 240b57cec5SDimitry Andric // the param space restrictions, gives the rest of IR a pointer w/o 250b57cec5SDimitry Andric // param space restrictions, and gives us an opportunity to eliminate 260b57cec5SDimitry Andric // the copy. 270b57cec5SDimitry Andric // 280b57cec5SDimitry Andric // Pointer arguments to kernel functions need more work to be lowered: 290b57cec5SDimitry Andric // 300b57cec5SDimitry Andric // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 310b57cec5SDimitry Andric // global address space. This allows later optimizations to emit 320b57cec5SDimitry Andric // ld.global.*/st.global.* for accessing these pointer arguments. For 330b57cec5SDimitry Andric // example, 340b57cec5SDimitry Andric // 350b57cec5SDimitry Andric // define void @foo(float* %input) { 360b57cec5SDimitry Andric // %v = load float, float* %input, align 4 370b57cec5SDimitry Andric // ... 380b57cec5SDimitry Andric // } 390b57cec5SDimitry Andric // 400b57cec5SDimitry Andric // becomes 410b57cec5SDimitry Andric // 420b57cec5SDimitry Andric // define void @foo(float* %input) { 430b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 440b57cec5SDimitry Andric // %input3 = addrspacecast float addrspace(1)* %input2 to float* 450b57cec5SDimitry Andric // %v = load float, float* %input3, align 4 460b57cec5SDimitry Andric // ... 470b57cec5SDimitry Andric // } 480b57cec5SDimitry Andric // 490b57cec5SDimitry Andric // Later, NVPTXInferAddressSpaces will optimize it to 500b57cec5SDimitry Andric // 510b57cec5SDimitry Andric // define void @foo(float* %input) { 520b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 530b57cec5SDimitry Andric // %v = load float, float addrspace(1)* %input2, align 4 540b57cec5SDimitry Andric // ... 550b57cec5SDimitry Andric // } 560b57cec5SDimitry Andric // 570b57cec5SDimitry Andric // 2. Convert pointers in a byval kernel parameter to pointers in the global 580b57cec5SDimitry Andric // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 590b57cec5SDimitry Andric // 600b57cec5SDimitry Andric // struct S { 610b57cec5SDimitry Andric // int *x; 620b57cec5SDimitry Andric // int *y; 630b57cec5SDimitry Andric // }; 640b57cec5SDimitry Andric // __global__ void foo(S s) { 650b57cec5SDimitry Andric // int *b = s.y; 660b57cec5SDimitry Andric // // use b 670b57cec5SDimitry Andric // } 680b57cec5SDimitry Andric // 690b57cec5SDimitry Andric // "b" points to the global address space. In the IR level, 700b57cec5SDimitry Andric // 710b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 720b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 730b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 740b57cec5SDimitry Andric // ; use %b 750b57cec5SDimitry Andric // } 760b57cec5SDimitry Andric // 770b57cec5SDimitry Andric // becomes 780b57cec5SDimitry Andric // 790b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 800b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 810b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 820b57cec5SDimitry Andric // %b_global = addrspacecast i32* %b to i32 addrspace(1)* 830b57cec5SDimitry Andric // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 840b57cec5SDimitry Andric // ; use %b_generic 850b57cec5SDimitry Andric // } 860b57cec5SDimitry Andric // 870b57cec5SDimitry Andric // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 880b57cec5SDimitry Andric // cancel the addrspacecast pair this pass emits. 890b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 900b57cec5SDimitry Andric 910b57cec5SDimitry Andric #include "NVPTX.h" 920b57cec5SDimitry Andric #include "NVPTXTargetMachine.h" 930b57cec5SDimitry Andric #include "NVPTXUtilities.h" 940b57cec5SDimitry Andric #include "MCTargetDesc/NVPTXBaseInfo.h" 950b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 960b57cec5SDimitry Andric #include "llvm/IR/Function.h" 970b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 980b57cec5SDimitry Andric #include "llvm/IR/Module.h" 990b57cec5SDimitry Andric #include "llvm/IR/Type.h" 1000b57cec5SDimitry Andric #include "llvm/Pass.h" 1010b57cec5SDimitry Andric 102*fe6060f1SDimitry Andric #define DEBUG_TYPE "nvptx-lower-args" 103*fe6060f1SDimitry Andric 1040b57cec5SDimitry Andric using namespace llvm; 1050b57cec5SDimitry Andric 1060b57cec5SDimitry Andric namespace llvm { 1070b57cec5SDimitry Andric void initializeNVPTXLowerArgsPass(PassRegistry &); 1080b57cec5SDimitry Andric } 1090b57cec5SDimitry Andric 1100b57cec5SDimitry Andric namespace { 1110b57cec5SDimitry Andric class NVPTXLowerArgs : public FunctionPass { 1120b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 1130b57cec5SDimitry Andric 1140b57cec5SDimitry Andric bool runOnKernelFunction(Function &F); 1150b57cec5SDimitry Andric bool runOnDeviceFunction(Function &F); 1160b57cec5SDimitry Andric 1170b57cec5SDimitry Andric // handle byval parameters 1180b57cec5SDimitry Andric void handleByValParam(Argument *Arg); 1190b57cec5SDimitry Andric // Knowing Ptr must point to the global address space, this function 1200b57cec5SDimitry Andric // addrspacecasts Ptr to global and then back to generic. This allows 1210b57cec5SDimitry Andric // NVPTXInferAddressSpaces to fold the global-to-generic cast into 1220b57cec5SDimitry Andric // loads/stores that appear later. 1230b57cec5SDimitry Andric void markPointerAsGlobal(Value *Ptr); 1240b57cec5SDimitry Andric 1250b57cec5SDimitry Andric public: 1260b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 1270b57cec5SDimitry Andric NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr) 1280b57cec5SDimitry Andric : FunctionPass(ID), TM(TM) {} 1290b57cec5SDimitry Andric StringRef getPassName() const override { 1300b57cec5SDimitry Andric return "Lower pointer arguments of CUDA kernels"; 1310b57cec5SDimitry Andric } 1320b57cec5SDimitry Andric 1330b57cec5SDimitry Andric private: 1340b57cec5SDimitry Andric const NVPTXTargetMachine *TM; 1350b57cec5SDimitry Andric }; 1360b57cec5SDimitry Andric } // namespace 1370b57cec5SDimitry Andric 1380b57cec5SDimitry Andric char NVPTXLowerArgs::ID = 1; 1390b57cec5SDimitry Andric 1400b57cec5SDimitry Andric INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args", 1410b57cec5SDimitry Andric "Lower arguments (NVPTX)", false, false) 1420b57cec5SDimitry Andric 1430b57cec5SDimitry Andric // ============================================================================= 1440b57cec5SDimitry Andric // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 145*fe6060f1SDimitry Andric // and we can't guarantee that the only accesses are loads, 1460b57cec5SDimitry Andric // then add the following instructions to the first basic block: 1470b57cec5SDimitry Andric // 1480b57cec5SDimitry Andric // %temp = alloca %struct.x, align 8 1490b57cec5SDimitry Andric // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 1500b57cec5SDimitry Andric // %tv = load %struct.x addrspace(101)* %tempd 1510b57cec5SDimitry Andric // store %struct.x %tv, %struct.x* %temp, align 8 1520b57cec5SDimitry Andric // 1530b57cec5SDimitry Andric // The above code allocates some space in the stack and copies the incoming 1540b57cec5SDimitry Andric // struct from param space to local space. 1550b57cec5SDimitry Andric // Then replace all occurrences of %d by %temp. 156*fe6060f1SDimitry Andric // 157*fe6060f1SDimitry Andric // In case we know that all users are GEPs or Loads, replace them with the same 158*fe6060f1SDimitry Andric // ones in parameter AS, so we can access them using ld.param. 1590b57cec5SDimitry Andric // ============================================================================= 160*fe6060f1SDimitry Andric 161*fe6060f1SDimitry Andric // Replaces the \p OldUser instruction with the same in parameter AS. 162*fe6060f1SDimitry Andric // Only Load and GEP are supported. 163*fe6060f1SDimitry Andric static void convertToParamAS(Value *OldUser, Value *Param) { 164*fe6060f1SDimitry Andric Instruction *I = dyn_cast<Instruction>(OldUser); 165*fe6060f1SDimitry Andric assert(I && "OldUser must be an instruction"); 166*fe6060f1SDimitry Andric struct IP { 167*fe6060f1SDimitry Andric Instruction *OldInstruction; 168*fe6060f1SDimitry Andric Value *NewParam; 169*fe6060f1SDimitry Andric }; 170*fe6060f1SDimitry Andric SmallVector<IP> ItemsToConvert = {{I, Param}}; 171*fe6060f1SDimitry Andric SmallVector<Instruction *> InstructionsToDelete; 172*fe6060f1SDimitry Andric 173*fe6060f1SDimitry Andric auto CloneInstInParamAS = [](const IP &I) -> Value * { 174*fe6060f1SDimitry Andric if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { 175*fe6060f1SDimitry Andric LI->setOperand(0, I.NewParam); 176*fe6060f1SDimitry Andric return LI; 177*fe6060f1SDimitry Andric } 178*fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { 179*fe6060f1SDimitry Andric SmallVector<Value *, 4> Indices(GEP->indices()); 180*fe6060f1SDimitry Andric auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(), 181*fe6060f1SDimitry Andric I.NewParam, Indices, 182*fe6060f1SDimitry Andric GEP->getName(), GEP); 183*fe6060f1SDimitry Andric NewGEP->setIsInBounds(GEP->isInBounds()); 184*fe6060f1SDimitry Andric return NewGEP; 185*fe6060f1SDimitry Andric } 186*fe6060f1SDimitry Andric if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { 187*fe6060f1SDimitry Andric auto *NewBCType = PointerType::getWithSamePointeeType( 188*fe6060f1SDimitry Andric cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM); 189*fe6060f1SDimitry Andric return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, 190*fe6060f1SDimitry Andric BC->getName(), BC); 191*fe6060f1SDimitry Andric } 192*fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { 193*fe6060f1SDimitry Andric assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); 194*fe6060f1SDimitry Andric (void)ASC; 195*fe6060f1SDimitry Andric // Just pass through the argument, the old ASC is no longer needed. 196*fe6060f1SDimitry Andric return I.NewParam; 197*fe6060f1SDimitry Andric } 198*fe6060f1SDimitry Andric llvm_unreachable("Unsupported instruction"); 199*fe6060f1SDimitry Andric }; 200*fe6060f1SDimitry Andric 201*fe6060f1SDimitry Andric while (!ItemsToConvert.empty()) { 202*fe6060f1SDimitry Andric IP I = ItemsToConvert.pop_back_val(); 203*fe6060f1SDimitry Andric Value *NewInst = CloneInstInParamAS(I); 204*fe6060f1SDimitry Andric 205*fe6060f1SDimitry Andric if (NewInst && NewInst != I.OldInstruction) { 206*fe6060f1SDimitry Andric // We've created a new instruction. Queue users of the old instruction to 207*fe6060f1SDimitry Andric // be converted and the instruction itself to be deleted. We can't delete 208*fe6060f1SDimitry Andric // the old instruction yet, because it's still in use by a load somewhere. 209*fe6060f1SDimitry Andric llvm::for_each( 210*fe6060f1SDimitry Andric I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) { 211*fe6060f1SDimitry Andric ItemsToConvert.push_back({cast<Instruction>(V), NewInst}); 212*fe6060f1SDimitry Andric }); 213*fe6060f1SDimitry Andric 214*fe6060f1SDimitry Andric InstructionsToDelete.push_back(I.OldInstruction); 215*fe6060f1SDimitry Andric } 216*fe6060f1SDimitry Andric } 217*fe6060f1SDimitry Andric 218*fe6060f1SDimitry Andric // Now we know that all argument loads are using addresses in parameter space 219*fe6060f1SDimitry Andric // and we can finally remove the old instructions in generic AS. Instructions 220*fe6060f1SDimitry Andric // scheduled for removal should be processed in reverse order so the ones 221*fe6060f1SDimitry Andric // closest to the load are deleted first. Otherwise they may still be in use. 222*fe6060f1SDimitry Andric // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will 223*fe6060f1SDimitry Andric // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by 224*fe6060f1SDimitry Andric // the BitCast. 225*fe6060f1SDimitry Andric llvm::for_each(reverse(InstructionsToDelete), 226*fe6060f1SDimitry Andric [](Instruction *I) { I->eraseFromParent(); }); 227*fe6060f1SDimitry Andric } 228*fe6060f1SDimitry Andric 2290b57cec5SDimitry Andric void NVPTXLowerArgs::handleByValParam(Argument *Arg) { 2300b57cec5SDimitry Andric Function *Func = Arg->getParent(); 2310b57cec5SDimitry Andric Instruction *FirstInst = &(Func->getEntryBlock().front()); 2320b57cec5SDimitry Andric PointerType *PType = dyn_cast<PointerType>(Arg->getType()); 2330b57cec5SDimitry Andric 2340b57cec5SDimitry Andric assert(PType && "Expecting pointer type in handleByValParam"); 2350b57cec5SDimitry Andric 2360b57cec5SDimitry Andric Type *StructType = PType->getElementType(); 237*fe6060f1SDimitry Andric 238*fe6060f1SDimitry Andric auto IsALoadChain = [&](Value *Start) { 239*fe6060f1SDimitry Andric SmallVector<Value *, 16> ValuesToCheck = {Start}; 240*fe6060f1SDimitry Andric auto IsALoadChainInstr = [](Value *V) -> bool { 241*fe6060f1SDimitry Andric if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V)) 242*fe6060f1SDimitry Andric return true; 243*fe6060f1SDimitry Andric // ASC to param space are OK, too -- we'll just strip them. 244*fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) { 245*fe6060f1SDimitry Andric if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) 246*fe6060f1SDimitry Andric return true; 247*fe6060f1SDimitry Andric } 248*fe6060f1SDimitry Andric return false; 249*fe6060f1SDimitry Andric }; 250*fe6060f1SDimitry Andric 251*fe6060f1SDimitry Andric while (!ValuesToCheck.empty()) { 252*fe6060f1SDimitry Andric Value *V = ValuesToCheck.pop_back_val(); 253*fe6060f1SDimitry Andric if (!IsALoadChainInstr(V)) { 254*fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V 255*fe6060f1SDimitry Andric << "\n"); 256*fe6060f1SDimitry Andric (void)Arg; 257*fe6060f1SDimitry Andric return false; 258*fe6060f1SDimitry Andric } 259*fe6060f1SDimitry Andric if (!isa<LoadInst>(V)) 260*fe6060f1SDimitry Andric llvm::append_range(ValuesToCheck, V->users()); 261*fe6060f1SDimitry Andric } 262*fe6060f1SDimitry Andric return true; 263*fe6060f1SDimitry Andric }; 264*fe6060f1SDimitry Andric 265*fe6060f1SDimitry Andric if (llvm::all_of(Arg->users(), IsALoadChain)) { 266*fe6060f1SDimitry Andric // Convert all loads and intermediate operations to use parameter AS and 267*fe6060f1SDimitry Andric // skip creation of a local copy of the argument. 268*fe6060f1SDimitry Andric SmallVector<User *, 16> UsersToUpdate(Arg->users()); 269*fe6060f1SDimitry Andric Value *ArgInParamAS = new AddrSpaceCastInst( 270*fe6060f1SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 271*fe6060f1SDimitry Andric FirstInst); 272*fe6060f1SDimitry Andric llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) { 273*fe6060f1SDimitry Andric convertToParamAS(V, ArgInParamAS); 274*fe6060f1SDimitry Andric }); 275*fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); 276*fe6060f1SDimitry Andric return; 277*fe6060f1SDimitry Andric } 278*fe6060f1SDimitry Andric 279*fe6060f1SDimitry Andric // Otherwise we have to create a temporary copy. 2805ffd83dbSDimitry Andric const DataLayout &DL = Func->getParent()->getDataLayout(); 2815ffd83dbSDimitry Andric unsigned AS = DL.getAllocaAddrSpace(); 2820b57cec5SDimitry Andric AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 2830b57cec5SDimitry Andric // Set the alignment to alignment of the byval parameter. This is because, 2840b57cec5SDimitry Andric // later load/stores assume that alignment, and we are going to replace 2850b57cec5SDimitry Andric // the use of the byval parameter with this alloca instruction. 2865ffd83dbSDimitry Andric AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo()) 2875ffd83dbSDimitry Andric .getValueOr(DL.getPrefTypeAlign(StructType))); 2880b57cec5SDimitry Andric Arg->replaceAllUsesWith(AllocA); 2890b57cec5SDimitry Andric 2900b57cec5SDimitry Andric Value *ArgInParam = new AddrSpaceCastInst( 2910b57cec5SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 2920b57cec5SDimitry Andric FirstInst); 293e8d8bef9SDimitry Andric // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX 294e8d8bef9SDimitry Andric // addrspacecast preserves alignment. Since params are constant, this load is 295e8d8bef9SDimitry Andric // definitely not volatile. 2960b57cec5SDimitry Andric LoadInst *LI = 297e8d8bef9SDimitry Andric new LoadInst(StructType, ArgInParam, Arg->getName(), 298e8d8bef9SDimitry Andric /*isVolatile=*/false, AllocA->getAlign(), FirstInst); 2990b57cec5SDimitry Andric new StoreInst(LI, AllocA, FirstInst); 3000b57cec5SDimitry Andric } 3010b57cec5SDimitry Andric 3020b57cec5SDimitry Andric void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 3030b57cec5SDimitry Andric if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) 3040b57cec5SDimitry Andric return; 3050b57cec5SDimitry Andric 3060b57cec5SDimitry Andric // Deciding where to emit the addrspacecast pair. 3070b57cec5SDimitry Andric BasicBlock::iterator InsertPt; 3080b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 3090b57cec5SDimitry Andric // Insert at the functon entry if Ptr is an argument. 3100b57cec5SDimitry Andric InsertPt = Arg->getParent()->getEntryBlock().begin(); 3110b57cec5SDimitry Andric } else { 3120b57cec5SDimitry Andric // Insert right after Ptr if Ptr is an instruction. 3130b57cec5SDimitry Andric InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 3140b57cec5SDimitry Andric assert(InsertPt != InsertPt->getParent()->end() && 3150b57cec5SDimitry Andric "We don't call this function with Ptr being a terminator."); 3160b57cec5SDimitry Andric } 3170b57cec5SDimitry Andric 3180b57cec5SDimitry Andric Instruction *PtrInGlobal = new AddrSpaceCastInst( 319*fe6060f1SDimitry Andric Ptr, 320*fe6060f1SDimitry Andric PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()), 3210b57cec5SDimitry Andric ADDRESS_SPACE_GLOBAL), 3220b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 3230b57cec5SDimitry Andric Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 3240b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 3250b57cec5SDimitry Andric // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 3260b57cec5SDimitry Andric Ptr->replaceAllUsesWith(PtrInGeneric); 3270b57cec5SDimitry Andric PtrInGlobal->setOperand(0, Ptr); 3280b57cec5SDimitry Andric } 3290b57cec5SDimitry Andric 3300b57cec5SDimitry Andric // ============================================================================= 3310b57cec5SDimitry Andric // Main function for this pass. 3320b57cec5SDimitry Andric // ============================================================================= 3330b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { 3340b57cec5SDimitry Andric if (TM && TM->getDrvInterface() == NVPTX::CUDA) { 3350b57cec5SDimitry Andric // Mark pointers in byval structs as global. 3360b57cec5SDimitry Andric for (auto &B : F) { 3370b57cec5SDimitry Andric for (auto &I : B) { 3380b57cec5SDimitry Andric if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 3390b57cec5SDimitry Andric if (LI->getType()->isPointerTy()) { 340e8d8bef9SDimitry Andric Value *UO = getUnderlyingObject(LI->getPointerOperand()); 3410b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(UO)) { 3420b57cec5SDimitry Andric if (Arg->hasByValAttr()) { 3430b57cec5SDimitry Andric // LI is a load from a pointer within a byval kernel parameter. 3440b57cec5SDimitry Andric markPointerAsGlobal(LI); 3450b57cec5SDimitry Andric } 3460b57cec5SDimitry Andric } 3470b57cec5SDimitry Andric } 3480b57cec5SDimitry Andric } 3490b57cec5SDimitry Andric } 3500b57cec5SDimitry Andric } 3510b57cec5SDimitry Andric } 3520b57cec5SDimitry Andric 353*fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); 3540b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 3550b57cec5SDimitry Andric if (Arg.getType()->isPointerTy()) { 3560b57cec5SDimitry Andric if (Arg.hasByValAttr()) 3570b57cec5SDimitry Andric handleByValParam(&Arg); 3580b57cec5SDimitry Andric else if (TM && TM->getDrvInterface() == NVPTX::CUDA) 3590b57cec5SDimitry Andric markPointerAsGlobal(&Arg); 3600b57cec5SDimitry Andric } 3610b57cec5SDimitry Andric } 3620b57cec5SDimitry Andric return true; 3630b57cec5SDimitry Andric } 3640b57cec5SDimitry Andric 3650b57cec5SDimitry Andric // Device functions only need to copy byval args into local memory. 3660b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { 367*fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); 3680b57cec5SDimitry Andric for (Argument &Arg : F.args()) 3690b57cec5SDimitry Andric if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 3700b57cec5SDimitry Andric handleByValParam(&Arg); 3710b57cec5SDimitry Andric return true; 3720b57cec5SDimitry Andric } 3730b57cec5SDimitry Andric 3740b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnFunction(Function &F) { 3750b57cec5SDimitry Andric return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F); 3760b57cec5SDimitry Andric } 3770b57cec5SDimitry Andric 3780b57cec5SDimitry Andric FunctionPass * 3790b57cec5SDimitry Andric llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) { 3800b57cec5SDimitry Andric return new NVPTXLowerArgs(TM); 3810b57cec5SDimitry Andric } 382