xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp (revision 0b57cec536236d46e3dba9bd041533462f33dbb7)
1*0b57cec5SDimitry Andric //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
2*0b57cec5SDimitry Andric //
3*0b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0b57cec5SDimitry Andric //
7*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
8*0b57cec5SDimitry Andric //
9*0b57cec5SDimitry Andric /// \file This pass replaces accesses to kernel arguments with loads from
10*0b57cec5SDimitry Andric /// offsets from the kernarg base pointer.
11*0b57cec5SDimitry Andric //
12*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
13*0b57cec5SDimitry Andric 
14*0b57cec5SDimitry Andric #include "AMDGPU.h"
15*0b57cec5SDimitry Andric #include "AMDGPUSubtarget.h"
16*0b57cec5SDimitry Andric #include "AMDGPUTargetMachine.h"
17*0b57cec5SDimitry Andric #include "llvm/ADT/StringRef.h"
18*0b57cec5SDimitry Andric #include "llvm/Analysis/Loads.h"
19*0b57cec5SDimitry Andric #include "llvm/CodeGen/Passes.h"
20*0b57cec5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
21*0b57cec5SDimitry Andric #include "llvm/IR/Attributes.h"
22*0b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h"
23*0b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
24*0b57cec5SDimitry Andric #include "llvm/IR/DerivedTypes.h"
25*0b57cec5SDimitry Andric #include "llvm/IR/Function.h"
26*0b57cec5SDimitry Andric #include "llvm/IR/IRBuilder.h"
27*0b57cec5SDimitry Andric #include "llvm/IR/InstrTypes.h"
28*0b57cec5SDimitry Andric #include "llvm/IR/Instruction.h"
29*0b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
30*0b57cec5SDimitry Andric #include "llvm/IR/LLVMContext.h"
31*0b57cec5SDimitry Andric #include "llvm/IR/MDBuilder.h"
32*0b57cec5SDimitry Andric #include "llvm/IR/Metadata.h"
33*0b57cec5SDimitry Andric #include "llvm/IR/Operator.h"
34*0b57cec5SDimitry Andric #include "llvm/IR/Type.h"
35*0b57cec5SDimitry Andric #include "llvm/IR/Value.h"
36*0b57cec5SDimitry Andric #include "llvm/Pass.h"
37*0b57cec5SDimitry Andric #include "llvm/Support/Casting.h"
38*0b57cec5SDimitry Andric 
39*0b57cec5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
40*0b57cec5SDimitry Andric 
41*0b57cec5SDimitry Andric using namespace llvm;
42*0b57cec5SDimitry Andric 
43*0b57cec5SDimitry Andric namespace {
44*0b57cec5SDimitry Andric 
45*0b57cec5SDimitry Andric class AMDGPULowerKernelArguments : public FunctionPass{
46*0b57cec5SDimitry Andric public:
47*0b57cec5SDimitry Andric   static char ID;
48*0b57cec5SDimitry Andric 
49*0b57cec5SDimitry Andric   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
50*0b57cec5SDimitry Andric 
51*0b57cec5SDimitry Andric   bool runOnFunction(Function &F) override;
52*0b57cec5SDimitry Andric 
53*0b57cec5SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
54*0b57cec5SDimitry Andric     AU.addRequired<TargetPassConfig>();
55*0b57cec5SDimitry Andric     AU.setPreservesAll();
56*0b57cec5SDimitry Andric  }
57*0b57cec5SDimitry Andric };
58*0b57cec5SDimitry Andric 
59*0b57cec5SDimitry Andric } // end anonymous namespace
60*0b57cec5SDimitry Andric 
61*0b57cec5SDimitry Andric bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
62*0b57cec5SDimitry Andric   CallingConv::ID CC = F.getCallingConv();
63*0b57cec5SDimitry Andric   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
64*0b57cec5SDimitry Andric     return false;
65*0b57cec5SDimitry Andric 
66*0b57cec5SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
67*0b57cec5SDimitry Andric 
68*0b57cec5SDimitry Andric   const TargetMachine &TM = TPC.getTM<TargetMachine>();
69*0b57cec5SDimitry Andric   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
70*0b57cec5SDimitry Andric   LLVMContext &Ctx = F.getParent()->getContext();
71*0b57cec5SDimitry Andric   const DataLayout &DL = F.getParent()->getDataLayout();
72*0b57cec5SDimitry Andric   BasicBlock &EntryBlock = *F.begin();
73*0b57cec5SDimitry Andric   IRBuilder<> Builder(&*EntryBlock.begin());
74*0b57cec5SDimitry Andric 
75*0b57cec5SDimitry Andric   const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary
76*0b57cec5SDimitry Andric   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
77*0b57cec5SDimitry Andric 
78*0b57cec5SDimitry Andric   unsigned MaxAlign;
79*0b57cec5SDimitry Andric   // FIXME: Alignment is broken broken with explicit arg offset.;
80*0b57cec5SDimitry Andric   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
81*0b57cec5SDimitry Andric   if (TotalKernArgSize == 0)
82*0b57cec5SDimitry Andric     return false;
83*0b57cec5SDimitry Andric 
84*0b57cec5SDimitry Andric   CallInst *KernArgSegment =
85*0b57cec5SDimitry Andric       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
86*0b57cec5SDimitry Andric                               nullptr, F.getName() + ".kernarg.segment");
87*0b57cec5SDimitry Andric 
88*0b57cec5SDimitry Andric   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
89*0b57cec5SDimitry Andric   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
90*0b57cec5SDimitry Andric     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
91*0b57cec5SDimitry Andric 
92*0b57cec5SDimitry Andric   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
93*0b57cec5SDimitry Andric   uint64_t ExplicitArgOffset = 0;
94*0b57cec5SDimitry Andric 
95*0b57cec5SDimitry Andric   for (Argument &Arg : F.args()) {
96*0b57cec5SDimitry Andric     Type *ArgTy = Arg.getType();
97*0b57cec5SDimitry Andric     unsigned Align = DL.getABITypeAlignment(ArgTy);
98*0b57cec5SDimitry Andric     unsigned Size = DL.getTypeSizeInBits(ArgTy);
99*0b57cec5SDimitry Andric     unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
100*0b57cec5SDimitry Andric 
101*0b57cec5SDimitry Andric     uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset;
102*0b57cec5SDimitry Andric     ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
103*0b57cec5SDimitry Andric 
104*0b57cec5SDimitry Andric     if (Arg.use_empty())
105*0b57cec5SDimitry Andric       continue;
106*0b57cec5SDimitry Andric 
107*0b57cec5SDimitry Andric     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
108*0b57cec5SDimitry Andric       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
109*0b57cec5SDimitry Andric       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
110*0b57cec5SDimitry Andric       // can't represent this with range metadata because it's only allowed for
111*0b57cec5SDimitry Andric       // integer types.
112*0b57cec5SDimitry Andric       if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
113*0b57cec5SDimitry Andric            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
114*0b57cec5SDimitry Andric           !ST.hasUsableDSOffset())
115*0b57cec5SDimitry Andric         continue;
116*0b57cec5SDimitry Andric 
117*0b57cec5SDimitry Andric       // FIXME: We can replace this with equivalent alias.scope/noalias
118*0b57cec5SDimitry Andric       // metadata, but this appears to be a lot of work.
119*0b57cec5SDimitry Andric       if (Arg.hasNoAliasAttr())
120*0b57cec5SDimitry Andric         continue;
121*0b57cec5SDimitry Andric     }
122*0b57cec5SDimitry Andric 
123*0b57cec5SDimitry Andric     VectorType *VT = dyn_cast<VectorType>(ArgTy);
124*0b57cec5SDimitry Andric     bool IsV3 = VT && VT->getNumElements() == 3;
125*0b57cec5SDimitry Andric     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
126*0b57cec5SDimitry Andric 
127*0b57cec5SDimitry Andric     VectorType *V4Ty = nullptr;
128*0b57cec5SDimitry Andric 
129*0b57cec5SDimitry Andric     int64_t AlignDownOffset = alignDown(EltOffset, 4);
130*0b57cec5SDimitry Andric     int64_t OffsetDiff = EltOffset - AlignDownOffset;
131*0b57cec5SDimitry Andric     unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset,
132*0b57cec5SDimitry Andric                                       KernArgBaseAlign);
133*0b57cec5SDimitry Andric 
134*0b57cec5SDimitry Andric     Value *ArgPtr;
135*0b57cec5SDimitry Andric     Type *AdjustedArgTy;
136*0b57cec5SDimitry Andric     if (DoShiftOpt) { // FIXME: Handle aggregate types
137*0b57cec5SDimitry Andric       // Since we don't have sub-dword scalar loads, avoid doing an extload by
138*0b57cec5SDimitry Andric       // loading earlier than the argument address, and extracting the relevant
139*0b57cec5SDimitry Andric       // bits.
140*0b57cec5SDimitry Andric       //
141*0b57cec5SDimitry Andric       // Additionally widen any sub-dword load to i32 even if suitably aligned,
142*0b57cec5SDimitry Andric       // so that CSE between different argument loads works easily.
143*0b57cec5SDimitry Andric       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
144*0b57cec5SDimitry Andric           Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
145*0b57cec5SDimitry Andric           Arg.getName() + ".kernarg.offset.align.down");
146*0b57cec5SDimitry Andric       AdjustedArgTy = Builder.getInt32Ty();
147*0b57cec5SDimitry Andric     } else {
148*0b57cec5SDimitry Andric       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
149*0b57cec5SDimitry Andric           Builder.getInt8Ty(), KernArgSegment, EltOffset,
150*0b57cec5SDimitry Andric           Arg.getName() + ".kernarg.offset");
151*0b57cec5SDimitry Andric       AdjustedArgTy = ArgTy;
152*0b57cec5SDimitry Andric     }
153*0b57cec5SDimitry Andric 
154*0b57cec5SDimitry Andric     if (IsV3 && Size >= 32) {
155*0b57cec5SDimitry Andric       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
156*0b57cec5SDimitry Andric       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
157*0b57cec5SDimitry Andric       AdjustedArgTy = V4Ty;
158*0b57cec5SDimitry Andric     }
159*0b57cec5SDimitry Andric 
160*0b57cec5SDimitry Andric     ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
161*0b57cec5SDimitry Andric                                    ArgPtr->getName() + ".cast");
162*0b57cec5SDimitry Andric     LoadInst *Load =
163*0b57cec5SDimitry Andric         Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
164*0b57cec5SDimitry Andric     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
165*0b57cec5SDimitry Andric 
166*0b57cec5SDimitry Andric     MDBuilder MDB(Ctx);
167*0b57cec5SDimitry Andric 
168*0b57cec5SDimitry Andric     if (isa<PointerType>(ArgTy)) {
169*0b57cec5SDimitry Andric       if (Arg.hasNonNullAttr())
170*0b57cec5SDimitry Andric         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
171*0b57cec5SDimitry Andric 
172*0b57cec5SDimitry Andric       uint64_t DerefBytes = Arg.getDereferenceableBytes();
173*0b57cec5SDimitry Andric       if (DerefBytes != 0) {
174*0b57cec5SDimitry Andric         Load->setMetadata(
175*0b57cec5SDimitry Andric           LLVMContext::MD_dereferenceable,
176*0b57cec5SDimitry Andric           MDNode::get(Ctx,
177*0b57cec5SDimitry Andric                       MDB.createConstant(
178*0b57cec5SDimitry Andric                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
179*0b57cec5SDimitry Andric       }
180*0b57cec5SDimitry Andric 
181*0b57cec5SDimitry Andric       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
182*0b57cec5SDimitry Andric       if (DerefOrNullBytes != 0) {
183*0b57cec5SDimitry Andric         Load->setMetadata(
184*0b57cec5SDimitry Andric           LLVMContext::MD_dereferenceable_or_null,
185*0b57cec5SDimitry Andric           MDNode::get(Ctx,
186*0b57cec5SDimitry Andric                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
187*0b57cec5SDimitry Andric                                                           DerefOrNullBytes))));
188*0b57cec5SDimitry Andric       }
189*0b57cec5SDimitry Andric 
190*0b57cec5SDimitry Andric       unsigned ParamAlign = Arg.getParamAlignment();
191*0b57cec5SDimitry Andric       if (ParamAlign != 0) {
192*0b57cec5SDimitry Andric         Load->setMetadata(
193*0b57cec5SDimitry Andric           LLVMContext::MD_align,
194*0b57cec5SDimitry Andric           MDNode::get(Ctx,
195*0b57cec5SDimitry Andric                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
196*0b57cec5SDimitry Andric                                                           ParamAlign))));
197*0b57cec5SDimitry Andric       }
198*0b57cec5SDimitry Andric     }
199*0b57cec5SDimitry Andric 
200*0b57cec5SDimitry Andric     // TODO: Convert noalias arg to !noalias
201*0b57cec5SDimitry Andric 
202*0b57cec5SDimitry Andric     if (DoShiftOpt) {
203*0b57cec5SDimitry Andric       Value *ExtractBits = OffsetDiff == 0 ?
204*0b57cec5SDimitry Andric         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
205*0b57cec5SDimitry Andric 
206*0b57cec5SDimitry Andric       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
207*0b57cec5SDimitry Andric       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
208*0b57cec5SDimitry Andric       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
209*0b57cec5SDimitry Andric                                             Arg.getName() + ".load");
210*0b57cec5SDimitry Andric       Arg.replaceAllUsesWith(NewVal);
211*0b57cec5SDimitry Andric     } else if (IsV3) {
212*0b57cec5SDimitry Andric       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
213*0b57cec5SDimitry Andric                                                 {0, 1, 2},
214*0b57cec5SDimitry Andric                                                 Arg.getName() + ".load");
215*0b57cec5SDimitry Andric       Arg.replaceAllUsesWith(Shuf);
216*0b57cec5SDimitry Andric     } else {
217*0b57cec5SDimitry Andric       Load->setName(Arg.getName() + ".load");
218*0b57cec5SDimitry Andric       Arg.replaceAllUsesWith(Load);
219*0b57cec5SDimitry Andric     }
220*0b57cec5SDimitry Andric   }
221*0b57cec5SDimitry Andric 
222*0b57cec5SDimitry Andric   KernArgSegment->addAttribute(
223*0b57cec5SDimitry Andric     AttributeList::ReturnIndex,
224*0b57cec5SDimitry Andric     Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
225*0b57cec5SDimitry Andric 
226*0b57cec5SDimitry Andric   return true;
227*0b57cec5SDimitry Andric }
228*0b57cec5SDimitry Andric 
229*0b57cec5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
230*0b57cec5SDimitry Andric                       "AMDGPU Lower Kernel Arguments", false, false)
231*0b57cec5SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
232*0b57cec5SDimitry Andric                     false, false)
233*0b57cec5SDimitry Andric 
234*0b57cec5SDimitry Andric char AMDGPULowerKernelArguments::ID = 0;
235*0b57cec5SDimitry Andric 
236*0b57cec5SDimitry Andric FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
237*0b57cec5SDimitry Andric   return new AMDGPULowerKernelArguments();
238*0b57cec5SDimitry Andric }
239