xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (revision 357378bbdedf24ce2b90e9bd831af4a9db3ec70a)
1 //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 // Arguments to kernel and device functions are passed via param space,
11 // which imposes certain restrictions:
12 // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13 //
14 // Kernel parameters are read-only and accessible only via ld.param
15 // instruction, directly or via a pointer. Pointers to kernel
16 // arguments can't be converted to generic address space.
17 //
18 // Device function parameters are directly accessible via
19 // ld.param/st.param, but taking the address of one returns a pointer
20 // to a copy created in local space which *can't* be used with
21 // ld.param/st.param.
22 //
23 // Copying a byval struct into local memory in IR allows us to enforce
24 // the param space restrictions, gives the rest of IR a pointer w/o
25 // param space restrictions, and gives us an opportunity to eliminate
26 // the copy.
27 //
28 // Pointer arguments to kernel functions need more work to be lowered:
29 //
30 // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
31 //    global address space. This allows later optimizations to emit
32 //    ld.global.*/st.global.* for accessing these pointer arguments. For
33 //    example,
34 //
35 //    define void @foo(float* %input) {
36 //      %v = load float, float* %input, align 4
37 //      ...
38 //    }
39 //
40 //    becomes
41 //
42 //    define void @foo(float* %input) {
43 //      %input2 = addrspacecast float* %input to float addrspace(1)*
44 //      %input3 = addrspacecast float addrspace(1)* %input2 to float*
45 //      %v = load float, float* %input3, align 4
46 //      ...
47 //    }
48 //
49 //    Later, NVPTXInferAddressSpaces will optimize it to
50 //
51 //    define void @foo(float* %input) {
52 //      %input2 = addrspacecast float* %input to float addrspace(1)*
53 //      %v = load float, float addrspace(1)* %input2, align 4
54 //      ...
55 //    }
56 //
57 // 2. Convert pointers in a byval kernel parameter to pointers in the global
58 //    address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
59 //
60 //    struct S {
61 //      int *x;
62 //      int *y;
63 //    };
64 //    __global__ void foo(S s) {
65 //      int *b = s.y;
66 //      // use b
67 //    }
68 //
69 //    "b" points to the global address space. In the IR level,
70 //
71 //    define void @foo({i32*, i32*}* byval %input) {
72 //      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
73 //      %b = load i32*, i32** %b_ptr
74 //      ; use %b
75 //    }
76 //
77 //    becomes
78 //
79 //    define void @foo({i32*, i32*}* byval %input) {
80 //      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81 //      %b = load i32*, i32** %b_ptr
82 //      %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83 //      %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
84 //      ; use %b_generic
85 //    }
86 //
87 // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88 // cancel the addrspacecast pair this pass emits.
89 //===----------------------------------------------------------------------===//
90 
91 #include "MCTargetDesc/NVPTXBaseInfo.h"
92 #include "NVPTX.h"
93 #include "NVPTXTargetMachine.h"
94 #include "NVPTXUtilities.h"
95 #include "llvm/Analysis/ValueTracking.h"
96 #include "llvm/CodeGen/TargetPassConfig.h"
97 #include "llvm/IR/Function.h"
98 #include "llvm/IR/Instructions.h"
99 #include "llvm/IR/Module.h"
100 #include "llvm/IR/Type.h"
101 #include "llvm/InitializePasses.h"
102 #include "llvm/Pass.h"
103 #include <numeric>
104 #include <queue>
105 
106 #define DEBUG_TYPE "nvptx-lower-args"
107 
108 using namespace llvm;
109 
110 namespace llvm {
111 void initializeNVPTXLowerArgsPass(PassRegistry &);
112 }
113 
114 namespace {
115 class NVPTXLowerArgs : public FunctionPass {
116   bool runOnFunction(Function &F) override;
117 
118   bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F);
119   bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F);
120 
121   // handle byval parameters
122   void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg);
123   // Knowing Ptr must point to the global address space, this function
124   // addrspacecasts Ptr to global and then back to generic. This allows
125   // NVPTXInferAddressSpaces to fold the global-to-generic cast into
126   // loads/stores that appear later.
127   void markPointerAsGlobal(Value *Ptr);
128 
129 public:
130   static char ID; // Pass identification, replacement for typeid
131   NVPTXLowerArgs() : FunctionPass(ID) {}
132   StringRef getPassName() const override {
133     return "Lower pointer arguments of CUDA kernels";
134   }
135   void getAnalysisUsage(AnalysisUsage &AU) const override {
136     AU.addRequired<TargetPassConfig>();
137   }
138 };
139 } // namespace
140 
141 char NVPTXLowerArgs::ID = 1;
142 
143 INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args",
144                       "Lower arguments (NVPTX)", false, false)
145 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
146 INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
147                     "Lower arguments (NVPTX)", false, false)
148 
149 // =============================================================================
150 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
151 // and we can't guarantee that the only accesses are loads,
152 // then add the following instructions to the first basic block:
153 //
154 // %temp = alloca %struct.x, align 8
155 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
156 // %tv = load %struct.x addrspace(101)* %tempd
157 // store %struct.x %tv, %struct.x* %temp, align 8
158 //
159 // The above code allocates some space in the stack and copies the incoming
160 // struct from param space to local space.
161 // Then replace all occurrences of %d by %temp.
162 //
163 // In case we know that all users are GEPs or Loads, replace them with the same
164 // ones in parameter AS, so we can access them using ld.param.
165 // =============================================================================
166 
167 // Replaces the \p OldUser instruction with the same in parameter AS.
168 // Only Load and GEP are supported.
169 static void convertToParamAS(Value *OldUser, Value *Param) {
170   Instruction *I = dyn_cast<Instruction>(OldUser);
171   assert(I && "OldUser must be an instruction");
172   struct IP {
173     Instruction *OldInstruction;
174     Value *NewParam;
175   };
176   SmallVector<IP> ItemsToConvert = {{I, Param}};
177   SmallVector<Instruction *> InstructionsToDelete;
178 
179   auto CloneInstInParamAS = [](const IP &I) -> Value * {
180     if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
181       LI->setOperand(0, I.NewParam);
182       return LI;
183     }
184     if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
185       SmallVector<Value *, 4> Indices(GEP->indices());
186       auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(),
187                                                I.NewParam, Indices,
188                                                GEP->getName(), GEP);
189       NewGEP->setIsInBounds(GEP->isInBounds());
190       return NewGEP;
191     }
192     if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
193       auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM);
194       return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
195                                  BC->getName(), BC);
196     }
197     if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
198       assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
199       (void)ASC;
200       // Just pass through the argument, the old ASC is no longer needed.
201       return I.NewParam;
202     }
203     llvm_unreachable("Unsupported instruction");
204   };
205 
206   while (!ItemsToConvert.empty()) {
207     IP I = ItemsToConvert.pop_back_val();
208     Value *NewInst = CloneInstInParamAS(I);
209 
210     if (NewInst && NewInst != I.OldInstruction) {
211       // We've created a new instruction. Queue users of the old instruction to
212       // be converted and the instruction itself to be deleted. We can't delete
213       // the old instruction yet, because it's still in use by a load somewhere.
214       for (Value *V : I.OldInstruction->users())
215         ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
216 
217       InstructionsToDelete.push_back(I.OldInstruction);
218     }
219   }
220 
221   // Now we know that all argument loads are using addresses in parameter space
222   // and we can finally remove the old instructions in generic AS.  Instructions
223   // scheduled for removal should be processed in reverse order so the ones
224   // closest to the load are deleted first. Otherwise they may still be in use.
225   // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
226   // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
227   // the BitCast.
228   for (Instruction *I : llvm::reverse(InstructionsToDelete))
229     I->eraseFromParent();
230 }
231 
232 // Adjust alignment of arguments passed byval in .param address space. We can
233 // increase alignment of such arguments in a way that ensures that we can
234 // effectively vectorize their loads. We should also traverse all loads from
235 // byval pointer and adjust their alignment, if those were using known offset.
236 // Such alignment changes must be conformed with parameter store and load in
237 // NVPTXTargetLowering::LowerCall.
238 static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
239                                     const NVPTXTargetLowering *TLI) {
240   Function *Func = Arg->getParent();
241   Type *StructType = Arg->getParamByValType();
242   const DataLayout DL(Func->getParent());
243 
244   uint64_t NewArgAlign =
245       TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
246   uint64_t CurArgAlign =
247       Arg->getAttribute(Attribute::Alignment).getValueAsInt();
248 
249   if (CurArgAlign >= NewArgAlign)
250     return;
251 
252   LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
253                     << CurArgAlign << " for " << *Arg << '\n');
254 
255   auto NewAlignAttr =
256       Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
257   Arg->removeAttr(Attribute::Alignment);
258   Arg->addAttr(NewAlignAttr);
259 
260   struct Load {
261     LoadInst *Inst;
262     uint64_t Offset;
263   };
264 
265   struct LoadContext {
266     Value *InitialVal;
267     uint64_t Offset;
268   };
269 
270   SmallVector<Load> Loads;
271   std::queue<LoadContext> Worklist;
272   Worklist.push({ArgInParamAS, 0});
273 
274   while (!Worklist.empty()) {
275     LoadContext Ctx = Worklist.front();
276     Worklist.pop();
277 
278     for (User *CurUser : Ctx.InitialVal->users()) {
279       if (auto *I = dyn_cast<LoadInst>(CurUser)) {
280         Loads.push_back({I, Ctx.Offset});
281         continue;
282       }
283 
284       if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
285         Worklist.push({I, Ctx.Offset});
286         continue;
287       }
288 
289       if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
290         APInt OffsetAccumulated =
291             APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
292 
293         if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
294           continue;
295 
296         uint64_t OffsetLimit = -1;
297         uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit);
298         assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
299 
300         Worklist.push({I, Ctx.Offset + Offset});
301         continue;
302       }
303 
304       llvm_unreachable("All users must be one of: load, "
305                        "bitcast, getelementptr.");
306     }
307   }
308 
309   for (Load &CurLoad : Loads) {
310     Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
311     Align CurLoadAlign(CurLoad.Inst->getAlign());
312     CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
313   }
314 }
315 
316 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
317                                       Argument *Arg) {
318   Function *Func = Arg->getParent();
319   Instruction *FirstInst = &(Func->getEntryBlock().front());
320   Type *StructType = Arg->getParamByValType();
321   assert(StructType && "Missing byval type");
322 
323   auto IsALoadChain = [&](Value *Start) {
324     SmallVector<Value *, 16> ValuesToCheck = {Start};
325     auto IsALoadChainInstr = [](Value *V) -> bool {
326       if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
327         return true;
328       // ASC to param space are OK, too -- we'll just strip them.
329       if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
330         if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
331           return true;
332       }
333       return false;
334     };
335 
336     while (!ValuesToCheck.empty()) {
337       Value *V = ValuesToCheck.pop_back_val();
338       if (!IsALoadChainInstr(V)) {
339         LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
340                           << "\n");
341         (void)Arg;
342         return false;
343       }
344       if (!isa<LoadInst>(V))
345         llvm::append_range(ValuesToCheck, V->users());
346     }
347     return true;
348   };
349 
350   if (llvm::all_of(Arg->users(), IsALoadChain)) {
351     // Convert all loads and intermediate operations to use parameter AS and
352     // skip creation of a local copy of the argument.
353     SmallVector<User *, 16> UsersToUpdate(Arg->users());
354     Value *ArgInParamAS = new AddrSpaceCastInst(
355         Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
356         FirstInst);
357     for (Value *V : UsersToUpdate)
358       convertToParamAS(V, ArgInParamAS);
359     LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
360 
361     const auto *TLI =
362         cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
363 
364     adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
365 
366     return;
367   }
368 
369   // Otherwise we have to create a temporary copy.
370   const DataLayout &DL = Func->getParent()->getDataLayout();
371   unsigned AS = DL.getAllocaAddrSpace();
372   AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
373   // Set the alignment to alignment of the byval parameter. This is because,
374   // later load/stores assume that alignment, and we are going to replace
375   // the use of the byval parameter with this alloca instruction.
376   AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
377                            .value_or(DL.getPrefTypeAlign(StructType)));
378   Arg->replaceAllUsesWith(AllocA);
379 
380   Value *ArgInParam = new AddrSpaceCastInst(
381       Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
382       FirstInst);
383   // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
384   // addrspacecast preserves alignment.  Since params are constant, this load is
385   // definitely not volatile.
386   LoadInst *LI =
387       new LoadInst(StructType, ArgInParam, Arg->getName(),
388                    /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
389   new StoreInst(LI, AllocA, FirstInst);
390 }
391 
392 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
393   if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
394     return;
395 
396   // Deciding where to emit the addrspacecast pair.
397   BasicBlock::iterator InsertPt;
398   if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
399     // Insert at the functon entry if Ptr is an argument.
400     InsertPt = Arg->getParent()->getEntryBlock().begin();
401   } else {
402     // Insert right after Ptr if Ptr is an instruction.
403     InsertPt = ++cast<Instruction>(Ptr)->getIterator();
404     assert(InsertPt != InsertPt->getParent()->end() &&
405            "We don't call this function with Ptr being a terminator.");
406   }
407 
408   Instruction *PtrInGlobal = new AddrSpaceCastInst(
409       Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
410       Ptr->getName(), &*InsertPt);
411   Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
412                                               Ptr->getName(), &*InsertPt);
413   // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
414   Ptr->replaceAllUsesWith(PtrInGeneric);
415   PtrInGlobal->setOperand(0, Ptr);
416 }
417 
418 // =============================================================================
419 // Main function for this pass.
420 // =============================================================================
421 bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
422                                          Function &F) {
423   // Copying of byval aggregates + SROA may result in pointers being loaded as
424   // integers, followed by intotoptr. We may want to mark those as global, too,
425   // but only if the loaded integer is used exclusively for conversion to a
426   // pointer with inttoptr.
427   auto HandleIntToPtr = [this](Value &V) {
428     if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) {
429       SmallVector<User *, 16> UsersToUpdate(V.users());
430       for (User *U : UsersToUpdate)
431         markPointerAsGlobal(U);
432     }
433   };
434   if (TM.getDrvInterface() == NVPTX::CUDA) {
435     // Mark pointers in byval structs as global.
436     for (auto &B : F) {
437       for (auto &I : B) {
438         if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
439           if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
440             Value *UO = getUnderlyingObject(LI->getPointerOperand());
441             if (Argument *Arg = dyn_cast<Argument>(UO)) {
442               if (Arg->hasByValAttr()) {
443                 // LI is a load from a pointer within a byval kernel parameter.
444                 if (LI->getType()->isPointerTy())
445                   markPointerAsGlobal(LI);
446                 else
447                   HandleIntToPtr(*LI);
448               }
449             }
450           }
451         }
452       }
453     }
454   }
455 
456   LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
457   for (Argument &Arg : F.args()) {
458     if (Arg.getType()->isPointerTy()) {
459       if (Arg.hasByValAttr())
460         handleByValParam(TM, &Arg);
461       else if (TM.getDrvInterface() == NVPTX::CUDA)
462         markPointerAsGlobal(&Arg);
463     } else if (Arg.getType()->isIntegerTy() &&
464                TM.getDrvInterface() == NVPTX::CUDA) {
465       HandleIntToPtr(Arg);
466     }
467   }
468   return true;
469 }
470 
471 // Device functions only need to copy byval args into local memory.
472 bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
473                                          Function &F) {
474   LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
475   for (Argument &Arg : F.args())
476     if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
477       handleByValParam(TM, &Arg);
478   return true;
479 }
480 
481 bool NVPTXLowerArgs::runOnFunction(Function &F) {
482   auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
483 
484   return isKernelFunction(F) ? runOnKernelFunction(TM, F)
485                              : runOnDeviceFunction(TM, F);
486 }
487 
488 FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }
489