1 //===- NVPTXUtilities.cpp - Utility Functions -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "NVPTXUtilities.h"
14 #include "NVPTX.h"
15 #include "NVPTXTargetMachine.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/GlobalVariable.h"
19 #include "llvm/IR/InstIterator.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/Operator.h"
22 #include "llvm/Support/Alignment.h"
23 #include "llvm/Support/Mutex.h"
24 #include <algorithm>
25 #include <cstring>
26 #include <map>
27 #include <mutex>
28 #include <optional>
29 #include <string>
30 #include <vector>
31 
32 namespace llvm {
33 
34 namespace {
35 typedef std::map<std::string, std::vector<unsigned> > key_val_pair_t;
36 typedef std::map<const GlobalValue *, key_val_pair_t> global_val_annot_t;
37 
38 struct AnnotationCache {
39   sys::Mutex Lock;
40   std::map<const Module *, global_val_annot_t> Cache;
41 };
42 
getAnnotationCache()43 AnnotationCache &getAnnotationCache() {
44   static AnnotationCache AC;
45   return AC;
46 }
47 } // anonymous namespace
48 
clearAnnotationCache(const Module * Mod)49 void clearAnnotationCache(const Module *Mod) {
50   auto &AC = getAnnotationCache();
51   std::lock_guard<sys::Mutex> Guard(AC.Lock);
52   AC.Cache.erase(Mod);
53 }
54 
readIntVecFromMDNode(const MDNode * MetadataNode,std::vector<unsigned> & Vec)55 static void readIntVecFromMDNode(const MDNode *MetadataNode,
56                                  std::vector<unsigned> &Vec) {
57   for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
58     ConstantInt *Val =
59         mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
60     Vec.push_back(Val->getZExtValue());
61   }
62 }
63 
cacheAnnotationFromMD(const MDNode * MetadataNode,key_val_pair_t & retval)64 static void cacheAnnotationFromMD(const MDNode *MetadataNode,
65                                   key_val_pair_t &retval) {
66   auto &AC = getAnnotationCache();
67   std::lock_guard<sys::Mutex> Guard(AC.Lock);
68   assert(MetadataNode && "Invalid mdnode for annotation");
69   assert((MetadataNode->getNumOperands() % 2) == 1 &&
70          "Invalid number of operands");
71   // start index = 1, to skip the global variable key
72   // increment = 2, to skip the value for each property-value pairs
73   for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
74     // property
75     const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
76     assert(prop && "Annotation property not a string");
77     std::string Key = prop->getString().str();
78 
79     // value
80     if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
81             MetadataNode->getOperand(i + 1))) {
82       retval[Key].push_back(Val->getZExtValue());
83     } else if (MDNode *VecMd =
84                    dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
85       // note: only "grid_constant" annotations support vector MDNodes.
86       // assert: there can only exist one unique key value pair of
87       // the form (string key, MDNode node). Operands of such a node
88       // shall always be unsigned ints.
89       if (retval.find(Key) == retval.end()) {
90         readIntVecFromMDNode(VecMd, retval[Key]);
91         continue;
92       }
93     } else {
94       llvm_unreachable("Value operand not a constant int or an mdnode");
95     }
96   }
97 }
98 
cacheAnnotationFromMD(const Module * m,const GlobalValue * gv)99 static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
100   auto &AC = getAnnotationCache();
101   std::lock_guard<sys::Mutex> Guard(AC.Lock);
102   NamedMDNode *NMD = m->getNamedMetadata("nvvm.annotations");
103   if (!NMD)
104     return;
105   key_val_pair_t tmp;
106   for (unsigned i = 0, e = NMD->getNumOperands(); i != e; ++i) {
107     const MDNode *elem = NMD->getOperand(i);
108 
109     GlobalValue *entity =
110         mdconst::dyn_extract_or_null<GlobalValue>(elem->getOperand(0));
111     // entity may be null due to DCE
112     if (!entity)
113       continue;
114     if (entity != gv)
115       continue;
116 
117     // accumulate annotations for entity in tmp
118     cacheAnnotationFromMD(elem, tmp);
119   }
120 
121   if (tmp.empty()) // no annotations for this gv
122     return;
123 
124   if (AC.Cache.find(m) != AC.Cache.end())
125     AC.Cache[m][gv] = std::move(tmp);
126   else {
127     global_val_annot_t tmp1;
128     tmp1[gv] = std::move(tmp);
129     AC.Cache[m] = std::move(tmp1);
130   }
131 }
132 
findOneNVVMAnnotation(const GlobalValue * gv,const std::string & prop,unsigned & retval)133 bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
134                            unsigned &retval) {
135   auto &AC = getAnnotationCache();
136   std::lock_guard<sys::Mutex> Guard(AC.Lock);
137   const Module *m = gv->getParent();
138   if (AC.Cache.find(m) == AC.Cache.end())
139     cacheAnnotationFromMD(m, gv);
140   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
141     cacheAnnotationFromMD(m, gv);
142   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
143     return false;
144   retval = AC.Cache[m][gv][prop][0];
145   return true;
146 }
147 
148 static std::optional<unsigned>
findOneNVVMAnnotation(const GlobalValue & GV,const std::string & PropName)149 findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
150   unsigned RetVal;
151   if (findOneNVVMAnnotation(&GV, PropName, RetVal))
152     return RetVal;
153   return std::nullopt;
154 }
155 
findAllNVVMAnnotation(const GlobalValue * gv,const std::string & prop,std::vector<unsigned> & retval)156 bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
157                            std::vector<unsigned> &retval) {
158   auto &AC = getAnnotationCache();
159   std::lock_guard<sys::Mutex> Guard(AC.Lock);
160   const Module *m = gv->getParent();
161   if (AC.Cache.find(m) == AC.Cache.end())
162     cacheAnnotationFromMD(m, gv);
163   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
164     cacheAnnotationFromMD(m, gv);
165   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
166     return false;
167   retval = AC.Cache[m][gv][prop];
168   return true;
169 }
170 
isTexture(const Value & val)171 bool isTexture(const Value &val) {
172   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
173     unsigned Annot;
174     if (findOneNVVMAnnotation(gv, "texture", Annot)) {
175       assert((Annot == 1) && "Unexpected annotation on a texture symbol");
176       return true;
177     }
178   }
179   return false;
180 }
181 
isSurface(const Value & val)182 bool isSurface(const Value &val) {
183   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
184     unsigned Annot;
185     if (findOneNVVMAnnotation(gv, "surface", Annot)) {
186       assert((Annot == 1) && "Unexpected annotation on a surface symbol");
187       return true;
188     }
189   }
190   return false;
191 }
192 
argHasNVVMAnnotation(const Value & Val,const std::string & Annotation,const bool StartArgIndexAtOne=false)193 static bool argHasNVVMAnnotation(const Value &Val,
194                                  const std::string &Annotation,
195                                  const bool StartArgIndexAtOne = false) {
196   if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
197     const Function *Func = Arg->getParent();
198     std::vector<unsigned> Annot;
199     if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
200       const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
201       if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
202         return true;
203       }
204     }
205   }
206   return false;
207 }
208 
isParamGridConstant(const Value & V)209 bool isParamGridConstant(const Value &V) {
210   if (const Argument *Arg = dyn_cast<Argument>(&V)) {
211     // "grid_constant" counts argument indices starting from 1
212     if (Arg->hasByValAttr() &&
213         argHasNVVMAnnotation(*Arg, "grid_constant",
214                              /*StartArgIndexAtOne*/ true)) {
215       assert(isKernelFunction(*Arg->getParent()) &&
216              "only kernel arguments can be grid_constant");
217       return true;
218     }
219   }
220   return false;
221 }
222 
isSampler(const Value & val)223 bool isSampler(const Value &val) {
224   const char *AnnotationName = "sampler";
225 
226   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
227     unsigned Annot;
228     if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
229       assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
230       return true;
231     }
232   }
233   return argHasNVVMAnnotation(val, AnnotationName);
234 }
235 
isImageReadOnly(const Value & val)236 bool isImageReadOnly(const Value &val) {
237   return argHasNVVMAnnotation(val, "rdoimage");
238 }
239 
isImageWriteOnly(const Value & val)240 bool isImageWriteOnly(const Value &val) {
241   return argHasNVVMAnnotation(val, "wroimage");
242 }
243 
isImageReadWrite(const Value & val)244 bool isImageReadWrite(const Value &val) {
245   return argHasNVVMAnnotation(val, "rdwrimage");
246 }
247 
isImage(const Value & val)248 bool isImage(const Value &val) {
249   return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
250 }
251 
isManaged(const Value & val)252 bool isManaged(const Value &val) {
253   if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
254     unsigned Annot;
255     if (findOneNVVMAnnotation(gv, "managed", Annot)) {
256       assert((Annot == 1) && "Unexpected annotation on a managed symbol");
257       return true;
258     }
259   }
260   return false;
261 }
262 
getTextureName(const Value & val)263 std::string getTextureName(const Value &val) {
264   assert(val.hasName() && "Found texture variable with no name");
265   return std::string(val.getName());
266 }
267 
getSurfaceName(const Value & val)268 std::string getSurfaceName(const Value &val) {
269   assert(val.hasName() && "Found surface variable with no name");
270   return std::string(val.getName());
271 }
272 
getSamplerName(const Value & val)273 std::string getSamplerName(const Value &val) {
274   assert(val.hasName() && "Found sampler variable with no name");
275   return std::string(val.getName());
276 }
277 
getMaxNTIDx(const Function & F)278 std::optional<unsigned> getMaxNTIDx(const Function &F) {
279   return findOneNVVMAnnotation(F, "maxntidx");
280 }
281 
getMaxNTIDy(const Function & F)282 std::optional<unsigned> getMaxNTIDy(const Function &F) {
283   return findOneNVVMAnnotation(F, "maxntidy");
284 }
285 
getMaxNTIDz(const Function & F)286 std::optional<unsigned> getMaxNTIDz(const Function &F) {
287   return findOneNVVMAnnotation(F, "maxntidz");
288 }
289 
getMaxNTID(const Function & F)290 std::optional<unsigned> getMaxNTID(const Function &F) {
291   // Note: The semantics here are a bit strange. The PTX ISA states the
292   // following (11.4.2. Performance-Tuning Directives: .maxntid):
293   //
294   //  Note that this directive guarantees that the total number of threads does
295   //  not exceed the maximum, but does not guarantee that the limit in any
296   //  particular dimension is not exceeded.
297   std::optional<unsigned> MaxNTIDx = getMaxNTIDx(F);
298   std::optional<unsigned> MaxNTIDy = getMaxNTIDy(F);
299   std::optional<unsigned> MaxNTIDz = getMaxNTIDz(F);
300   if (MaxNTIDx || MaxNTIDy || MaxNTIDz)
301     return MaxNTIDx.value_or(1) * MaxNTIDy.value_or(1) * MaxNTIDz.value_or(1);
302   return std::nullopt;
303 }
304 
getMaxClusterRank(const Function & F,unsigned & x)305 bool getMaxClusterRank(const Function &F, unsigned &x) {
306   return findOneNVVMAnnotation(&F, "maxclusterrank", x);
307 }
308 
getReqNTIDx(const Function & F)309 std::optional<unsigned> getReqNTIDx(const Function &F) {
310   return findOneNVVMAnnotation(F, "reqntidx");
311 }
312 
getReqNTIDy(const Function & F)313 std::optional<unsigned> getReqNTIDy(const Function &F) {
314   return findOneNVVMAnnotation(F, "reqntidy");
315 }
316 
getReqNTIDz(const Function & F)317 std::optional<unsigned> getReqNTIDz(const Function &F) {
318   return findOneNVVMAnnotation(F, "reqntidz");
319 }
320 
getReqNTID(const Function & F)321 std::optional<unsigned> getReqNTID(const Function &F) {
322   // Note: The semantics here are a bit strange. See getMaxNTID.
323   std::optional<unsigned> ReqNTIDx = getReqNTIDx(F);
324   std::optional<unsigned> ReqNTIDy = getReqNTIDy(F);
325   std::optional<unsigned> ReqNTIDz = getReqNTIDz(F);
326   if (ReqNTIDx || ReqNTIDy || ReqNTIDz)
327     return ReqNTIDx.value_or(1) * ReqNTIDy.value_or(1) * ReqNTIDz.value_or(1);
328   return std::nullopt;
329 }
330 
getMinCTASm(const Function & F,unsigned & x)331 bool getMinCTASm(const Function &F, unsigned &x) {
332   return findOneNVVMAnnotation(&F, "minctasm", x);
333 }
334 
getMaxNReg(const Function & F,unsigned & x)335 bool getMaxNReg(const Function &F, unsigned &x) {
336   return findOneNVVMAnnotation(&F, "maxnreg", x);
337 }
338 
isKernelFunction(const Function & F)339 bool isKernelFunction(const Function &F) {
340   unsigned x = 0;
341   if (!findOneNVVMAnnotation(&F, "kernel", x)) {
342     // There is no NVVM metadata, check the calling convention
343     return F.getCallingConv() == CallingConv::PTX_Kernel;
344   }
345   return (x == 1);
346 }
347 
getAlign(const Function & F,unsigned Index)348 MaybeAlign getAlign(const Function &F, unsigned Index) {
349   // First check the alignstack metadata
350   if (MaybeAlign StackAlign =
351           F.getAttributes().getAttributes(Index).getStackAlignment())
352     return StackAlign;
353 
354   // If that is missing, check the legacy nvvm metadata
355   std::vector<unsigned> Vs;
356   bool retval = findAllNVVMAnnotation(&F, "align", Vs);
357   if (!retval)
358     return std::nullopt;
359   for (unsigned V : Vs)
360     if ((V >> 16) == Index)
361       return Align(V & 0xFFFF);
362 
363   return std::nullopt;
364 }
365 
getAlign(const CallInst & I,unsigned Index)366 MaybeAlign getAlign(const CallInst &I, unsigned Index) {
367   // First check the alignstack metadata
368   if (MaybeAlign StackAlign =
369           I.getAttributes().getAttributes(Index).getStackAlignment())
370     return StackAlign;
371 
372   // If that is missing, check the legacy nvvm metadata
373   if (MDNode *alignNode = I.getMetadata("callalign")) {
374     for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
375       if (const ConstantInt *CI =
376               mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
377         unsigned V = CI->getZExtValue();
378         if ((V >> 16) == Index)
379           return Align(V & 0xFFFF);
380         if ((V >> 16) > Index)
381           return std::nullopt;
382       }
383     }
384   }
385   return std::nullopt;
386 }
387 
getMaybeBitcastedCallee(const CallBase * CB)388 Function *getMaybeBitcastedCallee(const CallBase *CB) {
389   return dyn_cast<Function>(CB->getCalledOperand()->stripPointerCasts());
390 }
391 
shouldEmitPTXNoReturn(const Value * V,const TargetMachine & TM)392 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
393   const auto &ST =
394       *static_cast<const NVPTXTargetMachine &>(TM).getSubtargetImpl();
395   if (!ST.hasNoReturn())
396     return false;
397 
398   assert((isa<Function>(V) || isa<CallInst>(V)) &&
399          "Expect either a call instruction or a function");
400 
401   if (const CallInst *CallI = dyn_cast<CallInst>(V))
402     return CallI->doesNotReturn() &&
403            CallI->getFunctionType()->getReturnType()->isVoidTy();
404 
405   const Function *F = cast<Function>(V);
406   return F->doesNotReturn() &&
407          F->getFunctionType()->getReturnType()->isVoidTy() &&
408          !isKernelFunction(*F);
409 }
410 
Isv2x16VT(EVT VT)411 bool Isv2x16VT(EVT VT) {
412   return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
413 }
414 
415 } // namespace llvm
416