1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass replaces accesses to kernel arguments with loads from 10 /// offsets from the kernarg base pointer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "GCNSubtarget.h" 16 #include "llvm/CodeGen/TargetPassConfig.h" 17 #include "llvm/IR/IntrinsicsAMDGPU.h" 18 #include "llvm/IR/IRBuilder.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Target/TargetMachine.h" 21 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 22 23 using namespace llvm; 24 25 namespace { 26 27 class AMDGPULowerKernelArguments : public FunctionPass{ 28 public: 29 static char ID; 30 31 AMDGPULowerKernelArguments() : FunctionPass(ID) {} 32 33 bool runOnFunction(Function &F) override; 34 35 void getAnalysisUsage(AnalysisUsage &AU) const override { 36 AU.addRequired<TargetPassConfig>(); 37 AU.setPreservesAll(); 38 } 39 }; 40 41 } // end anonymous namespace 42 43 // skip allocas 44 static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 45 BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 46 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 47 AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 48 49 // If this is a dynamic alloca, the value may depend on the loaded kernargs, 50 // so loads will need to be inserted before it. 51 if (!AI || !AI->isStaticAlloca()) 52 break; 53 } 54 55 return InsPt; 56 } 57 58 bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 59 CallingConv::ID CC = F.getCallingConv(); 60 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 61 return false; 62 63 auto &TPC = getAnalysis<TargetPassConfig>(); 64 65 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 66 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 67 LLVMContext &Ctx = F.getParent()->getContext(); 68 const DataLayout &DL = F.getParent()->getDataLayout(); 69 BasicBlock &EntryBlock = *F.begin(); 70 IRBuilder<> Builder(&*getInsertPt(EntryBlock)); 71 72 const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 73 const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 74 75 Align MaxAlign; 76 // FIXME: Alignment is broken broken with explicit arg offset.; 77 const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 78 if (TotalKernArgSize == 0) 79 return false; 80 81 CallInst *KernArgSegment = 82 Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 83 nullptr, F.getName() + ".kernarg.segment"); 84 85 KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 86 KernArgSegment->addAttribute(AttributeList::ReturnIndex, 87 Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 88 89 unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 90 uint64_t ExplicitArgOffset = 0; 91 92 for (Argument &Arg : F.args()) { 93 const bool IsByRef = Arg.hasByRefAttr(); 94 Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType(); 95 MaybeAlign ABITypeAlign = IsByRef ? Arg.getParamAlign() : None; 96 if (!ABITypeAlign) 97 ABITypeAlign = DL.getABITypeAlign(ArgTy); 98 99 uint64_t Size = DL.getTypeSizeInBits(ArgTy); 100 uint64_t AllocSize = DL.getTypeAllocSize(ArgTy); 101 102 uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 103 ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 104 105 if (Arg.use_empty()) 106 continue; 107 108 // If this is byval, the loads are already explicit in the function. We just 109 // need to rewrite the pointer values. 110 if (IsByRef) { 111 Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64( 112 Builder.getInt8Ty(), KernArgSegment, EltOffset, 113 Arg.getName() + ".byval.kernarg.offset"); 114 115 Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( 116 ArgOffsetPtr, Arg.getType()); 117 Arg.replaceAllUsesWith(CastOffsetPtr); 118 continue; 119 } 120 121 if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 122 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 123 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 124 // can't represent this with range metadata because it's only allowed for 125 // integer types. 126 if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 127 PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 128 !ST.hasUsableDSOffset()) 129 continue; 130 131 // FIXME: We can replace this with equivalent alias.scope/noalias 132 // metadata, but this appears to be a lot of work. 133 if (Arg.hasNoAliasAttr()) 134 continue; 135 } 136 137 auto *VT = dyn_cast<FixedVectorType>(ArgTy); 138 bool IsV3 = VT && VT->getNumElements() == 3; 139 bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 140 141 VectorType *V4Ty = nullptr; 142 143 int64_t AlignDownOffset = alignDown(EltOffset, 4); 144 int64_t OffsetDiff = EltOffset - AlignDownOffset; 145 Align AdjustedAlign = commonAlignment( 146 KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 147 148 Value *ArgPtr; 149 Type *AdjustedArgTy; 150 if (DoShiftOpt) { // FIXME: Handle aggregate types 151 // Since we don't have sub-dword scalar loads, avoid doing an extload by 152 // loading earlier than the argument address, and extracting the relevant 153 // bits. 154 // 155 // Additionally widen any sub-dword load to i32 even if suitably aligned, 156 // so that CSE between different argument loads works easily. 157 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 158 Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 159 Arg.getName() + ".kernarg.offset.align.down"); 160 AdjustedArgTy = Builder.getInt32Ty(); 161 } else { 162 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 163 Builder.getInt8Ty(), KernArgSegment, EltOffset, 164 Arg.getName() + ".kernarg.offset"); 165 AdjustedArgTy = ArgTy; 166 } 167 168 if (IsV3 && Size >= 32) { 169 V4Ty = FixedVectorType::get(VT->getElementType(), 4); 170 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 171 AdjustedArgTy = V4Ty; 172 } 173 174 ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 175 ArgPtr->getName() + ".cast"); 176 LoadInst *Load = 177 Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 178 Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 179 180 MDBuilder MDB(Ctx); 181 182 if (isa<PointerType>(ArgTy)) { 183 if (Arg.hasNonNullAttr()) 184 Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 185 186 uint64_t DerefBytes = Arg.getDereferenceableBytes(); 187 if (DerefBytes != 0) { 188 Load->setMetadata( 189 LLVMContext::MD_dereferenceable, 190 MDNode::get(Ctx, 191 MDB.createConstant( 192 ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 193 } 194 195 uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 196 if (DerefOrNullBytes != 0) { 197 Load->setMetadata( 198 LLVMContext::MD_dereferenceable_or_null, 199 MDNode::get(Ctx, 200 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 201 DerefOrNullBytes)))); 202 } 203 204 unsigned ParamAlign = Arg.getParamAlignment(); 205 if (ParamAlign != 0) { 206 Load->setMetadata( 207 LLVMContext::MD_align, 208 MDNode::get(Ctx, 209 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 210 ParamAlign)))); 211 } 212 } 213 214 // TODO: Convert noalias arg to !noalias 215 216 if (DoShiftOpt) { 217 Value *ExtractBits = OffsetDiff == 0 ? 218 Load : Builder.CreateLShr(Load, OffsetDiff * 8); 219 220 IntegerType *ArgIntTy = Builder.getIntNTy(Size); 221 Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 222 Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 223 Arg.getName() + ".load"); 224 Arg.replaceAllUsesWith(NewVal); 225 } else if (IsV3) { 226 Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2}, 227 Arg.getName() + ".load"); 228 Arg.replaceAllUsesWith(Shuf); 229 } else { 230 Load->setName(Arg.getName() + ".load"); 231 Arg.replaceAllUsesWith(Load); 232 } 233 } 234 235 KernArgSegment->addAttribute( 236 AttributeList::ReturnIndex, 237 Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 238 239 return true; 240 } 241 242 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 243 "AMDGPU Lower Kernel Arguments", false, false) 244 INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 245 false, false) 246 247 char AMDGPULowerKernelArguments::ID = 0; 248 249 FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 250 return new AMDGPULowerKernelArguments(); 251 } 252