1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass replaces accesses to kernel arguments with loads from 10 /// offsets from the kernarg base pointer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "GCNSubtarget.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/TargetPassConfig.h" 18 #include "llvm/IR/Attributes.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/IntrinsicsAMDGPU.h" 21 #include "llvm/IR/MDBuilder.h" 22 #include "llvm/Target/TargetMachine.h" 23 24 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 25 26 using namespace llvm; 27 28 namespace { 29 30 class AMDGPULowerKernelArguments : public FunctionPass { 31 public: 32 static char ID; 33 34 AMDGPULowerKernelArguments() : FunctionPass(ID) {} 35 36 bool runOnFunction(Function &F) override; 37 38 void getAnalysisUsage(AnalysisUsage &AU) const override { 39 AU.addRequired<TargetPassConfig>(); 40 AU.setPreservesAll(); 41 } 42 }; 43 44 } // end anonymous namespace 45 46 // skip allocas 47 static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 48 BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 49 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 50 AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 51 52 // If this is a dynamic alloca, the value may depend on the loaded kernargs, 53 // so loads will need to be inserted before it. 54 if (!AI || !AI->isStaticAlloca()) 55 break; 56 } 57 58 return InsPt; 59 } 60 61 static bool lowerKernelArguments(Function &F, const TargetMachine &TM) { 62 CallingConv::ID CC = F.getCallingConv(); 63 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 64 return false; 65 66 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 67 LLVMContext &Ctx = F.getParent()->getContext(); 68 const DataLayout &DL = F.getDataLayout(); 69 BasicBlock &EntryBlock = *F.begin(); 70 IRBuilder<> Builder(&EntryBlock, getInsertPt(EntryBlock)); 71 72 const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 73 const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(); 74 75 Align MaxAlign; 76 // FIXME: Alignment is broken with explicit arg offset.; 77 const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 78 if (TotalKernArgSize == 0) 79 return false; 80 81 CallInst *KernArgSegment = 82 Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, 83 nullptr, F.getName() + ".kernarg.segment"); 84 KernArgSegment->addRetAttr(Attribute::NonNull); 85 KernArgSegment->addRetAttr( 86 Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 87 88 uint64_t ExplicitArgOffset = 0; 89 for (Argument &Arg : F.args()) { 90 const bool IsByRef = Arg.hasByRefAttr(); 91 Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType(); 92 MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt; 93 Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy); 94 95 uint64_t Size = DL.getTypeSizeInBits(ArgTy); 96 uint64_t AllocSize = DL.getTypeAllocSize(ArgTy); 97 98 uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 99 ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 100 101 // Skip inreg arguments which should be preloaded. 102 if (Arg.use_empty() || Arg.hasInRegAttr()) 103 continue; 104 105 // If this is byval, the loads are already explicit in the function. We just 106 // need to rewrite the pointer values. 107 if (IsByRef) { 108 Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64( 109 Builder.getInt8Ty(), KernArgSegment, EltOffset, 110 Arg.getName() + ".byval.kernarg.offset"); 111 112 Value *CastOffsetPtr = 113 Builder.CreateAddrSpaceCast(ArgOffsetPtr, Arg.getType()); 114 Arg.replaceAllUsesWith(CastOffsetPtr); 115 continue; 116 } 117 118 if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 119 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 120 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 121 // can't represent this with range metadata because it's only allowed for 122 // integer types. 123 if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 124 PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 125 !ST.hasUsableDSOffset()) 126 continue; 127 128 // FIXME: We can replace this with equivalent alias.scope/noalias 129 // metadata, but this appears to be a lot of work. 130 if (Arg.hasNoAliasAttr()) 131 continue; 132 } 133 134 auto *VT = dyn_cast<FixedVectorType>(ArgTy); 135 bool IsV3 = VT && VT->getNumElements() == 3; 136 bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 137 138 VectorType *V4Ty = nullptr; 139 140 int64_t AlignDownOffset = alignDown(EltOffset, 4); 141 int64_t OffsetDiff = EltOffset - AlignDownOffset; 142 Align AdjustedAlign = commonAlignment( 143 KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 144 145 Value *ArgPtr; 146 Type *AdjustedArgTy; 147 if (DoShiftOpt) { // FIXME: Handle aggregate types 148 // Since we don't have sub-dword scalar loads, avoid doing an extload by 149 // loading earlier than the argument address, and extracting the relevant 150 // bits. 151 // TODO: Update this for GFX12 which does have scalar sub-dword loads. 152 // 153 // Additionally widen any sub-dword load to i32 even if suitably aligned, 154 // so that CSE between different argument loads works easily. 155 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 156 Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 157 Arg.getName() + ".kernarg.offset.align.down"); 158 AdjustedArgTy = Builder.getInt32Ty(); 159 } else { 160 ArgPtr = Builder.CreateConstInBoundsGEP1_64( 161 Builder.getInt8Ty(), KernArgSegment, EltOffset, 162 Arg.getName() + ".kernarg.offset"); 163 AdjustedArgTy = ArgTy; 164 } 165 166 if (IsV3 && Size >= 32) { 167 V4Ty = FixedVectorType::get(VT->getElementType(), 4); 168 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 169 AdjustedArgTy = V4Ty; 170 } 171 172 LoadInst *Load = 173 Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 174 Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 175 176 MDBuilder MDB(Ctx); 177 178 if (Arg.hasAttribute(Attribute::NoUndef)) 179 Load->setMetadata(LLVMContext::MD_noundef, MDNode::get(Ctx, {})); 180 181 if (Arg.hasAttribute(Attribute::Range)) { 182 const ConstantRange &Range = 183 Arg.getAttribute(Attribute::Range).getValueAsConstantRange(); 184 Load->setMetadata(LLVMContext::MD_range, 185 MDB.createRange(Range.getLower(), Range.getUpper())); 186 } 187 188 if (isa<PointerType>(ArgTy)) { 189 if (Arg.hasNonNullAttr()) 190 Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 191 192 uint64_t DerefBytes = Arg.getDereferenceableBytes(); 193 if (DerefBytes != 0) { 194 Load->setMetadata( 195 LLVMContext::MD_dereferenceable, 196 MDNode::get(Ctx, 197 MDB.createConstant( 198 ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 199 } 200 201 uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 202 if (DerefOrNullBytes != 0) { 203 Load->setMetadata( 204 LLVMContext::MD_dereferenceable_or_null, 205 MDNode::get(Ctx, 206 MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 207 DerefOrNullBytes)))); 208 } 209 210 if (MaybeAlign ParamAlign = Arg.getParamAlign()) { 211 Load->setMetadata( 212 LLVMContext::MD_align, 213 MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( 214 Builder.getInt64Ty(), ParamAlign->value())))); 215 } 216 } 217 218 // TODO: Convert noalias arg to !noalias 219 220 if (DoShiftOpt) { 221 Value *ExtractBits = OffsetDiff == 0 ? 222 Load : Builder.CreateLShr(Load, OffsetDiff * 8); 223 224 IntegerType *ArgIntTy = Builder.getIntNTy(Size); 225 Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 226 Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 227 Arg.getName() + ".load"); 228 Arg.replaceAllUsesWith(NewVal); 229 } else if (IsV3) { 230 Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2}, 231 Arg.getName() + ".load"); 232 Arg.replaceAllUsesWith(Shuf); 233 } else { 234 Load->setName(Arg.getName() + ".load"); 235 Arg.replaceAllUsesWith(Load); 236 } 237 } 238 239 KernArgSegment->addRetAttr( 240 Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 241 242 return true; 243 } 244 245 bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 246 auto &TPC = getAnalysis<TargetPassConfig>(); 247 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 248 return lowerKernelArguments(F, TM); 249 } 250 251 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 252 "AMDGPU Lower Kernel Arguments", false, false) 253 INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 254 false, false) 255 256 char AMDGPULowerKernelArguments::ID = 0; 257 258 FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 259 return new AMDGPULowerKernelArguments(); 260 } 261 262 PreservedAnalyses 263 AMDGPULowerKernelArgumentsPass::run(Function &F, FunctionAnalysisManager &AM) { 264 bool Changed = lowerKernelArguments(F, TM); 265 if (Changed) { 266 // TODO: Preserves a lot more. 267 PreservedAnalyses PA; 268 PA.preserveSet<CFGAnalyses>(); 269 return PA; 270 } 271 272 return PreservedAnalyses::all(); 273 } 274