1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass replaces accesses to kernel arguments with loads from 10 /// offsets from the kernarg base pointer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "GCNSubtarget.h" 16 #include "llvm/CodeGen/TargetPassConfig.h" 17 #include "llvm/IR/IntrinsicsAMDGPU.h" 18 #include "llvm/IR/MDBuilder.h" 19 #include "llvm/Target/TargetMachine.h" 20 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 21 22 using namespace llvm; 23 24 namespace { 25 26 class AMDGPULowerKernelArguments : public FunctionPass{ 27 public: 28 static char ID; 29 30 AMDGPULowerKernelArguments() : FunctionPass(ID) {} 31 32 bool runOnFunction(Function &F) override; 33 34 void getAnalysisUsage(AnalysisUsage &AU) const override { 35 AU.addRequired<TargetPassConfig>(); 36 AU.setPreservesAll(); 37 } 38 }; 39 40 } // end anonymous namespace 41 42 // skip allocas 43 static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 44 BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 45 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 46 AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 47 48 // If this is a dynamic alloca, the value may depend on the loaded kernargs, 49 // so loads will need to be inserted before it. 50 if (!AI || !AI->isStaticAlloca()) 51 break; 52 } 53 54 return InsPt; 55 } 56 57 bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 58 CallingConv::ID CC = F.getCallingConv(); 59 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 60 return false; 61 62 auto &TPC = getAnalysis<TargetPassConfig>(); 63 64 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 65 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 66 LLVMContext &Ctx = F.getParent()->getContext(); 67 const DataLayout &DL = F.getParent()->getDataLayout(); 68 BasicBlock &EntryBlock = *F.begin(); 69 IRBuilder<> Builder(&*getInsertPt(EntryBlock)); 70 71 const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 72 const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 73 74 Align MaxAlign; 75 // FIXME: Alignment is broken broken with explicit arg offset.; 76 const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 77 if (TotalKernArgSize == 0) 78 return false; 79 80 CallInst *KernArgSegment = 81 Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 82 nullptr, F.getName() + ".kernarg.segment"); 83 84 KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 85 KernArgSegment->addAttribute(AttributeList::ReturnIndex, 86 Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 87 88 unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 89 uint64_t ExplicitArgOffset = 0; 90 91 for (Argument &Arg : F.args()) { 92 const bool IsByRef = Arg.hasByRefAttr(); 93 Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType(); 94 MaybeAlign ABITypeAlign = IsByRef ? Arg.getParamAlign() : None; 95 if (!ABITypeAlign) 96 ABITypeAlign = DL.getABITypeAlign(ArgTy); 97 98 uint64_t Size = DL.getTypeSizeInBits(ArgTy); 99 uint64_t AllocSize = DL.getTypeAllocSize(ArgTy); 100 101 uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 102 ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 103 104 if (Arg.use_empty()) 105 continue; 106 107 // If this is byval, the loads are already explicit in the function. We just 108 // need to rewrite the pointer values. 109 if (IsByRef) { 110 Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64( 111 Builder.getInt8Ty(), KernArgSegment, EltOffset, 112 Arg.getName() + ".byval.kernarg.offset"); 113 114 Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( 115 ArgOffsetPtr, Arg.getType()); 116 Arg.replaceAllUsesWith(CastOffsetPtr); 117 continue; 118 } 119 120 if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 121 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 122 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 123 // can't represent this with range metadata because it's only allowed for 124 // integer types. 125 if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 126 PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 127 !ST.hasUsableDSOffset()) 128 continue; 129 130 // FIXME: We can replace this with equivalent alias.scope/noalias 131 // metadata, but this appears to be a lot of work. 132 if (Arg.hasNoAliasAttr()) 133 continue; 134 } 135 136 auto *VT = dyn_cast<FixedVectorType>(ArgTy); 137 bool IsV3 = VT && VT->getNumElements() == 3; 138 bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 139 140 VectorType *V4Ty = nullptr; 141 142 int64_t AlignDownOffset = alignDown(EltOffset, 4); 143 int64_t OffsetDiff = EltOffset - AlignDownOffset; 144 Align AdjustedAlign = commonAlignment( 145 KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 146 147 Value *ArgPtr; 148 Type *AdjustedArgTy; 149 if (DoShiftOpt) { // FIXME: Handle aggregate types 150 // Since we don't have sub-dword scalar loads, avoid doing an extload by 151 // loading earlier than the argument address, and extracting the relevant 152 // bits. 153 // 154 // Additionally widen any sub-dword load to i32 even if suitably aligned, 155 // so that CSE between different argument loads works easily. 156 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 157 Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 158 Arg.getName() + ".kernarg.offset.align.down"); 159 AdjustedArgTy = Builder.getInt32Ty(); 160 } else { 161 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 162 Builder.getInt8Ty(), KernArgSegment, EltOffset, 163 Arg.getName() + ".kernarg.offset"); 164 AdjustedArgTy = ArgTy; 165 } 166 167 if (IsV3 && Size >= 32) { 168 V4Ty = FixedVectorType::get(VT->getElementType(), 4); 169 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 170 AdjustedArgTy = V4Ty; 171 } 172 173 ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 174 ArgPtr->getName() + ".cast"); 175 LoadInst *Load = 176 Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 177 Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 178 179 MDBuilder MDB(Ctx); 180 181 if (isa<PointerType>(ArgTy)) { 182 if (Arg.hasNonNullAttr()) 183 Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 184 185 uint64_t DerefBytes = Arg.getDereferenceableBytes(); 186 if (DerefBytes != 0) { 187 Load->setMetadata( 188 LLVMContext::MD_dereferenceable, 189 MDNode::get(Ctx, 190 MDB.createConstant( 191 ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 192 } 193 194 uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 195 if (DerefOrNullBytes != 0) { 196 Load->setMetadata( 197 LLVMContext::MD_dereferenceable_or_null, 198 MDNode::get(Ctx, 199 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 200 DerefOrNullBytes)))); 201 } 202 203 unsigned ParamAlign = Arg.getParamAlignment(); 204 if (ParamAlign != 0) { 205 Load->setMetadata( 206 LLVMContext::MD_align, 207 MDNode::get(Ctx, 208 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 209 ParamAlign)))); 210 } 211 } 212 213 // TODO: Convert noalias arg to !noalias 214 215 if (DoShiftOpt) { 216 Value *ExtractBits = OffsetDiff == 0 ? 217 Load : Builder.CreateLShr(Load, OffsetDiff * 8); 218 219 IntegerType *ArgIntTy = Builder.getIntNTy(Size); 220 Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 221 Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 222 Arg.getName() + ".load"); 223 Arg.replaceAllUsesWith(NewVal); 224 } else if (IsV3) { 225 Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2}, 226 Arg.getName() + ".load"); 227 Arg.replaceAllUsesWith(Shuf); 228 } else { 229 Load->setName(Arg.getName() + ".load"); 230 Arg.replaceAllUsesWith(Load); 231 } 232 } 233 234 KernArgSegment->addAttribute( 235 AttributeList::ReturnIndex, 236 Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 237 238 return true; 239 } 240 241 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 242 "AMDGPU Lower Kernel Arguments", false, false) 243 INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 244 false, false) 245 246 char AMDGPULowerKernelArguments::ID = 0; 247 248 FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 249 return new AMDGPULowerKernelArguments(); 250 } 251