10b57cec5SDimitry Andric //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 20b57cec5SDimitry Andric // 30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60b57cec5SDimitry Andric // 70b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 80b57cec5SDimitry Andric // 90b57cec5SDimitry Andric // 100b57cec5SDimitry Andric // Arguments to kernel and device functions are passed via param space, 110b57cec5SDimitry Andric // which imposes certain restrictions: 120b57cec5SDimitry Andric // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 130b57cec5SDimitry Andric // 140b57cec5SDimitry Andric // Kernel parameters are read-only and accessible only via ld.param 150b57cec5SDimitry Andric // instruction, directly or via a pointer. Pointers to kernel 160b57cec5SDimitry Andric // arguments can't be converted to generic address space. 170b57cec5SDimitry Andric // 180b57cec5SDimitry Andric // Device function parameters are directly accessible via 190b57cec5SDimitry Andric // ld.param/st.param, but taking the address of one returns a pointer 200b57cec5SDimitry Andric // to a copy created in local space which *can't* be used with 210b57cec5SDimitry Andric // ld.param/st.param. 220b57cec5SDimitry Andric // 230b57cec5SDimitry Andric // Copying a byval struct into local memory in IR allows us to enforce 240b57cec5SDimitry Andric // the param space restrictions, gives the rest of IR a pointer w/o 250b57cec5SDimitry Andric // param space restrictions, and gives us an opportunity to eliminate 260b57cec5SDimitry Andric // the copy. 270b57cec5SDimitry Andric // 280b57cec5SDimitry Andric // Pointer arguments to kernel functions need more work to be lowered: 290b57cec5SDimitry Andric // 300b57cec5SDimitry Andric // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 310b57cec5SDimitry Andric // global address space. This allows later optimizations to emit 320b57cec5SDimitry Andric // ld.global.*/st.global.* for accessing these pointer arguments. For 330b57cec5SDimitry Andric // example, 340b57cec5SDimitry Andric // 350b57cec5SDimitry Andric // define void @foo(float* %input) { 360b57cec5SDimitry Andric // %v = load float, float* %input, align 4 370b57cec5SDimitry Andric // ... 380b57cec5SDimitry Andric // } 390b57cec5SDimitry Andric // 400b57cec5SDimitry Andric // becomes 410b57cec5SDimitry Andric // 420b57cec5SDimitry Andric // define void @foo(float* %input) { 430b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 440b57cec5SDimitry Andric // %input3 = addrspacecast float addrspace(1)* %input2 to float* 450b57cec5SDimitry Andric // %v = load float, float* %input3, align 4 460b57cec5SDimitry Andric // ... 470b57cec5SDimitry Andric // } 480b57cec5SDimitry Andric // 490b57cec5SDimitry Andric // Later, NVPTXInferAddressSpaces will optimize it to 500b57cec5SDimitry Andric // 510b57cec5SDimitry Andric // define void @foo(float* %input) { 520b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 530b57cec5SDimitry Andric // %v = load float, float addrspace(1)* %input2, align 4 540b57cec5SDimitry Andric // ... 550b57cec5SDimitry Andric // } 560b57cec5SDimitry Andric // 570b57cec5SDimitry Andric // 2. Convert pointers in a byval kernel parameter to pointers in the global 580b57cec5SDimitry Andric // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 590b57cec5SDimitry Andric // 600b57cec5SDimitry Andric // struct S { 610b57cec5SDimitry Andric // int *x; 620b57cec5SDimitry Andric // int *y; 630b57cec5SDimitry Andric // }; 640b57cec5SDimitry Andric // __global__ void foo(S s) { 650b57cec5SDimitry Andric // int *b = s.y; 660b57cec5SDimitry Andric // // use b 670b57cec5SDimitry Andric // } 680b57cec5SDimitry Andric // 690b57cec5SDimitry Andric // "b" points to the global address space. In the IR level, 700b57cec5SDimitry Andric // 710b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 720b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 730b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 740b57cec5SDimitry Andric // ; use %b 750b57cec5SDimitry Andric // } 760b57cec5SDimitry Andric // 770b57cec5SDimitry Andric // becomes 780b57cec5SDimitry Andric // 790b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 800b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 810b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 820b57cec5SDimitry Andric // %b_global = addrspacecast i32* %b to i32 addrspace(1)* 830b57cec5SDimitry Andric // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 840b57cec5SDimitry Andric // ; use %b_generic 850b57cec5SDimitry Andric // } 860b57cec5SDimitry Andric // 870b57cec5SDimitry Andric // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 880b57cec5SDimitry Andric // cancel the addrspacecast pair this pass emits. 890b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 900b57cec5SDimitry Andric 9181ad6265SDimitry Andric #include "MCTargetDesc/NVPTXBaseInfo.h" 920b57cec5SDimitry Andric #include "NVPTX.h" 930b57cec5SDimitry Andric #include "NVPTXTargetMachine.h" 940b57cec5SDimitry Andric #include "NVPTXUtilities.h" 950b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 960b57cec5SDimitry Andric #include "llvm/IR/Function.h" 970b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 980b57cec5SDimitry Andric #include "llvm/IR/Module.h" 990b57cec5SDimitry Andric #include "llvm/IR/Type.h" 1000b57cec5SDimitry Andric #include "llvm/Pass.h" 101*bdd1243dSDimitry Andric #include <numeric> 10281ad6265SDimitry Andric #include <queue> 1030b57cec5SDimitry Andric 104fe6060f1SDimitry Andric #define DEBUG_TYPE "nvptx-lower-args" 105fe6060f1SDimitry Andric 1060b57cec5SDimitry Andric using namespace llvm; 1070b57cec5SDimitry Andric 1080b57cec5SDimitry Andric namespace llvm { 1090b57cec5SDimitry Andric void initializeNVPTXLowerArgsPass(PassRegistry &); 1100b57cec5SDimitry Andric } 1110b57cec5SDimitry Andric 1120b57cec5SDimitry Andric namespace { 1130b57cec5SDimitry Andric class NVPTXLowerArgs : public FunctionPass { 1140b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 1150b57cec5SDimitry Andric 1160b57cec5SDimitry Andric bool runOnKernelFunction(Function &F); 1170b57cec5SDimitry Andric bool runOnDeviceFunction(Function &F); 1180b57cec5SDimitry Andric 1190b57cec5SDimitry Andric // handle byval parameters 1200b57cec5SDimitry Andric void handleByValParam(Argument *Arg); 1210b57cec5SDimitry Andric // Knowing Ptr must point to the global address space, this function 1220b57cec5SDimitry Andric // addrspacecasts Ptr to global and then back to generic. This allows 1230b57cec5SDimitry Andric // NVPTXInferAddressSpaces to fold the global-to-generic cast into 1240b57cec5SDimitry Andric // loads/stores that appear later. 1250b57cec5SDimitry Andric void markPointerAsGlobal(Value *Ptr); 1260b57cec5SDimitry Andric 1270b57cec5SDimitry Andric public: 1280b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 1290b57cec5SDimitry Andric NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr) 1300b57cec5SDimitry Andric : FunctionPass(ID), TM(TM) {} 1310b57cec5SDimitry Andric StringRef getPassName() const override { 1320b57cec5SDimitry Andric return "Lower pointer arguments of CUDA kernels"; 1330b57cec5SDimitry Andric } 1340b57cec5SDimitry Andric 1350b57cec5SDimitry Andric private: 1360b57cec5SDimitry Andric const NVPTXTargetMachine *TM; 1370b57cec5SDimitry Andric }; 1380b57cec5SDimitry Andric } // namespace 1390b57cec5SDimitry Andric 1400b57cec5SDimitry Andric char NVPTXLowerArgs::ID = 1; 1410b57cec5SDimitry Andric 1420b57cec5SDimitry Andric INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args", 1430b57cec5SDimitry Andric "Lower arguments (NVPTX)", false, false) 1440b57cec5SDimitry Andric 1450b57cec5SDimitry Andric // ============================================================================= 1460b57cec5SDimitry Andric // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 147fe6060f1SDimitry Andric // and we can't guarantee that the only accesses are loads, 1480b57cec5SDimitry Andric // then add the following instructions to the first basic block: 1490b57cec5SDimitry Andric // 1500b57cec5SDimitry Andric // %temp = alloca %struct.x, align 8 1510b57cec5SDimitry Andric // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 1520b57cec5SDimitry Andric // %tv = load %struct.x addrspace(101)* %tempd 1530b57cec5SDimitry Andric // store %struct.x %tv, %struct.x* %temp, align 8 1540b57cec5SDimitry Andric // 1550b57cec5SDimitry Andric // The above code allocates some space in the stack and copies the incoming 1560b57cec5SDimitry Andric // struct from param space to local space. 1570b57cec5SDimitry Andric // Then replace all occurrences of %d by %temp. 158fe6060f1SDimitry Andric // 159fe6060f1SDimitry Andric // In case we know that all users are GEPs or Loads, replace them with the same 160fe6060f1SDimitry Andric // ones in parameter AS, so we can access them using ld.param. 1610b57cec5SDimitry Andric // ============================================================================= 162fe6060f1SDimitry Andric 163fe6060f1SDimitry Andric // Replaces the \p OldUser instruction with the same in parameter AS. 164fe6060f1SDimitry Andric // Only Load and GEP are supported. 165fe6060f1SDimitry Andric static void convertToParamAS(Value *OldUser, Value *Param) { 166fe6060f1SDimitry Andric Instruction *I = dyn_cast<Instruction>(OldUser); 167fe6060f1SDimitry Andric assert(I && "OldUser must be an instruction"); 168fe6060f1SDimitry Andric struct IP { 169fe6060f1SDimitry Andric Instruction *OldInstruction; 170fe6060f1SDimitry Andric Value *NewParam; 171fe6060f1SDimitry Andric }; 172fe6060f1SDimitry Andric SmallVector<IP> ItemsToConvert = {{I, Param}}; 173fe6060f1SDimitry Andric SmallVector<Instruction *> InstructionsToDelete; 174fe6060f1SDimitry Andric 175fe6060f1SDimitry Andric auto CloneInstInParamAS = [](const IP &I) -> Value * { 176fe6060f1SDimitry Andric if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { 177fe6060f1SDimitry Andric LI->setOperand(0, I.NewParam); 178fe6060f1SDimitry Andric return LI; 179fe6060f1SDimitry Andric } 180fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { 181fe6060f1SDimitry Andric SmallVector<Value *, 4> Indices(GEP->indices()); 182fe6060f1SDimitry Andric auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(), 183fe6060f1SDimitry Andric I.NewParam, Indices, 184fe6060f1SDimitry Andric GEP->getName(), GEP); 185fe6060f1SDimitry Andric NewGEP->setIsInBounds(GEP->isInBounds()); 186fe6060f1SDimitry Andric return NewGEP; 187fe6060f1SDimitry Andric } 188fe6060f1SDimitry Andric if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { 189fe6060f1SDimitry Andric auto *NewBCType = PointerType::getWithSamePointeeType( 190fe6060f1SDimitry Andric cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM); 191fe6060f1SDimitry Andric return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, 192fe6060f1SDimitry Andric BC->getName(), BC); 193fe6060f1SDimitry Andric } 194fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { 195fe6060f1SDimitry Andric assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); 196fe6060f1SDimitry Andric (void)ASC; 197fe6060f1SDimitry Andric // Just pass through the argument, the old ASC is no longer needed. 198fe6060f1SDimitry Andric return I.NewParam; 199fe6060f1SDimitry Andric } 200fe6060f1SDimitry Andric llvm_unreachable("Unsupported instruction"); 201fe6060f1SDimitry Andric }; 202fe6060f1SDimitry Andric 203fe6060f1SDimitry Andric while (!ItemsToConvert.empty()) { 204fe6060f1SDimitry Andric IP I = ItemsToConvert.pop_back_val(); 205fe6060f1SDimitry Andric Value *NewInst = CloneInstInParamAS(I); 206fe6060f1SDimitry Andric 207fe6060f1SDimitry Andric if (NewInst && NewInst != I.OldInstruction) { 208fe6060f1SDimitry Andric // We've created a new instruction. Queue users of the old instruction to 209fe6060f1SDimitry Andric // be converted and the instruction itself to be deleted. We can't delete 210fe6060f1SDimitry Andric // the old instruction yet, because it's still in use by a load somewhere. 21181ad6265SDimitry Andric for (Value *V : I.OldInstruction->users()) 212fe6060f1SDimitry Andric ItemsToConvert.push_back({cast<Instruction>(V), NewInst}); 213fe6060f1SDimitry Andric 214fe6060f1SDimitry Andric InstructionsToDelete.push_back(I.OldInstruction); 215fe6060f1SDimitry Andric } 216fe6060f1SDimitry Andric } 217fe6060f1SDimitry Andric 218fe6060f1SDimitry Andric // Now we know that all argument loads are using addresses in parameter space 219fe6060f1SDimitry Andric // and we can finally remove the old instructions in generic AS. Instructions 220fe6060f1SDimitry Andric // scheduled for removal should be processed in reverse order so the ones 221fe6060f1SDimitry Andric // closest to the load are deleted first. Otherwise they may still be in use. 222fe6060f1SDimitry Andric // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will 223fe6060f1SDimitry Andric // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by 224fe6060f1SDimitry Andric // the BitCast. 22581ad6265SDimitry Andric for (Instruction *I : llvm::reverse(InstructionsToDelete)) 22681ad6265SDimitry Andric I->eraseFromParent(); 22781ad6265SDimitry Andric } 22881ad6265SDimitry Andric 22981ad6265SDimitry Andric // Adjust alignment of arguments passed byval in .param address space. We can 23081ad6265SDimitry Andric // increase alignment of such arguments in a way that ensures that we can 23181ad6265SDimitry Andric // effectively vectorize their loads. We should also traverse all loads from 23281ad6265SDimitry Andric // byval pointer and adjust their alignment, if those were using known offset. 23381ad6265SDimitry Andric // Such alignment changes must be conformed with parameter store and load in 23481ad6265SDimitry Andric // NVPTXTargetLowering::LowerCall. 23581ad6265SDimitry Andric static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, 23681ad6265SDimitry Andric const NVPTXTargetLowering *TLI) { 23781ad6265SDimitry Andric Function *Func = Arg->getParent(); 23881ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 23981ad6265SDimitry Andric const DataLayout DL(Func->getParent()); 24081ad6265SDimitry Andric 24181ad6265SDimitry Andric uint64_t NewArgAlign = 24281ad6265SDimitry Andric TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value(); 24381ad6265SDimitry Andric uint64_t CurArgAlign = 24481ad6265SDimitry Andric Arg->getAttribute(Attribute::Alignment).getValueAsInt(); 24581ad6265SDimitry Andric 24681ad6265SDimitry Andric if (CurArgAlign >= NewArgAlign) 24781ad6265SDimitry Andric return; 24881ad6265SDimitry Andric 24981ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of " 25081ad6265SDimitry Andric << CurArgAlign << " for " << *Arg << '\n'); 25181ad6265SDimitry Andric 25281ad6265SDimitry Andric auto NewAlignAttr = 25381ad6265SDimitry Andric Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign); 25481ad6265SDimitry Andric Arg->removeAttr(Attribute::Alignment); 25581ad6265SDimitry Andric Arg->addAttr(NewAlignAttr); 25681ad6265SDimitry Andric 25781ad6265SDimitry Andric struct Load { 25881ad6265SDimitry Andric LoadInst *Inst; 25981ad6265SDimitry Andric uint64_t Offset; 26081ad6265SDimitry Andric }; 26181ad6265SDimitry Andric 26281ad6265SDimitry Andric struct LoadContext { 26381ad6265SDimitry Andric Value *InitialVal; 26481ad6265SDimitry Andric uint64_t Offset; 26581ad6265SDimitry Andric }; 26681ad6265SDimitry Andric 26781ad6265SDimitry Andric SmallVector<Load> Loads; 26881ad6265SDimitry Andric std::queue<LoadContext> Worklist; 26981ad6265SDimitry Andric Worklist.push({ArgInParamAS, 0}); 27081ad6265SDimitry Andric 27181ad6265SDimitry Andric while (!Worklist.empty()) { 27281ad6265SDimitry Andric LoadContext Ctx = Worklist.front(); 27381ad6265SDimitry Andric Worklist.pop(); 27481ad6265SDimitry Andric 27581ad6265SDimitry Andric for (User *CurUser : Ctx.InitialVal->users()) { 27681ad6265SDimitry Andric if (auto *I = dyn_cast<LoadInst>(CurUser)) { 27781ad6265SDimitry Andric Loads.push_back({I, Ctx.Offset}); 27881ad6265SDimitry Andric continue; 27981ad6265SDimitry Andric } 28081ad6265SDimitry Andric 28181ad6265SDimitry Andric if (auto *I = dyn_cast<BitCastInst>(CurUser)) { 28281ad6265SDimitry Andric Worklist.push({I, Ctx.Offset}); 28381ad6265SDimitry Andric continue; 28481ad6265SDimitry Andric } 28581ad6265SDimitry Andric 28681ad6265SDimitry Andric if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) { 28781ad6265SDimitry Andric APInt OffsetAccumulated = 28881ad6265SDimitry Andric APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM)); 28981ad6265SDimitry Andric 29081ad6265SDimitry Andric if (!I->accumulateConstantOffset(DL, OffsetAccumulated)) 29181ad6265SDimitry Andric continue; 29281ad6265SDimitry Andric 29381ad6265SDimitry Andric uint64_t OffsetLimit = -1; 29481ad6265SDimitry Andric uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit); 29581ad6265SDimitry Andric assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX"); 29681ad6265SDimitry Andric 29781ad6265SDimitry Andric Worklist.push({I, Ctx.Offset + Offset}); 29881ad6265SDimitry Andric continue; 29981ad6265SDimitry Andric } 30081ad6265SDimitry Andric 30181ad6265SDimitry Andric llvm_unreachable("All users must be one of: load, " 30281ad6265SDimitry Andric "bitcast, getelementptr."); 30381ad6265SDimitry Andric } 30481ad6265SDimitry Andric } 30581ad6265SDimitry Andric 30681ad6265SDimitry Andric for (Load &CurLoad : Loads) { 307*bdd1243dSDimitry Andric Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset)); 30881ad6265SDimitry Andric Align CurLoadAlign(CurLoad.Inst->getAlign()); 30981ad6265SDimitry Andric CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign)); 31081ad6265SDimitry Andric } 311fe6060f1SDimitry Andric } 312fe6060f1SDimitry Andric 3130b57cec5SDimitry Andric void NVPTXLowerArgs::handleByValParam(Argument *Arg) { 3140b57cec5SDimitry Andric Function *Func = Arg->getParent(); 3150b57cec5SDimitry Andric Instruction *FirstInst = &(Func->getEntryBlock().front()); 31681ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 31781ad6265SDimitry Andric assert(StructType && "Missing byval type"); 318fe6060f1SDimitry Andric 319fe6060f1SDimitry Andric auto IsALoadChain = [&](Value *Start) { 320fe6060f1SDimitry Andric SmallVector<Value *, 16> ValuesToCheck = {Start}; 321fe6060f1SDimitry Andric auto IsALoadChainInstr = [](Value *V) -> bool { 322fe6060f1SDimitry Andric if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V)) 323fe6060f1SDimitry Andric return true; 324fe6060f1SDimitry Andric // ASC to param space are OK, too -- we'll just strip them. 325fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) { 326fe6060f1SDimitry Andric if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) 327fe6060f1SDimitry Andric return true; 328fe6060f1SDimitry Andric } 329fe6060f1SDimitry Andric return false; 330fe6060f1SDimitry Andric }; 331fe6060f1SDimitry Andric 332fe6060f1SDimitry Andric while (!ValuesToCheck.empty()) { 333fe6060f1SDimitry Andric Value *V = ValuesToCheck.pop_back_val(); 334fe6060f1SDimitry Andric if (!IsALoadChainInstr(V)) { 335fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V 336fe6060f1SDimitry Andric << "\n"); 337fe6060f1SDimitry Andric (void)Arg; 338fe6060f1SDimitry Andric return false; 339fe6060f1SDimitry Andric } 340fe6060f1SDimitry Andric if (!isa<LoadInst>(V)) 341fe6060f1SDimitry Andric llvm::append_range(ValuesToCheck, V->users()); 342fe6060f1SDimitry Andric } 343fe6060f1SDimitry Andric return true; 344fe6060f1SDimitry Andric }; 345fe6060f1SDimitry Andric 346fe6060f1SDimitry Andric if (llvm::all_of(Arg->users(), IsALoadChain)) { 347fe6060f1SDimitry Andric // Convert all loads and intermediate operations to use parameter AS and 348fe6060f1SDimitry Andric // skip creation of a local copy of the argument. 349fe6060f1SDimitry Andric SmallVector<User *, 16> UsersToUpdate(Arg->users()); 350fe6060f1SDimitry Andric Value *ArgInParamAS = new AddrSpaceCastInst( 351fe6060f1SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 352fe6060f1SDimitry Andric FirstInst); 35381ad6265SDimitry Andric for (Value *V : UsersToUpdate) 354fe6060f1SDimitry Andric convertToParamAS(V, ArgInParamAS); 355fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); 35681ad6265SDimitry Andric 35781ad6265SDimitry Andric // Further optimizations require target lowering info. 35881ad6265SDimitry Andric if (!TM) 35981ad6265SDimitry Andric return; 36081ad6265SDimitry Andric 36181ad6265SDimitry Andric const auto *TLI = 36281ad6265SDimitry Andric cast<NVPTXTargetLowering>(TM->getSubtargetImpl()->getTargetLowering()); 36381ad6265SDimitry Andric 36481ad6265SDimitry Andric adjustByValArgAlignment(Arg, ArgInParamAS, TLI); 36581ad6265SDimitry Andric 366fe6060f1SDimitry Andric return; 367fe6060f1SDimitry Andric } 368fe6060f1SDimitry Andric 369fe6060f1SDimitry Andric // Otherwise we have to create a temporary copy. 3705ffd83dbSDimitry Andric const DataLayout &DL = Func->getParent()->getDataLayout(); 3715ffd83dbSDimitry Andric unsigned AS = DL.getAllocaAddrSpace(); 3720b57cec5SDimitry Andric AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 3730b57cec5SDimitry Andric // Set the alignment to alignment of the byval parameter. This is because, 3740b57cec5SDimitry Andric // later load/stores assume that alignment, and we are going to replace 3750b57cec5SDimitry Andric // the use of the byval parameter with this alloca instruction. 3765ffd83dbSDimitry Andric AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo()) 37781ad6265SDimitry Andric .value_or(DL.getPrefTypeAlign(StructType))); 3780b57cec5SDimitry Andric Arg->replaceAllUsesWith(AllocA); 3790b57cec5SDimitry Andric 3800b57cec5SDimitry Andric Value *ArgInParam = new AddrSpaceCastInst( 3810b57cec5SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 3820b57cec5SDimitry Andric FirstInst); 383e8d8bef9SDimitry Andric // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX 384e8d8bef9SDimitry Andric // addrspacecast preserves alignment. Since params are constant, this load is 385e8d8bef9SDimitry Andric // definitely not volatile. 3860b57cec5SDimitry Andric LoadInst *LI = 387e8d8bef9SDimitry Andric new LoadInst(StructType, ArgInParam, Arg->getName(), 388e8d8bef9SDimitry Andric /*isVolatile=*/false, AllocA->getAlign(), FirstInst); 3890b57cec5SDimitry Andric new StoreInst(LI, AllocA, FirstInst); 3900b57cec5SDimitry Andric } 3910b57cec5SDimitry Andric 3920b57cec5SDimitry Andric void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 3930b57cec5SDimitry Andric if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) 3940b57cec5SDimitry Andric return; 3950b57cec5SDimitry Andric 3960b57cec5SDimitry Andric // Deciding where to emit the addrspacecast pair. 3970b57cec5SDimitry Andric BasicBlock::iterator InsertPt; 3980b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 3990b57cec5SDimitry Andric // Insert at the functon entry if Ptr is an argument. 4000b57cec5SDimitry Andric InsertPt = Arg->getParent()->getEntryBlock().begin(); 4010b57cec5SDimitry Andric } else { 4020b57cec5SDimitry Andric // Insert right after Ptr if Ptr is an instruction. 4030b57cec5SDimitry Andric InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 4040b57cec5SDimitry Andric assert(InsertPt != InsertPt->getParent()->end() && 4050b57cec5SDimitry Andric "We don't call this function with Ptr being a terminator."); 4060b57cec5SDimitry Andric } 4070b57cec5SDimitry Andric 4080b57cec5SDimitry Andric Instruction *PtrInGlobal = new AddrSpaceCastInst( 409fe6060f1SDimitry Andric Ptr, 410fe6060f1SDimitry Andric PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()), 4110b57cec5SDimitry Andric ADDRESS_SPACE_GLOBAL), 4120b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4130b57cec5SDimitry Andric Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 4140b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4150b57cec5SDimitry Andric // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 4160b57cec5SDimitry Andric Ptr->replaceAllUsesWith(PtrInGeneric); 4170b57cec5SDimitry Andric PtrInGlobal->setOperand(0, Ptr); 4180b57cec5SDimitry Andric } 4190b57cec5SDimitry Andric 4200b57cec5SDimitry Andric // ============================================================================= 4210b57cec5SDimitry Andric // Main function for this pass. 4220b57cec5SDimitry Andric // ============================================================================= 4230b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { 4240b57cec5SDimitry Andric if (TM && TM->getDrvInterface() == NVPTX::CUDA) { 4250b57cec5SDimitry Andric // Mark pointers in byval structs as global. 4260b57cec5SDimitry Andric for (auto &B : F) { 4270b57cec5SDimitry Andric for (auto &I : B) { 4280b57cec5SDimitry Andric if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 4290b57cec5SDimitry Andric if (LI->getType()->isPointerTy()) { 430e8d8bef9SDimitry Andric Value *UO = getUnderlyingObject(LI->getPointerOperand()); 4310b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(UO)) { 4320b57cec5SDimitry Andric if (Arg->hasByValAttr()) { 4330b57cec5SDimitry Andric // LI is a load from a pointer within a byval kernel parameter. 4340b57cec5SDimitry Andric markPointerAsGlobal(LI); 4350b57cec5SDimitry Andric } 4360b57cec5SDimitry Andric } 4370b57cec5SDimitry Andric } 4380b57cec5SDimitry Andric } 4390b57cec5SDimitry Andric } 4400b57cec5SDimitry Andric } 4410b57cec5SDimitry Andric } 4420b57cec5SDimitry Andric 443fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); 4440b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 4450b57cec5SDimitry Andric if (Arg.getType()->isPointerTy()) { 4460b57cec5SDimitry Andric if (Arg.hasByValAttr()) 4470b57cec5SDimitry Andric handleByValParam(&Arg); 4480b57cec5SDimitry Andric else if (TM && TM->getDrvInterface() == NVPTX::CUDA) 4490b57cec5SDimitry Andric markPointerAsGlobal(&Arg); 4500b57cec5SDimitry Andric } 4510b57cec5SDimitry Andric } 4520b57cec5SDimitry Andric return true; 4530b57cec5SDimitry Andric } 4540b57cec5SDimitry Andric 4550b57cec5SDimitry Andric // Device functions only need to copy byval args into local memory. 4560b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { 457fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); 4580b57cec5SDimitry Andric for (Argument &Arg : F.args()) 4590b57cec5SDimitry Andric if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 4600b57cec5SDimitry Andric handleByValParam(&Arg); 4610b57cec5SDimitry Andric return true; 4620b57cec5SDimitry Andric } 4630b57cec5SDimitry Andric 4640b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnFunction(Function &F) { 4650b57cec5SDimitry Andric return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F); 4660b57cec5SDimitry Andric } 4670b57cec5SDimitry Andric 4680b57cec5SDimitry Andric FunctionPass * 4690b57cec5SDimitry Andric llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) { 4700b57cec5SDimitry Andric return new NVPTXLowerArgs(TM); 4710b57cec5SDimitry Andric } 472