10b57cec5SDimitry Andric //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 20b57cec5SDimitry Andric // 30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60b57cec5SDimitry Andric // 70b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 80b57cec5SDimitry Andric // 90b57cec5SDimitry Andric // 100b57cec5SDimitry Andric // Arguments to kernel and device functions are passed via param space, 110b57cec5SDimitry Andric // which imposes certain restrictions: 120b57cec5SDimitry Andric // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 130b57cec5SDimitry Andric // 140b57cec5SDimitry Andric // Kernel parameters are read-only and accessible only via ld.param 150b57cec5SDimitry Andric // instruction, directly or via a pointer. Pointers to kernel 160b57cec5SDimitry Andric // arguments can't be converted to generic address space. 170b57cec5SDimitry Andric // 180b57cec5SDimitry Andric // Device function parameters are directly accessible via 190b57cec5SDimitry Andric // ld.param/st.param, but taking the address of one returns a pointer 200b57cec5SDimitry Andric // to a copy created in local space which *can't* be used with 210b57cec5SDimitry Andric // ld.param/st.param. 220b57cec5SDimitry Andric // 230b57cec5SDimitry Andric // Copying a byval struct into local memory in IR allows us to enforce 240b57cec5SDimitry Andric // the param space restrictions, gives the rest of IR a pointer w/o 250b57cec5SDimitry Andric // param space restrictions, and gives us an opportunity to eliminate 260b57cec5SDimitry Andric // the copy. 270b57cec5SDimitry Andric // 280b57cec5SDimitry Andric // Pointer arguments to kernel functions need more work to be lowered: 290b57cec5SDimitry Andric // 300b57cec5SDimitry Andric // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 310b57cec5SDimitry Andric // global address space. This allows later optimizations to emit 320b57cec5SDimitry Andric // ld.global.*/st.global.* for accessing these pointer arguments. For 330b57cec5SDimitry Andric // example, 340b57cec5SDimitry Andric // 350b57cec5SDimitry Andric // define void @foo(float* %input) { 360b57cec5SDimitry Andric // %v = load float, float* %input, align 4 370b57cec5SDimitry Andric // ... 380b57cec5SDimitry Andric // } 390b57cec5SDimitry Andric // 400b57cec5SDimitry Andric // becomes 410b57cec5SDimitry Andric // 420b57cec5SDimitry Andric // define void @foo(float* %input) { 430b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 440b57cec5SDimitry Andric // %input3 = addrspacecast float addrspace(1)* %input2 to float* 450b57cec5SDimitry Andric // %v = load float, float* %input3, align 4 460b57cec5SDimitry Andric // ... 470b57cec5SDimitry Andric // } 480b57cec5SDimitry Andric // 490b57cec5SDimitry Andric // Later, NVPTXInferAddressSpaces will optimize it to 500b57cec5SDimitry Andric // 510b57cec5SDimitry Andric // define void @foo(float* %input) { 520b57cec5SDimitry Andric // %input2 = addrspacecast float* %input to float addrspace(1)* 530b57cec5SDimitry Andric // %v = load float, float addrspace(1)* %input2, align 4 540b57cec5SDimitry Andric // ... 550b57cec5SDimitry Andric // } 560b57cec5SDimitry Andric // 570b57cec5SDimitry Andric // 2. Convert pointers in a byval kernel parameter to pointers in the global 580b57cec5SDimitry Andric // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., 590b57cec5SDimitry Andric // 600b57cec5SDimitry Andric // struct S { 610b57cec5SDimitry Andric // int *x; 620b57cec5SDimitry Andric // int *y; 630b57cec5SDimitry Andric // }; 640b57cec5SDimitry Andric // __global__ void foo(S s) { 650b57cec5SDimitry Andric // int *b = s.y; 660b57cec5SDimitry Andric // // use b 670b57cec5SDimitry Andric // } 680b57cec5SDimitry Andric // 690b57cec5SDimitry Andric // "b" points to the global address space. In the IR level, 700b57cec5SDimitry Andric // 710b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 720b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 730b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 740b57cec5SDimitry Andric // ; use %b 750b57cec5SDimitry Andric // } 760b57cec5SDimitry Andric // 770b57cec5SDimitry Andric // becomes 780b57cec5SDimitry Andric // 790b57cec5SDimitry Andric // define void @foo({i32*, i32*}* byval %input) { 800b57cec5SDimitry Andric // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 810b57cec5SDimitry Andric // %b = load i32*, i32** %b_ptr 820b57cec5SDimitry Andric // %b_global = addrspacecast i32* %b to i32 addrspace(1)* 830b57cec5SDimitry Andric // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* 840b57cec5SDimitry Andric // ; use %b_generic 850b57cec5SDimitry Andric // } 860b57cec5SDimitry Andric // 870b57cec5SDimitry Andric // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 880b57cec5SDimitry Andric // cancel the addrspacecast pair this pass emits. 890b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 900b57cec5SDimitry Andric 910b57cec5SDimitry Andric #include "NVPTX.h" 920b57cec5SDimitry Andric #include "NVPTXTargetMachine.h" 930b57cec5SDimitry Andric #include "NVPTXUtilities.h" 940b57cec5SDimitry Andric #include "MCTargetDesc/NVPTXBaseInfo.h" 950b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 960b57cec5SDimitry Andric #include "llvm/IR/Function.h" 970b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 980b57cec5SDimitry Andric #include "llvm/IR/Module.h" 990b57cec5SDimitry Andric #include "llvm/IR/Type.h" 1000b57cec5SDimitry Andric #include "llvm/Pass.h" 1010b57cec5SDimitry Andric 1020b57cec5SDimitry Andric using namespace llvm; 1030b57cec5SDimitry Andric 1040b57cec5SDimitry Andric namespace llvm { 1050b57cec5SDimitry Andric void initializeNVPTXLowerArgsPass(PassRegistry &); 1060b57cec5SDimitry Andric } 1070b57cec5SDimitry Andric 1080b57cec5SDimitry Andric namespace { 1090b57cec5SDimitry Andric class NVPTXLowerArgs : public FunctionPass { 1100b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 1110b57cec5SDimitry Andric 1120b57cec5SDimitry Andric bool runOnKernelFunction(Function &F); 1130b57cec5SDimitry Andric bool runOnDeviceFunction(Function &F); 1140b57cec5SDimitry Andric 1150b57cec5SDimitry Andric // handle byval parameters 1160b57cec5SDimitry Andric void handleByValParam(Argument *Arg); 1170b57cec5SDimitry Andric // Knowing Ptr must point to the global address space, this function 1180b57cec5SDimitry Andric // addrspacecasts Ptr to global and then back to generic. This allows 1190b57cec5SDimitry Andric // NVPTXInferAddressSpaces to fold the global-to-generic cast into 1200b57cec5SDimitry Andric // loads/stores that appear later. 1210b57cec5SDimitry Andric void markPointerAsGlobal(Value *Ptr); 1220b57cec5SDimitry Andric 1230b57cec5SDimitry Andric public: 1240b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 1250b57cec5SDimitry Andric NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr) 1260b57cec5SDimitry Andric : FunctionPass(ID), TM(TM) {} 1270b57cec5SDimitry Andric StringRef getPassName() const override { 1280b57cec5SDimitry Andric return "Lower pointer arguments of CUDA kernels"; 1290b57cec5SDimitry Andric } 1300b57cec5SDimitry Andric 1310b57cec5SDimitry Andric private: 1320b57cec5SDimitry Andric const NVPTXTargetMachine *TM; 1330b57cec5SDimitry Andric }; 1340b57cec5SDimitry Andric } // namespace 1350b57cec5SDimitry Andric 1360b57cec5SDimitry Andric char NVPTXLowerArgs::ID = 1; 1370b57cec5SDimitry Andric 1380b57cec5SDimitry Andric INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args", 1390b57cec5SDimitry Andric "Lower arguments (NVPTX)", false, false) 1400b57cec5SDimitry Andric 1410b57cec5SDimitry Andric // ============================================================================= 1420b57cec5SDimitry Andric // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 1430b57cec5SDimitry Andric // then add the following instructions to the first basic block: 1440b57cec5SDimitry Andric // 1450b57cec5SDimitry Andric // %temp = alloca %struct.x, align 8 1460b57cec5SDimitry Andric // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 1470b57cec5SDimitry Andric // %tv = load %struct.x addrspace(101)* %tempd 1480b57cec5SDimitry Andric // store %struct.x %tv, %struct.x* %temp, align 8 1490b57cec5SDimitry Andric // 1500b57cec5SDimitry Andric // The above code allocates some space in the stack and copies the incoming 1510b57cec5SDimitry Andric // struct from param space to local space. 1520b57cec5SDimitry Andric // Then replace all occurrences of %d by %temp. 1530b57cec5SDimitry Andric // ============================================================================= 1540b57cec5SDimitry Andric void NVPTXLowerArgs::handleByValParam(Argument *Arg) { 1550b57cec5SDimitry Andric Function *Func = Arg->getParent(); 1560b57cec5SDimitry Andric Instruction *FirstInst = &(Func->getEntryBlock().front()); 1570b57cec5SDimitry Andric PointerType *PType = dyn_cast<PointerType>(Arg->getType()); 1580b57cec5SDimitry Andric 1590b57cec5SDimitry Andric assert(PType && "Expecting pointer type in handleByValParam"); 1600b57cec5SDimitry Andric 1610b57cec5SDimitry Andric Type *StructType = PType->getElementType(); 162*5ffd83dbSDimitry Andric const DataLayout &DL = Func->getParent()->getDataLayout(); 163*5ffd83dbSDimitry Andric unsigned AS = DL.getAllocaAddrSpace(); 1640b57cec5SDimitry Andric AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst); 1650b57cec5SDimitry Andric // Set the alignment to alignment of the byval parameter. This is because, 1660b57cec5SDimitry Andric // later load/stores assume that alignment, and we are going to replace 1670b57cec5SDimitry Andric // the use of the byval parameter with this alloca instruction. 168*5ffd83dbSDimitry Andric AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo()) 169*5ffd83dbSDimitry Andric .getValueOr(DL.getPrefTypeAlign(StructType))); 1700b57cec5SDimitry Andric Arg->replaceAllUsesWith(AllocA); 1710b57cec5SDimitry Andric 1720b57cec5SDimitry Andric Value *ArgInParam = new AddrSpaceCastInst( 1730b57cec5SDimitry Andric Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), 1740b57cec5SDimitry Andric FirstInst); 1750b57cec5SDimitry Andric LoadInst *LI = 1760b57cec5SDimitry Andric new LoadInst(StructType, ArgInParam, Arg->getName(), FirstInst); 1770b57cec5SDimitry Andric new StoreInst(LI, AllocA, FirstInst); 1780b57cec5SDimitry Andric } 1790b57cec5SDimitry Andric 1800b57cec5SDimitry Andric void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 1810b57cec5SDimitry Andric if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) 1820b57cec5SDimitry Andric return; 1830b57cec5SDimitry Andric 1840b57cec5SDimitry Andric // Deciding where to emit the addrspacecast pair. 1850b57cec5SDimitry Andric BasicBlock::iterator InsertPt; 1860b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 1870b57cec5SDimitry Andric // Insert at the functon entry if Ptr is an argument. 1880b57cec5SDimitry Andric InsertPt = Arg->getParent()->getEntryBlock().begin(); 1890b57cec5SDimitry Andric } else { 1900b57cec5SDimitry Andric // Insert right after Ptr if Ptr is an instruction. 1910b57cec5SDimitry Andric InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 1920b57cec5SDimitry Andric assert(InsertPt != InsertPt->getParent()->end() && 1930b57cec5SDimitry Andric "We don't call this function with Ptr being a terminator."); 1940b57cec5SDimitry Andric } 1950b57cec5SDimitry Andric 1960b57cec5SDimitry Andric Instruction *PtrInGlobal = new AddrSpaceCastInst( 1970b57cec5SDimitry Andric Ptr, PointerType::get(Ptr->getType()->getPointerElementType(), 1980b57cec5SDimitry Andric ADDRESS_SPACE_GLOBAL), 1990b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 2000b57cec5SDimitry Andric Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 2010b57cec5SDimitry Andric Ptr->getName(), &*InsertPt); 2020b57cec5SDimitry Andric // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 2030b57cec5SDimitry Andric Ptr->replaceAllUsesWith(PtrInGeneric); 2040b57cec5SDimitry Andric PtrInGlobal->setOperand(0, Ptr); 2050b57cec5SDimitry Andric } 2060b57cec5SDimitry Andric 2070b57cec5SDimitry Andric // ============================================================================= 2080b57cec5SDimitry Andric // Main function for this pass. 2090b57cec5SDimitry Andric // ============================================================================= 2100b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnKernelFunction(Function &F) { 2110b57cec5SDimitry Andric if (TM && TM->getDrvInterface() == NVPTX::CUDA) { 2120b57cec5SDimitry Andric // Mark pointers in byval structs as global. 2130b57cec5SDimitry Andric for (auto &B : F) { 2140b57cec5SDimitry Andric for (auto &I : B) { 2150b57cec5SDimitry Andric if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 2160b57cec5SDimitry Andric if (LI->getType()->isPointerTy()) { 2170b57cec5SDimitry Andric Value *UO = GetUnderlyingObject(LI->getPointerOperand(), 2180b57cec5SDimitry Andric F.getParent()->getDataLayout()); 2190b57cec5SDimitry Andric if (Argument *Arg = dyn_cast<Argument>(UO)) { 2200b57cec5SDimitry Andric if (Arg->hasByValAttr()) { 2210b57cec5SDimitry Andric // LI is a load from a pointer within a byval kernel parameter. 2220b57cec5SDimitry Andric markPointerAsGlobal(LI); 2230b57cec5SDimitry Andric } 2240b57cec5SDimitry Andric } 2250b57cec5SDimitry Andric } 2260b57cec5SDimitry Andric } 2270b57cec5SDimitry Andric } 2280b57cec5SDimitry Andric } 2290b57cec5SDimitry Andric } 2300b57cec5SDimitry Andric 2310b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 2320b57cec5SDimitry Andric if (Arg.getType()->isPointerTy()) { 2330b57cec5SDimitry Andric if (Arg.hasByValAttr()) 2340b57cec5SDimitry Andric handleByValParam(&Arg); 2350b57cec5SDimitry Andric else if (TM && TM->getDrvInterface() == NVPTX::CUDA) 2360b57cec5SDimitry Andric markPointerAsGlobal(&Arg); 2370b57cec5SDimitry Andric } 2380b57cec5SDimitry Andric } 2390b57cec5SDimitry Andric return true; 2400b57cec5SDimitry Andric } 2410b57cec5SDimitry Andric 2420b57cec5SDimitry Andric // Device functions only need to copy byval args into local memory. 2430b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) { 2440b57cec5SDimitry Andric for (Argument &Arg : F.args()) 2450b57cec5SDimitry Andric if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 2460b57cec5SDimitry Andric handleByValParam(&Arg); 2470b57cec5SDimitry Andric return true; 2480b57cec5SDimitry Andric } 2490b57cec5SDimitry Andric 2500b57cec5SDimitry Andric bool NVPTXLowerArgs::runOnFunction(Function &F) { 2510b57cec5SDimitry Andric return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F); 2520b57cec5SDimitry Andric } 2530b57cec5SDimitry Andric 2540b57cec5SDimitry Andric FunctionPass * 2550b57cec5SDimitry Andric llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) { 2560b57cec5SDimitry Andric return new NVPTXLowerArgs(TM); 2570b57cec5SDimitry Andric } 258