10b57cec5SDimitry Andric //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 20b57cec5SDimitry Andric // 30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60b57cec5SDimitry Andric // 70b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 80b57cec5SDimitry Andric // 90b57cec5SDimitry Andric // 100b57cec5SDimitry Andric // Arguments to kernel and device functions are passed via param space, 110b57cec5SDimitry Andric // which imposes certain restrictions: 120b57cec5SDimitry Andric // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 130b57cec5SDimitry Andric // 140b57cec5SDimitry Andric // Kernel parameters are read-only and accessible only via ld.param 150b57cec5SDimitry Andric // instruction, directly or via a pointer. Pointers to kernel 160b57cec5SDimitry Andric // arguments can't be converted to generic address space. 170b57cec5SDimitry Andric // 180b57cec5SDimitry Andric // Device function parameters are directly accessible via 190b57cec5SDimitry Andric // ld.param/st.param, but taking the address of one returns a pointer 200b57cec5SDimitry Andric // to a copy created in local space which *can't* be used with 210b57cec5SDimitry Andric // ld.param/st.param. 220b57cec5SDimitry Andric // 230b57cec5SDimitry Andric // Copying a byval struct into local memory in IR allows us to enforce 240b57cec5SDimitry Andric // the param space restrictions, gives the rest of IR a pointer w/o 250b57cec5SDimitry Andric // param space restrictions, and gives us an opportunity to eliminate 260b57cec5SDimitry Andric // the copy. 270b57cec5SDimitry Andric // 280b57cec5SDimitry Andric // Pointer arguments to kernel functions need more work to be lowered: 290b57cec5SDimitry Andric // 300b57cec5SDimitry Andric // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 310b57cec5SDimitry Andric // global address space. This allows later optimizations to emit 320b57cec5SDimitry Andric // ld.global.*/st.global.* for accessing these pointer arguments. For 330b57cec5SDimitry Andric // example, 340b57cec5SDimitry Andric // 350b57cec5SDimitry Andric // define void @foo(float* %input) { 360b57cec5SDimitry Andric // %v = load float, float* %input, align 4 370b57cec5SDimitry Andric // ... 380b57cec5SDimitry Andric // } 390b57cec5SDimitry Andric // 400b57cec5SDimitry Andric // becomes 410b57cec5SDimitry Andric // 420b57cec5SDimitry Andric // define void @foo(float* %input) { 430b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 440b57cec5SDimitry Andric // %input3 = addrspacecast float addrspace(1)* %input2 to float* 450b57cec5SDimitry Andric // %v = load float, float* %input3, align 4 460b57cec5SDimitry Andric // ... 470b57cec5SDimitry Andric // } 480b57cec5SDimitry Andric // 490b57cec5SDimitry Andric // Later, NVPTXInferAddressSpaces will optimize it to 500b57cec5SDimitry Andric // 510b57cec5SDimitry Andric // define void @foo(float* %input) { 520b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 530b57cec5SDimitry Andric // %v = load float, float addrspace(1)* %input2, align 4 540b57cec5SDimitry Andric // ... 550b57cec5SDimitry Andric // } 560b57cec5SDimitry Andric // 570b57cec5SDimitry Andric // 2. Convert pointers in a byval kernel parameter to pointers in the global 580b57cec5SDimitry Andric // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 590b57cec5SDimitry Andric // 600b57cec5SDimitry Andric // struct S { 610b57cec5SDimitry Andric // int *x; 620b57cec5SDimitry Andric // int *y; 630b57cec5SDimitry Andric // }; 640b57cec5SDimitry Andric // __global__ void foo(S s) { 650b57cec5SDimitry Andric // int *b = s.y; 660b57cec5SDimitry Andric // // use b 670b57cec5SDimitry Andric // } 680b57cec5SDimitry Andric // 690b57cec5SDimitry Andric // "b" points to the global address space. In the IR level, 700b57cec5SDimitry Andric // 710b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 720b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 730b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 740b57cec5SDimitry Andric // ; use %b 750b57cec5SDimitry Andric // } 760b57cec5SDimitry Andric // 770b57cec5SDimitry Andric // becomes 780b57cec5SDimitry Andric // 790b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 800b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 810b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 820b57cec5SDimitry Andric // %b_global = addrspacecast i32* %b to i32 addrspace(1)* 830b57cec5SDimitry Andric // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 840b57cec5SDimitry Andric // ; use %b_generic 850b57cec5SDimitry Andric // } 860b57cec5SDimitry Andric // 870b57cec5SDimitry Andric // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 880b57cec5SDimitry Andric // cancel the addrspacecast pair this pass emits. 890b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 900b57cec5SDimitry Andric 91*81ad6265SDimitry Andric #include "MCTargetDesc/NVPTXBaseInfo.h" 920b57cec5SDimitry Andric #include "NVPTX.h" 930b57cec5SDimitry Andric #include "NVPTXTargetMachine.h" 940b57cec5SDimitry Andric #include "NVPTXUtilities.h" 950b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 960b57cec5SDimitry Andric #include "llvm/IR/Function.h" 970b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 980b57cec5SDimitry Andric #include "llvm/IR/Module.h" 990b57cec5SDimitry Andric #include "llvm/IR/Type.h" 1000b57cec5SDimitry Andric #include "llvm/Pass.h" 101*81ad6265SDimitry Andric #include <queue> 1020b57cec5SDimitry Andric 103fe6060f1SDimitry Andric #define DEBUG_TYPE "nvptx-lower-args" 104fe6060f1SDimitry Andric 1050b57cec5SDimitry Andric using namespace llvm; 1060b57cec5SDimitry Andric 1070b57cec5SDimitry Andric namespace llvm { 1080b57cec5SDimitry Andric void initializeNVPTXLowerArgsPass(PassRegistry &); 1090b57cec5SDimitry Andric } 1100b57cec5SDimitry Andric 1110b57cec5SDimitry Andric namespace { 1120b57cec5SDimitry Andric class NVPTXLowerArgs : public FunctionPass { 1130b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 1140b57cec5SDimitry Andric 1150b57cec5SDimitry Andric bool runOnKernelFunction(Function &F); 1160b57cec5SDimitry Andric bool runOnDeviceFunction(Function &F); 1170b57cec5SDimitry Andric 1180b57cec5SDimitry Andric // handle byval parameters 1190b57cec5SDimitry Andric void handleByValParam(Argument *Arg); 1200b57cec5SDimitry Andric // Knowing Ptr must point to the global address space, this function 1210b57cec5SDimitry Andric // addrspacecasts Ptr to global and then back to generic. This allows 1220b57cec5SDimitry Andric // NVPTXInferAddressSpaces to fold the global-to-generic cast into 1230b57cec5SDimitry Andric // loads/stores that appear later. 1240b57cec5SDimitry Andric void markPointerAsGlobal(Value *Ptr); 1250b57cec5SDimitry Andric 1260b57cec5SDimitry Andric public: 1270b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 1280b57cec5SDimitry Andric NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr) 1290b57cec5SDimitry Andric : FunctionPass(ID), TM(TM) {} 1300b57cec5SDimitry Andric StringRef getPassName() const override { 1310b57cec5SDimitry Andric return "Lower pointer arguments of CUDA kernels"; 1320b57cec5SDimitry Andric } 1330b57cec5SDimitry Andric 1340b57cec5SDimitry Andric private: 1350b57cec5SDimitry Andric const NVPTXTargetMachine *TM; 1360b57cec5SDimitry Andric }; 1370b57cec5SDimitry Andric } // namespace 1380b57cec5SDimitry Andric 1390b57cec5SDimitry Andric char NVPTXLowerArgs::ID = 1; 1400b57cec5SDimitry Andric 1410b57cec5SDimitry Andric INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args", 1420b57cec5SDimitry Andric "Lower arguments (NVPTX)", false, false) 1430b57cec5SDimitry Andric 1440b57cec5SDimitry Andric // ============================================================================= 1450b57cec5SDimitry Andric // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 146fe6060f1SDimitry Andric // and we can't guarantee that the only accesses are loads, 1470b57cec5SDimitry Andric // then add the following instructions to the first basic block: 1480b57cec5SDimitry Andric // 1490b57cec5SDimitry Andric // %temp = alloca %struct.x, align 8 1500b57cec5SDimitry Andric // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 1510b57cec5SDimitry Andric // %tv = load %struct.x addrspace(101)* %tempd 1520b57cec5SDimitry Andric // store %struct.x %tv, %struct.x* %temp, align 8 1530b57cec5SDimitry Andric // 1540b57cec5SDimitry Andric // The above code allocates some space in the stack and copies the incoming 1550b57cec5SDimitry Andric // struct from param space to local space. 1560b57cec5SDimitry Andric // Then replace all occurrences of %d by %temp. 157fe6060f1SDimitry Andric // 158fe6060f1SDimitry Andric // In case we know that all users are GEPs or Loads, replace them with the same 159fe6060f1SDimitry Andric // ones in parameter AS, so we can access them using ld.param. 1600b57cec5SDimitry Andric // ============================================================================= 161fe6060f1SDimitry Andric 162fe6060f1SDimitry Andric // Replaces the \p OldUser instruction with the same in parameter AS. 163fe6060f1SDimitry Andric // Only Load and GEP are supported. 164fe6060f1SDimitry Andric static void convertToParamAS(Value *OldUser, Value *Param) { 165fe6060f1SDimitry Andric Instruction *I = dyn_cast<Instruction>(OldUser); 166fe6060f1SDimitry Andric assert(I && "OldUser must be an instruction"); 167fe6060f1SDimitry Andric struct IP { 168fe6060f1SDimitry Andric Instruction *OldInstruction; 169fe6060f1SDimitry Andric Value *NewParam; 170fe6060f1SDimitry Andric }; 171fe6060f1SDimitry Andric SmallVector<IP> ItemsToConvert = {{I, Param}}; 172fe6060f1SDimitry Andric SmallVector<Instruction *> InstructionsToDelete; 173fe6060f1SDimitry Andric 174fe6060f1SDimitry Andric auto CloneInstInParamAS = [](const IP &I) -> Value * { 175fe6060f1SDimitry Andric if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { 176fe6060f1SDimitry Andric LI->setOperand(0, I.NewParam); 177fe6060f1SDimitry Andric return LI; 178fe6060f1SDimitry Andric } 179fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { 180fe6060f1SDimitry Andric SmallVector<Value *, 4> Indices(GEP->indices()); 181fe6060f1SDimitry Andric auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(), 182fe6060f1SDimitry Andric I.NewParam, Indices, 183fe6060f1SDimitry Andric GEP->getName(), GEP); 184fe6060f1SDimitry Andric NewGEP->setIsInBounds(GEP->isInBounds()); 185fe6060f1SDimitry Andric return NewGEP; 186fe6060f1SDimitry Andric } 187fe6060f1SDimitry Andric if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { 188fe6060f1SDimitry Andric auto *NewBCType = PointerType::getWithSamePointeeType( 189fe6060f1SDimitry Andric cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM); 190fe6060f1SDimitry Andric return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, 191fe6060f1SDimitry Andric BC->getName(), BC); 192fe6060f1SDimitry Andric } 193fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { 194fe6060f1SDimitry Andric assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); 195fe6060f1SDimitry Andric (void)ASC; 196fe6060f1SDimitry Andric // Just pass through the argument, the old ASC is no longer needed. 197fe6060f1SDimitry Andric return I.NewParam; 198fe6060f1SDimitry Andric } 199fe6060f1SDimitry Andric llvm_unreachable("Unsupported instruction"); 200fe6060f1SDimitry Andric }; 201fe6060f1SDimitry Andric 202fe6060f1SDimitry Andric while (!ItemsToConvert.empty()) { 203fe6060f1SDimitry Andric IP I = ItemsToConvert.pop_back_val(); 204fe6060f1SDimitry Andric Value *NewInst = CloneInstInParamAS(I); 205fe6060f1SDimitry Andric 206fe6060f1SDimitry Andric if (NewInst && NewInst != I.OldInstruction) { 207fe6060f1SDimitry Andric // We've created a new instruction. Queue users of the old instruction to 208fe6060f1SDimitry Andric // be converted and the instruction itself to be deleted. We can't delete 209fe6060f1SDimitry Andric // the old instruction yet, because it's still in use by a load somewhere. 210*81ad6265SDimitry Andric for (Value *V : I.OldInstruction->users()) 211fe6060f1SDimitry Andric ItemsToConvert.push_back({cast<Instruction>(V), NewInst}); 212fe6060f1SDimitry Andric 213fe6060f1SDimitry Andric InstructionsToDelete.push_back(I.OldInstruction); 214fe6060f1SDimitry Andric } 215fe6060f1SDimitry Andric } 216fe6060f1SDimitry Andric 217fe6060f1SDimitry Andric // Now we know that all argument loads are using addresses in parameter space 218fe6060f1SDimitry Andric // and we can finally remove the old instructions in generic AS. Instructions 219fe6060f1SDimitry Andric // scheduled for removal should be processed in reverse order so the ones 220fe6060f1SDimitry Andric // closest to the load are deleted first. Otherwise they may still be in use. 221fe6060f1SDimitry Andric // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will 222fe6060f1SDimitry Andric // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by 223fe6060f1SDimitry Andric // the BitCast. 224*81ad6265SDimitry Andric for (Instruction *I : llvm::reverse(InstructionsToDelete)) 225*81ad6265SDimitry Andric I->eraseFromParent(); 226*81ad6265SDimitry Andric } 227*81ad6265SDimitry Andric 228*81ad6265SDimitry Andric // Adjust alignment of arguments passed byval in .param address space. We can 229*81ad6265SDimitry Andric // increase alignment of such arguments in a way that ensures that we can 230*81ad6265SDimitry Andric // effectively vectorize their loads. We should also traverse all loads from 231*81ad6265SDimitry Andric // byval pointer and adjust their alignment, if those were using known offset. 232*81ad6265SDimitry Andric // Such alignment changes must be conformed with parameter store and load in 233*81ad6265SDimitry Andric // NVPTXTargetLowering::LowerCall. 234*81ad6265SDimitry Andric static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, 235*81ad6265SDimitry Andric const NVPTXTargetLowering *TLI) { 236*81ad6265SDimitry Andric Function *Func = Arg->getParent(); 237*81ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 238*81ad6265SDimitry Andric const DataLayout DL(Func->getParent()); 239*81ad6265SDimitry Andric 240*81ad6265SDimitry Andric uint64_t NewArgAlign = 241*81ad6265SDimitry Andric TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value(); 242*81ad6265SDimitry Andric uint64_t CurArgAlign = 243*81ad6265SDimitry Andric Arg->getAttribute(Attribute::Alignment).getValueAsInt(); 244*81ad6265SDimitry Andric 245*81ad6265SDimitry Andric if (CurArgAlign >= NewArgAlign) 246*81ad6265SDimitry Andric return; 247*81ad6265SDimitry Andric 248*81ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of " 249*81ad6265SDimitry Andric << CurArgAlign << " for " << *Arg << '\n'); 250*81ad6265SDimitry Andric 251*81ad6265SDimitry Andric auto NewAlignAttr = 252*81ad6265SDimitry Andric Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign); 253*81ad6265SDimitry Andric Arg->removeAttr(Attribute::Alignment); 254*81ad6265SDimitry Andric Arg->addAttr(NewAlignAttr); 255*81ad6265SDimitry Andric 256*81ad6265SDimitry Andric struct Load { 257*81ad6265SDimitry Andric LoadInst *Inst; 258*81ad6265SDimitry Andric uint64_t Offset; 259*81ad6265SDimitry Andric }; 260*81ad6265SDimitry Andric 261*81ad6265SDimitry Andric struct LoadContext { 262*81ad6265SDimitry Andric Value *InitialVal; 263*81ad6265SDimitry Andric uint64_t Offset; 264*81ad6265SDimitry Andric }; 265*81ad6265SDimitry Andric 266*81ad6265SDimitry Andric SmallVector<Load> Loads; 267*81ad6265SDimitry Andric std::queue<LoadContext> Worklist; 268*81ad6265SDimitry Andric Worklist.push({ArgInParamAS, 0}); 269*81ad6265SDimitry Andric 270*81ad6265SDimitry Andric while (!Worklist.empty()) { 271*81ad6265SDimitry Andric LoadContext Ctx = Worklist.front(); 272*81ad6265SDimitry Andric Worklist.pop(); 273*81ad6265SDimitry Andric 274*81ad6265SDimitry Andric for (User *CurUser : Ctx.InitialVal->users()) { 275*81ad6265SDimitry Andric if (auto *I = dyn_cast<LoadInst>(CurUser)) { 276*81ad6265SDimitry Andric Loads.push_back({I, Ctx.Offset}); 277*81ad6265SDimitry Andric continue; 278*81ad6265SDimitry Andric } 279*81ad6265SDimitry Andric 280*81ad6265SDimitry Andric if (auto *I = dyn_cast<BitCastInst>(CurUser)) { 281*81ad6265SDimitry Andric Worklist.push({I, Ctx.Offset}); 282*81ad6265SDimitry Andric continue; 283*81ad6265SDimitry Andric } 284*81ad6265SDimitry Andric 285*81ad6265SDimitry Andric if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) { 286*81ad6265SDimitry Andric APInt OffsetAccumulated = 287*81ad6265SDimitry Andric APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM)); 288*81ad6265SDimitry Andric 289*81ad6265SDimitry Andric if (!I->accumulateConstantOffset(DL, OffsetAccumulated)) 290*81ad6265SDimitry Andric continue; 291*81ad6265SDimitry Andric 292*81ad6265SDimitry Andric uint64_t OffsetLimit = -1; 293*81ad6265SDimitry Andric uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit); 294*81ad6265SDimitry Andric assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX"); 295*81ad6265SDimitry Andric 296*81ad6265SDimitry Andric Worklist.push({I, Ctx.Offset + Offset}); 297*81ad6265SDimitry Andric continue; 298*81ad6265SDimitry Andric } 299*81ad6265SDimitry Andric 300*81ad6265SDimitry Andric llvm_unreachable("All users must be one of: load, " 301*81ad6265SDimitry Andric "bitcast, getelementptr."); 302*81ad6265SDimitry Andric } 303*81ad6265SDimitry Andric } 304*81ad6265SDimitry Andric 305*81ad6265SDimitry Andric for (Load &CurLoad : Loads) { 306*81ad6265SDimitry Andric Align NewLoadAlign(greatestCommonDivisor(NewArgAlign, CurLoad.Offset)); 307*81ad6265SDimitry Andric Align CurLoadAlign(CurLoad.Inst->getAlign()); 308*81ad6265SDimitry Andric CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign)); 309*81ad6265SDimitry Andric } 310fe6060f1SDimitry Andric } 311fe6060f1SDimitry Andric 3120b57cec5SDimitry Andric void NVPTXLowerArgs::handleByValParam(Argument *Arg) { 3130b57cec5SDimitry Andric Function *Func = Arg->getParent(); 3140b57cec5SDimitry Andric Instruction *FirstInst = &(Func->getEntryBlock().front()); 315*81ad6265SDimitry Andric Type *StructType = Arg->getParamByValType(); 316*81ad6265SDimitry Andric assert(StructType && "Missing byval type"); 317fe6060f1SDimitry Andric 318fe6060f1SDimitry Andric auto IsALoadChain = [&](Value *Start) { 319fe6060f1SDimitry Andric SmallVector<Value *, 16> ValuesToCheck = {Start}; 320fe6060f1SDimitry Andric auto IsALoadChainInstr = [](Value *V) -> bool { 321fe6060f1SDimitry Andric if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V)) 322fe6060f1SDimitry Andric return true; 323fe6060f1SDimitry Andric // ASC to param space are OK, too -- we'll just strip them. 324fe6060f1SDimitry Andric if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) { 325fe6060f1SDimitry Andric if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) 326fe6060f1SDimitry Andric return true; 327fe6060f1SDimitry Andric } 328fe6060f1SDimitry Andric return false; 329fe6060f1SDimitry Andric }; 330fe6060f1SDimitry Andric 331fe6060f1SDimitry Andric while (!ValuesToCheck.empty()) { 332fe6060f1SDimitry Andric Value *V = ValuesToCheck.pop_back_val(); 333fe6060f1SDimitry Andric if (!IsALoadChainInstr(V)) { 334fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V 335fe6060f1SDimitry Andric << "\n"); 336fe6060f1SDimitry Andric (void)Arg; 337fe6060f1SDimitry Andric return false; 338fe6060f1SDimitry Andric } 339fe6060f1SDimitry Andric if (!isa<LoadInst>(V)) 340fe6060f1SDimitry Andric llvm::append_range(ValuesToCheck, V->users()); 341fe6060f1SDimitry Andric } 342fe6060f1SDimitry Andric return true; 343fe6060f1SDimitry Andric }; 344fe6060f1SDimitry Andric 345fe6060f1SDimitry Andric if (llvm::all_of(Arg->users(), IsALoadChain)) { 346fe6060f1SDimitry Andric // Convert all loads and intermediate operations to use parameter AS and 347fe6060f1SDimitry Andric // skip creation of a local copy of the argument. 348fe6060f1SDimitry Andric SmallVector<User *, 16> UsersToUpdate(Arg->users()); 349fe6060f1SDimitry Andric Value *ArgInParamAS = new AddrSpaceCastInst( 350fe6060f1SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 351fe6060f1SDimitry Andric FirstInst); 352*81ad6265SDimitry Andric for (Value *V : UsersToUpdate) 353fe6060f1SDimitry Andric convertToParamAS(V, ArgInParamAS); 354fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); 355*81ad6265SDimitry Andric 356*81ad6265SDimitry Andric // Further optimizations require target lowering info. 357*81ad6265SDimitry Andric if (!TM) 358*81ad6265SDimitry Andric return; 359*81ad6265SDimitry Andric 360*81ad6265SDimitry Andric const auto *TLI = 361*81ad6265SDimitry Andric cast<NVPTXTargetLowering>(TM->getSubtargetImpl()->getTargetLowering()); 362*81ad6265SDimitry Andric 363*81ad6265SDimitry Andric adjustByValArgAlignment(Arg, ArgInParamAS, TLI); 364*81ad6265SDimitry Andric 365fe6060f1SDimitry Andric return; 366fe6060f1SDimitry Andric } 367fe6060f1SDimitry Andric 368fe6060f1SDimitry Andric // Otherwise we have to create a temporary copy. 3695ffd83dbSDimitry Andric const DataLayout &DL = Func->getParent()->getDataLayout(); 3705ffd83dbSDimitry Andric unsigned AS = DL.getAllocaAddrSpace(); 3710b57cec5SDimitry Andric AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 3720b57cec5SDimitry Andric // Set the alignment to alignment of the byval parameter. This is because, 3730b57cec5SDimitry Andric // later load/stores assume that alignment, and we are going to replace 3740b57cec5SDimitry Andric // the use of the byval parameter with this alloca instruction. 3755ffd83dbSDimitry Andric AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo()) 376*81ad6265SDimitry Andric .value_or(DL.getPrefTypeAlign(StructType))); 3770b57cec5SDimitry Andric Arg->replaceAllUsesWith(AllocA); 3780b57cec5SDimitry Andric 3790b57cec5SDimitry Andric Value *ArgInParam = new AddrSpaceCastInst( 3800b57cec5SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 3810b57cec5SDimitry Andric FirstInst); 382e8d8bef9SDimitry Andric // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX 383e8d8bef9SDimitry Andric // addrspacecast preserves alignment. Since params are constant, this load is 384e8d8bef9SDimitry Andric // definitely not volatile. 3850b57cec5SDimitry Andric LoadInst *LI = 386e8d8bef9SDimitry Andric new LoadInst(StructType, ArgInParam, Arg->getName(), 387e8d8bef9SDimitry Andric /*isVolatile=*/false, AllocA->getAlign(), FirstInst); 3880b57cec5SDimitry Andric new StoreInst(LI, AllocA, FirstInst); 3890b57cec5SDimitry Andric } 3900b57cec5SDimitry Andric 3910b57cec5SDimitry Andric void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 3920b57cec5SDimitry Andric if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) 3930b57cec5SDimitry Andric return; 3940b57cec5SDimitry Andric 3950b57cec5SDimitry Andric // Deciding where to emit the addrspacecast pair. 3960b57cec5SDimitry Andric BasicBlock::iterator InsertPt; 3970b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 3980b57cec5SDimitry Andric // Insert at the functon entry if Ptr is an argument. 3990b57cec5SDimitry Andric InsertPt = Arg->getParent()->getEntryBlock().begin(); 4000b57cec5SDimitry Andric } else { 4010b57cec5SDimitry Andric // Insert right after Ptr if Ptr is an instruction. 4020b57cec5SDimitry Andric InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 4030b57cec5SDimitry Andric assert(InsertPt != InsertPt->getParent()->end() && 4040b57cec5SDimitry Andric "We don't call this function with Ptr being a terminator."); 4050b57cec5SDimitry Andric } 4060b57cec5SDimitry Andric 4070b57cec5SDimitry Andric Instruction *PtrInGlobal = new AddrSpaceCastInst( 408fe6060f1SDimitry Andric Ptr, 409fe6060f1SDimitry Andric PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()), 4100b57cec5SDimitry Andric ADDRESS_SPACE_GLOBAL), 4110b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4120b57cec5SDimitry Andric Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 4130b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 4140b57cec5SDimitry Andric // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 4150b57cec5SDimitry Andric Ptr->replaceAllUsesWith(PtrInGeneric); 4160b57cec5SDimitry Andric PtrInGlobal->setOperand(0, Ptr); 4170b57cec5SDimitry Andric } 4180b57cec5SDimitry Andric 4190b57cec5SDimitry Andric // ============================================================================= 4200b57cec5SDimitry Andric // Main function for this pass. 4210b57cec5SDimitry Andric // ============================================================================= 4220b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { 4230b57cec5SDimitry Andric if (TM && TM->getDrvInterface() == NVPTX::CUDA) { 4240b57cec5SDimitry Andric // Mark pointers in byval structs as global. 4250b57cec5SDimitry Andric for (auto &B : F) { 4260b57cec5SDimitry Andric for (auto &I : B) { 4270b57cec5SDimitry Andric if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 4280b57cec5SDimitry Andric if (LI->getType()->isPointerTy()) { 429e8d8bef9SDimitry Andric Value *UO = getUnderlyingObject(LI->getPointerOperand()); 4300b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(UO)) { 4310b57cec5SDimitry Andric if (Arg->hasByValAttr()) { 4320b57cec5SDimitry Andric // LI is a load from a pointer within a byval kernel parameter. 4330b57cec5SDimitry Andric markPointerAsGlobal(LI); 4340b57cec5SDimitry Andric } 4350b57cec5SDimitry Andric } 4360b57cec5SDimitry Andric } 4370b57cec5SDimitry Andric } 4380b57cec5SDimitry Andric } 4390b57cec5SDimitry Andric } 4400b57cec5SDimitry Andric } 4410b57cec5SDimitry Andric 442fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); 4430b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 4440b57cec5SDimitry Andric if (Arg.getType()->isPointerTy()) { 4450b57cec5SDimitry Andric if (Arg.hasByValAttr()) 4460b57cec5SDimitry Andric handleByValParam(&Arg); 4470b57cec5SDimitry Andric else if (TM && TM->getDrvInterface() == NVPTX::CUDA) 4480b57cec5SDimitry Andric markPointerAsGlobal(&Arg); 4490b57cec5SDimitry Andric } 4500b57cec5SDimitry Andric } 4510b57cec5SDimitry Andric return true; 4520b57cec5SDimitry Andric } 4530b57cec5SDimitry Andric 4540b57cec5SDimitry Andric // Device functions only need to copy byval args into local memory. 4550b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { 456fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); 4570b57cec5SDimitry Andric for (Argument &Arg : F.args()) 4580b57cec5SDimitry Andric if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 4590b57cec5SDimitry Andric handleByValParam(&Arg); 4600b57cec5SDimitry Andric return true; 4610b57cec5SDimitry Andric } 4620b57cec5SDimitry Andric 4630b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnFunction(Function &F) { 4640b57cec5SDimitry Andric return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F); 4650b57cec5SDimitry Andric } 4660b57cec5SDimitry Andric 4670b57cec5SDimitry Andric FunctionPass * 4680b57cec5SDimitry Andric llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) { 4690b57cec5SDimitry Andric return new NVPTXLowerArgs(TM); 4700b57cec5SDimitry Andric } 471