1 //===-- AMDGPUPromoteKernelArguments.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass recursively promotes generic pointer arguments of a kernel 10 /// into the global address space. 11 /// 12 /// The pass walks kernel's pointer arguments, then loads from them. If a loaded 13 /// value is a pointer and loaded pointer is unmodified in the kernel before the 14 /// load, then promote loaded pointer to global. Then recursively continue. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "AMDGPU.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/Analysis/MemorySSA.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/InitializePasses.h" 23 24 #define DEBUG_TYPE "amdgpu-promote-kernel-arguments" 25 26 using namespace llvm; 27 28 namespace { 29 30 class AMDGPUPromoteKernelArguments : public FunctionPass { 31 MemorySSA *MSSA; 32 33 Instruction *ArgCastInsertPt; 34 35 SmallVector<Value *> Ptrs; 36 37 void enqueueUsers(Value *Ptr); 38 39 bool promotePointer(Value *Ptr); 40 41 public: 42 static char ID; 43 44 AMDGPUPromoteKernelArguments() : FunctionPass(ID) {} 45 46 bool run(Function &F, MemorySSA &MSSA); 47 48 bool runOnFunction(Function &F) override; 49 50 void getAnalysisUsage(AnalysisUsage &AU) const override { 51 AU.addRequired<MemorySSAWrapperPass>(); 52 AU.setPreservesAll(); 53 } 54 }; 55 56 } // end anonymous namespace 57 58 void AMDGPUPromoteKernelArguments::enqueueUsers(Value *Ptr) { 59 SmallVector<User *> PtrUsers(Ptr->users()); 60 61 while (!PtrUsers.empty()) { 62 Instruction *U = dyn_cast<Instruction>(PtrUsers.pop_back_val()); 63 if (!U) 64 continue; 65 66 switch (U->getOpcode()) { 67 default: 68 break; 69 case Instruction::Load: { 70 LoadInst *LD = cast<LoadInst>(U); 71 PointerType *PT = dyn_cast<PointerType>(LD->getType()); 72 if (!PT || 73 (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS && 74 PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS && 75 PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS) || 76 LD->getPointerOperand()->stripInBoundsOffsets() != Ptr) 77 break; 78 const MemoryAccess *MA = MSSA->getWalker()->getClobberingMemoryAccess(LD); 79 // TODO: This load poprobably can be promoted to constant address space. 80 if (MSSA->isLiveOnEntryDef(MA)) 81 Ptrs.push_back(LD); 82 break; 83 } 84 case Instruction::GetElementPtr: 85 case Instruction::AddrSpaceCast: 86 case Instruction::BitCast: 87 if (U->getOperand(0)->stripInBoundsOffsets() == Ptr) 88 PtrUsers.append(U->user_begin(), U->user_end()); 89 break; 90 } 91 } 92 } 93 94 bool AMDGPUPromoteKernelArguments::promotePointer(Value *Ptr) { 95 enqueueUsers(Ptr); 96 97 PointerType *PT = cast<PointerType>(Ptr->getType()); 98 if (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS) 99 return false; 100 101 bool IsArg = isa<Argument>(Ptr); 102 IRBuilder<> B(IsArg ? ArgCastInsertPt 103 : &*std::next(cast<Instruction>(Ptr)->getIterator())); 104 105 // Cast pointer to global address space and back to flat and let 106 // Infer Address Spaces pass to do all necessary rewriting. 107 PointerType *NewPT = 108 PointerType::getWithSamePointeeType(PT, AMDGPUAS::GLOBAL_ADDRESS); 109 Value *Cast = 110 B.CreateAddrSpaceCast(Ptr, NewPT, Twine(Ptr->getName(), ".global")); 111 Value *CastBack = 112 B.CreateAddrSpaceCast(Cast, PT, Twine(Ptr->getName(), ".flat")); 113 Ptr->replaceUsesWithIf(CastBack, 114 [Cast](Use &U) { return U.getUser() != Cast; }); 115 116 return true; 117 } 118 119 // skip allocas 120 static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 121 BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 122 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 123 AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 124 125 // If this is a dynamic alloca, the value may depend on the loaded kernargs, 126 // so loads will need to be inserted before it. 127 if (!AI || !AI->isStaticAlloca()) 128 break; 129 } 130 131 return InsPt; 132 } 133 134 bool AMDGPUPromoteKernelArguments::run(Function &F, MemorySSA &MSSA) { 135 if (skipFunction(F)) 136 return false; 137 138 CallingConv::ID CC = F.getCallingConv(); 139 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 140 return false; 141 142 ArgCastInsertPt = &*getInsertPt(*F.begin()); 143 this->MSSA = &MSSA; 144 145 for (Argument &Arg : F.args()) { 146 if (Arg.use_empty()) 147 continue; 148 149 PointerType *PT = dyn_cast<PointerType>(Arg.getType()); 150 if (!PT || (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS && 151 PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS && 152 PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS)) 153 continue; 154 155 Ptrs.push_back(&Arg); 156 } 157 158 bool Changed = false; 159 while (!Ptrs.empty()) { 160 Value *Ptr = Ptrs.pop_back_val(); 161 Changed |= promotePointer(Ptr); 162 } 163 164 return Changed; 165 } 166 167 bool AMDGPUPromoteKernelArguments::runOnFunction(Function &F) { 168 MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); 169 return run(F, MSSA); 170 } 171 172 INITIALIZE_PASS_BEGIN(AMDGPUPromoteKernelArguments, DEBUG_TYPE, 173 "AMDGPU Promote Kernel Arguments", false, false) 174 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) 175 INITIALIZE_PASS_END(AMDGPUPromoteKernelArguments, DEBUG_TYPE, 176 "AMDGPU Promote Kernel Arguments", false, false) 177 178 char AMDGPUPromoteKernelArguments::ID = 0; 179 180 FunctionPass *llvm::createAMDGPUPromoteKernelArgumentsPass() { 181 return new AMDGPUPromoteKernelArguments(); 182 } 183 184 PreservedAnalyses 185 AMDGPUPromoteKernelArgumentsPass::run(Function &F, 186 FunctionAnalysisManager &AM) { 187 MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); 188 if (AMDGPUPromoteKernelArguments().run(F, MSSA)) { 189 PreservedAnalyses PA; 190 PA.preserveSet<CFGAnalyses>(); 191 PA.preserve<MemorySSAAnalysis>(); 192 return PA; 193 } 194 return PreservedAnalyses::all(); 195 } 196