1*0b57cec5SDimitry Andric //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 2*0b57cec5SDimitry Andric // 3*0b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*0b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0b57cec5SDimitry Andric // 7*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 8*0b57cec5SDimitry Andric // 9*0b57cec5SDimitry Andric /// \file This pass replaces accesses to kernel arguments with loads from 10*0b57cec5SDimitry Andric /// offsets from the kernarg base pointer. 11*0b57cec5SDimitry Andric // 12*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 13*0b57cec5SDimitry Andric 14*0b57cec5SDimitry Andric #include "AMDGPU.h" 15*0b57cec5SDimitry Andric #include "AMDGPUSubtarget.h" 16*0b57cec5SDimitry Andric #include "AMDGPUTargetMachine.h" 17*0b57cec5SDimitry Andric #include "llvm/ADT/StringRef.h" 18*0b57cec5SDimitry Andric #include "llvm/Analysis/Loads.h" 19*0b57cec5SDimitry Andric #include "llvm/CodeGen/Passes.h" 20*0b57cec5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 21*0b57cec5SDimitry Andric #include "llvm/IR/Attributes.h" 22*0b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h" 23*0b57cec5SDimitry Andric #include "llvm/IR/Constants.h" 24*0b57cec5SDimitry Andric #include "llvm/IR/DerivedTypes.h" 25*0b57cec5SDimitry Andric #include "llvm/IR/Function.h" 26*0b57cec5SDimitry Andric #include "llvm/IR/IRBuilder.h" 27*0b57cec5SDimitry Andric #include "llvm/IR/InstrTypes.h" 28*0b57cec5SDimitry Andric #include "llvm/IR/Instruction.h" 29*0b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 30*0b57cec5SDimitry Andric #include "llvm/IR/LLVMContext.h" 31*0b57cec5SDimitry Andric #include "llvm/IR/MDBuilder.h" 32*0b57cec5SDimitry Andric #include "llvm/IR/Metadata.h" 33*0b57cec5SDimitry Andric #include "llvm/IR/Operator.h" 34*0b57cec5SDimitry Andric #include "llvm/IR/Type.h" 35*0b57cec5SDimitry Andric #include "llvm/IR/Value.h" 36*0b57cec5SDimitry Andric #include "llvm/Pass.h" 37*0b57cec5SDimitry Andric #include "llvm/Support/Casting.h" 38*0b57cec5SDimitry Andric 39*0b57cec5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 40*0b57cec5SDimitry Andric 41*0b57cec5SDimitry Andric using namespace llvm; 42*0b57cec5SDimitry Andric 43*0b57cec5SDimitry Andric namespace { 44*0b57cec5SDimitry Andric 45*0b57cec5SDimitry Andric class AMDGPULowerKernelArguments : public FunctionPass{ 46*0b57cec5SDimitry Andric public: 47*0b57cec5SDimitry Andric static char ID; 48*0b57cec5SDimitry Andric 49*0b57cec5SDimitry Andric AMDGPULowerKernelArguments() : FunctionPass(ID) {} 50*0b57cec5SDimitry Andric 51*0b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 52*0b57cec5SDimitry Andric 53*0b57cec5SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 54*0b57cec5SDimitry Andric AU.addRequired<TargetPassConfig>(); 55*0b57cec5SDimitry Andric AU.setPreservesAll(); 56*0b57cec5SDimitry Andric } 57*0b57cec5SDimitry Andric }; 58*0b57cec5SDimitry Andric 59*0b57cec5SDimitry Andric } // end anonymous namespace 60*0b57cec5SDimitry Andric 61*0b57cec5SDimitry Andric bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 62*0b57cec5SDimitry Andric CallingConv::ID CC = F.getCallingConv(); 63*0b57cec5SDimitry Andric if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 64*0b57cec5SDimitry Andric return false; 65*0b57cec5SDimitry Andric 66*0b57cec5SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 67*0b57cec5SDimitry Andric 68*0b57cec5SDimitry Andric const TargetMachine &TM = TPC.getTM<TargetMachine>(); 69*0b57cec5SDimitry Andric const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 70*0b57cec5SDimitry Andric LLVMContext &Ctx = F.getParent()->getContext(); 71*0b57cec5SDimitry Andric const DataLayout &DL = F.getParent()->getDataLayout(); 72*0b57cec5SDimitry Andric BasicBlock &EntryBlock = *F.begin(); 73*0b57cec5SDimitry Andric IRBuilder<> Builder(&*EntryBlock.begin()); 74*0b57cec5SDimitry Andric 75*0b57cec5SDimitry Andric const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary 76*0b57cec5SDimitry Andric const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 77*0b57cec5SDimitry Andric 78*0b57cec5SDimitry Andric unsigned MaxAlign; 79*0b57cec5SDimitry Andric // FIXME: Alignment is broken broken with explicit arg offset.; 80*0b57cec5SDimitry Andric const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 81*0b57cec5SDimitry Andric if (TotalKernArgSize == 0) 82*0b57cec5SDimitry Andric return false; 83*0b57cec5SDimitry Andric 84*0b57cec5SDimitry Andric CallInst *KernArgSegment = 85*0b57cec5SDimitry Andric Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 86*0b57cec5SDimitry Andric nullptr, F.getName() + ".kernarg.segment"); 87*0b57cec5SDimitry Andric 88*0b57cec5SDimitry Andric KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 89*0b57cec5SDimitry Andric KernArgSegment->addAttribute(AttributeList::ReturnIndex, 90*0b57cec5SDimitry Andric Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 91*0b57cec5SDimitry Andric 92*0b57cec5SDimitry Andric unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 93*0b57cec5SDimitry Andric uint64_t ExplicitArgOffset = 0; 94*0b57cec5SDimitry Andric 95*0b57cec5SDimitry Andric for (Argument &Arg : F.args()) { 96*0b57cec5SDimitry Andric Type *ArgTy = Arg.getType(); 97*0b57cec5SDimitry Andric unsigned Align = DL.getABITypeAlignment(ArgTy); 98*0b57cec5SDimitry Andric unsigned Size = DL.getTypeSizeInBits(ArgTy); 99*0b57cec5SDimitry Andric unsigned AllocSize = DL.getTypeAllocSize(ArgTy); 100*0b57cec5SDimitry Andric 101*0b57cec5SDimitry Andric uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset; 102*0b57cec5SDimitry Andric ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize; 103*0b57cec5SDimitry Andric 104*0b57cec5SDimitry Andric if (Arg.use_empty()) 105*0b57cec5SDimitry Andric continue; 106*0b57cec5SDimitry Andric 107*0b57cec5SDimitry Andric if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 108*0b57cec5SDimitry Andric // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 109*0b57cec5SDimitry Andric // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 110*0b57cec5SDimitry Andric // can't represent this with range metadata because it's only allowed for 111*0b57cec5SDimitry Andric // integer types. 112*0b57cec5SDimitry Andric if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 113*0b57cec5SDimitry Andric PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 114*0b57cec5SDimitry Andric !ST.hasUsableDSOffset()) 115*0b57cec5SDimitry Andric continue; 116*0b57cec5SDimitry Andric 117*0b57cec5SDimitry Andric // FIXME: We can replace this with equivalent alias.scope/noalias 118*0b57cec5SDimitry Andric // metadata, but this appears to be a lot of work. 119*0b57cec5SDimitry Andric if (Arg.hasNoAliasAttr()) 120*0b57cec5SDimitry Andric continue; 121*0b57cec5SDimitry Andric } 122*0b57cec5SDimitry Andric 123*0b57cec5SDimitry Andric VectorType *VT = dyn_cast<VectorType>(ArgTy); 124*0b57cec5SDimitry Andric bool IsV3 = VT && VT->getNumElements() == 3; 125*0b57cec5SDimitry Andric bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 126*0b57cec5SDimitry Andric 127*0b57cec5SDimitry Andric VectorType *V4Ty = nullptr; 128*0b57cec5SDimitry Andric 129*0b57cec5SDimitry Andric int64_t AlignDownOffset = alignDown(EltOffset, 4); 130*0b57cec5SDimitry Andric int64_t OffsetDiff = EltOffset - AlignDownOffset; 131*0b57cec5SDimitry Andric unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset, 132*0b57cec5SDimitry Andric KernArgBaseAlign); 133*0b57cec5SDimitry Andric 134*0b57cec5SDimitry Andric Value *ArgPtr; 135*0b57cec5SDimitry Andric Type *AdjustedArgTy; 136*0b57cec5SDimitry Andric if (DoShiftOpt) { // FIXME: Handle aggregate types 137*0b57cec5SDimitry Andric // Since we don't have sub-dword scalar loads, avoid doing an extload by 138*0b57cec5SDimitry Andric // loading earlier than the argument address, and extracting the relevant 139*0b57cec5SDimitry Andric // bits. 140*0b57cec5SDimitry Andric // 141*0b57cec5SDimitry Andric // Additionally widen any sub-dword load to i32 even if suitably aligned, 142*0b57cec5SDimitry Andric // so that CSE between different argument loads works easily. 143*0b57cec5SDimitry Andric ArgPtr = Builder.CreateConstInBoundsGEP1_64( 144*0b57cec5SDimitry Andric Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 145*0b57cec5SDimitry Andric Arg.getName() + ".kernarg.offset.align.down"); 146*0b57cec5SDimitry Andric AdjustedArgTy = Builder.getInt32Ty(); 147*0b57cec5SDimitry Andric } else { 148*0b57cec5SDimitry Andric ArgPtr = Builder.CreateConstInBoundsGEP1_64( 149*0b57cec5SDimitry Andric Builder.getInt8Ty(), KernArgSegment, EltOffset, 150*0b57cec5SDimitry Andric Arg.getName() + ".kernarg.offset"); 151*0b57cec5SDimitry Andric AdjustedArgTy = ArgTy; 152*0b57cec5SDimitry Andric } 153*0b57cec5SDimitry Andric 154*0b57cec5SDimitry Andric if (IsV3 && Size >= 32) { 155*0b57cec5SDimitry Andric V4Ty = VectorType::get(VT->getVectorElementType(), 4); 156*0b57cec5SDimitry Andric // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 157*0b57cec5SDimitry Andric AdjustedArgTy = V4Ty; 158*0b57cec5SDimitry Andric } 159*0b57cec5SDimitry Andric 160*0b57cec5SDimitry Andric ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 161*0b57cec5SDimitry Andric ArgPtr->getName() + ".cast"); 162*0b57cec5SDimitry Andric LoadInst *Load = 163*0b57cec5SDimitry Andric Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 164*0b57cec5SDimitry Andric Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 165*0b57cec5SDimitry Andric 166*0b57cec5SDimitry Andric MDBuilder MDB(Ctx); 167*0b57cec5SDimitry Andric 168*0b57cec5SDimitry Andric if (isa<PointerType>(ArgTy)) { 169*0b57cec5SDimitry Andric if (Arg.hasNonNullAttr()) 170*0b57cec5SDimitry Andric Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 171*0b57cec5SDimitry Andric 172*0b57cec5SDimitry Andric uint64_t DerefBytes = Arg.getDereferenceableBytes(); 173*0b57cec5SDimitry Andric if (DerefBytes != 0) { 174*0b57cec5SDimitry Andric Load->setMetadata( 175*0b57cec5SDimitry Andric LLVMContext::MD_dereferenceable, 176*0b57cec5SDimitry Andric MDNode::get(Ctx, 177*0b57cec5SDimitry Andric MDB.createConstant( 178*0b57cec5SDimitry Andric ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 179*0b57cec5SDimitry Andric } 180*0b57cec5SDimitry Andric 181*0b57cec5SDimitry Andric uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 182*0b57cec5SDimitry Andric if (DerefOrNullBytes != 0) { 183*0b57cec5SDimitry Andric Load->setMetadata( 184*0b57cec5SDimitry Andric LLVMContext::MD_dereferenceable_or_null, 185*0b57cec5SDimitry Andric MDNode::get(Ctx, 186*0b57cec5SDimitry Andric MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 187*0b57cec5SDimitry Andric DerefOrNullBytes)))); 188*0b57cec5SDimitry Andric } 189*0b57cec5SDimitry Andric 190*0b57cec5SDimitry Andric unsigned ParamAlign = Arg.getParamAlignment(); 191*0b57cec5SDimitry Andric if (ParamAlign != 0) { 192*0b57cec5SDimitry Andric Load->setMetadata( 193*0b57cec5SDimitry Andric LLVMContext::MD_align, 194*0b57cec5SDimitry Andric MDNode::get(Ctx, 195*0b57cec5SDimitry Andric MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 196*0b57cec5SDimitry Andric ParamAlign)))); 197*0b57cec5SDimitry Andric } 198*0b57cec5SDimitry Andric } 199*0b57cec5SDimitry Andric 200*0b57cec5SDimitry Andric // TODO: Convert noalias arg to !noalias 201*0b57cec5SDimitry Andric 202*0b57cec5SDimitry Andric if (DoShiftOpt) { 203*0b57cec5SDimitry Andric Value *ExtractBits = OffsetDiff == 0 ? 204*0b57cec5SDimitry Andric Load : Builder.CreateLShr(Load, OffsetDiff * 8); 205*0b57cec5SDimitry Andric 206*0b57cec5SDimitry Andric IntegerType *ArgIntTy = Builder.getIntNTy(Size); 207*0b57cec5SDimitry Andric Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 208*0b57cec5SDimitry Andric Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 209*0b57cec5SDimitry Andric Arg.getName() + ".load"); 210*0b57cec5SDimitry Andric Arg.replaceAllUsesWith(NewVal); 211*0b57cec5SDimitry Andric } else if (IsV3) { 212*0b57cec5SDimitry Andric Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty), 213*0b57cec5SDimitry Andric {0, 1, 2}, 214*0b57cec5SDimitry Andric Arg.getName() + ".load"); 215*0b57cec5SDimitry Andric Arg.replaceAllUsesWith(Shuf); 216*0b57cec5SDimitry Andric } else { 217*0b57cec5SDimitry Andric Load->setName(Arg.getName() + ".load"); 218*0b57cec5SDimitry Andric Arg.replaceAllUsesWith(Load); 219*0b57cec5SDimitry Andric } 220*0b57cec5SDimitry Andric } 221*0b57cec5SDimitry Andric 222*0b57cec5SDimitry Andric KernArgSegment->addAttribute( 223*0b57cec5SDimitry Andric AttributeList::ReturnIndex, 224*0b57cec5SDimitry Andric Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 225*0b57cec5SDimitry Andric 226*0b57cec5SDimitry Andric return true; 227*0b57cec5SDimitry Andric } 228*0b57cec5SDimitry Andric 229*0b57cec5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 230*0b57cec5SDimitry Andric "AMDGPU Lower Kernel Arguments", false, false) 231*0b57cec5SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 232*0b57cec5SDimitry Andric false, false) 233*0b57cec5SDimitry Andric 234*0b57cec5SDimitry Andric char AMDGPULowerKernelArguments::ID = 0; 235*0b57cec5SDimitry Andric 236*0b57cec5SDimitry Andric FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 237*0b57cec5SDimitry Andric return new AMDGPULowerKernelArguments(); 238*0b57cec5SDimitry Andric } 239