xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (revision b1879975794772ee51f0b4865753364c7d7626c3)
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/Path.h"
82 #include "llvm/Support/raw_ostream.h"
83 #include "llvm/Target/TargetLoweringObjectFile.h"
84 #include "llvm/Target/TargetMachine.h"
85 #include "llvm/TargetParser/Triple.h"
86 #include "llvm/Transforms/Utils/UnrollLoop.h"
87 #include <cassert>
88 #include <cstdint>
89 #include <cstring>
90 #include <new>
91 #include <string>
92 #include <utility>
93 #include <vector>
94 
95 using namespace llvm;
96 
97 static cl::opt<bool>
98     LowerCtorDtor("nvptx-lower-global-ctor-dtor",
99                   cl::desc("Lower GPU ctor / dtors to globals on the device."),
100                   cl::init(false), cl::Hidden);
101 
102 #define DEPOTNAME "__local_depot"
103 
104 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
105 /// depends.
106 static void
107 DiscoverDependentGlobals(const Value *V,
108                          DenseSet<const GlobalVariable *> &Globals) {
109   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
110     Globals.insert(GV);
111   else {
112     if (const User *U = dyn_cast<User>(V)) {
113       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
114         DiscoverDependentGlobals(U->getOperand(i), Globals);
115       }
116     }
117   }
118 }
119 
120 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
121 /// instances to be emitted, but only after any dependents have been added
122 /// first.s
123 static void
124 VisitGlobalVariableForEmission(const GlobalVariable *GV,
125                                SmallVectorImpl<const GlobalVariable *> &Order,
126                                DenseSet<const GlobalVariable *> &Visited,
127                                DenseSet<const GlobalVariable *> &Visiting) {
128   // Have we already visited this one?
129   if (Visited.count(GV))
130     return;
131 
132   // Do we have a circular dependency?
133   if (!Visiting.insert(GV).second)
134     report_fatal_error("Circular dependency found in global variable set");
135 
136   // Make sure we visit all dependents first
137   DenseSet<const GlobalVariable *> Others;
138   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
139     DiscoverDependentGlobals(GV->getOperand(i), Others);
140 
141   for (const GlobalVariable *GV : Others)
142     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
143 
144   // Now we can visit ourself
145   Order.push_back(GV);
146   Visited.insert(GV);
147   Visiting.erase(GV);
148 }
149 
150 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
151   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
152                                         getSubtargetInfo().getFeatureBits());
153 
154   MCInst Inst;
155   lowerToMCInst(MI, Inst);
156   EmitToStreamer(*OutStreamer, Inst);
157 }
158 
159 // Handle symbol backtracking for targets that do not support image handles
160 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
161                                            unsigned OpNo, MCOperand &MCOp) {
162   const MachineOperand &MO = MI->getOperand(OpNo);
163   const MCInstrDesc &MCID = MI->getDesc();
164 
165   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
166     // This is a texture fetch, so operand 4 is a texref and operand 5 is
167     // a samplerref
168     if (OpNo == 4 && MO.isImm()) {
169       lowerImageHandleSymbol(MO.getImm(), MCOp);
170       return true;
171     }
172     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
173       lowerImageHandleSymbol(MO.getImm(), MCOp);
174       return true;
175     }
176 
177     return false;
178   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
179     unsigned VecSize =
180       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
181 
182     // For a surface load of vector size N, the Nth operand will be the surfref
183     if (OpNo == VecSize && MO.isImm()) {
184       lowerImageHandleSymbol(MO.getImm(), MCOp);
185       return true;
186     }
187 
188     return false;
189   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
190     // This is a surface store, so operand 0 is a surfref
191     if (OpNo == 0 && MO.isImm()) {
192       lowerImageHandleSymbol(MO.getImm(), MCOp);
193       return true;
194     }
195 
196     return false;
197   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
198     // This is a query, so operand 1 is a surfref/texref
199     if (OpNo == 1 && MO.isImm()) {
200       lowerImageHandleSymbol(MO.getImm(), MCOp);
201       return true;
202     }
203 
204     return false;
205   }
206 
207   return false;
208 }
209 
210 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
211   // Ewwww
212   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
213   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
214   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
215   const char *Sym = MFI->getImageHandleSymbol(Index);
216   StringRef SymName = nvTM.getStrPool().save(Sym);
217   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
218 }
219 
220 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
221   OutMI.setOpcode(MI->getOpcode());
222   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
223   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
224     const MachineOperand &MO = MI->getOperand(0);
225     OutMI.addOperand(GetSymbolRef(
226       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
227     return;
228   }
229 
230   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
231   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
232     const MachineOperand &MO = MI->getOperand(i);
233 
234     MCOperand MCOp;
235     if (!STI.hasImageHandles()) {
236       if (lowerImageHandleOperand(MI, i, MCOp)) {
237         OutMI.addOperand(MCOp);
238         continue;
239       }
240     }
241 
242     if (lowerOperand(MO, MCOp))
243       OutMI.addOperand(MCOp);
244   }
245 }
246 
247 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
248                                    MCOperand &MCOp) {
249   switch (MO.getType()) {
250   default: llvm_unreachable("unknown operand type");
251   case MachineOperand::MO_Register:
252     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
253     break;
254   case MachineOperand::MO_Immediate:
255     MCOp = MCOperand::createImm(MO.getImm());
256     break;
257   case MachineOperand::MO_MachineBasicBlock:
258     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
259         MO.getMBB()->getSymbol(), OutContext));
260     break;
261   case MachineOperand::MO_ExternalSymbol:
262     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
263     break;
264   case MachineOperand::MO_GlobalAddress:
265     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
266     break;
267   case MachineOperand::MO_FPImmediate: {
268     const ConstantFP *Cnt = MO.getFPImm();
269     const APFloat &Val = Cnt->getValueAPF();
270 
271     switch (Cnt->getType()->getTypeID()) {
272     default: report_fatal_error("Unsupported FP type"); break;
273     case Type::HalfTyID:
274       MCOp = MCOperand::createExpr(
275         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
276       break;
277     case Type::BFloatTyID:
278       MCOp = MCOperand::createExpr(
279           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
280       break;
281     case Type::FloatTyID:
282       MCOp = MCOperand::createExpr(
283         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
284       break;
285     case Type::DoubleTyID:
286       MCOp = MCOperand::createExpr(
287         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
288       break;
289     }
290     break;
291   }
292   }
293   return true;
294 }
295 
296 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
297   if (Register::isVirtualRegister(Reg)) {
298     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
299 
300     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
301     unsigned RegNum = RegMap[Reg];
302 
303     // Encode the register class in the upper 4 bits
304     // Must be kept in sync with NVPTXInstPrinter::printRegName
305     unsigned Ret = 0;
306     if (RC == &NVPTX::Int1RegsRegClass) {
307       Ret = (1 << 28);
308     } else if (RC == &NVPTX::Int16RegsRegClass) {
309       Ret = (2 << 28);
310     } else if (RC == &NVPTX::Int32RegsRegClass) {
311       Ret = (3 << 28);
312     } else if (RC == &NVPTX::Int64RegsRegClass) {
313       Ret = (4 << 28);
314     } else if (RC == &NVPTX::Float32RegsRegClass) {
315       Ret = (5 << 28);
316     } else if (RC == &NVPTX::Float64RegsRegClass) {
317       Ret = (6 << 28);
318     } else if (RC == &NVPTX::Int128RegsRegClass) {
319       Ret = (7 << 28);
320     } else {
321       report_fatal_error("Bad register class");
322     }
323 
324     // Insert the vreg number
325     Ret |= (RegNum & 0x0FFFFFFF);
326     return Ret;
327   } else {
328     // Some special-use registers are actually physical registers.
329     // Encode this as the register class ID of 0 and the real register ID.
330     return Reg & 0x0FFFFFFF;
331   }
332 }
333 
334 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
335   const MCExpr *Expr;
336   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
337                                  OutContext);
338   return MCOperand::createExpr(Expr);
339 }
340 
341 static bool ShouldPassAsArray(Type *Ty) {
342   return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
343          Ty->isHalfTy() || Ty->isBFloatTy();
344 }
345 
346 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
347   const DataLayout &DL = getDataLayout();
348   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
349   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
350 
351   Type *Ty = F->getReturnType();
352 
353   bool isABI = (STI.getSmVersion() >= 20);
354 
355   if (Ty->getTypeID() == Type::VoidTyID)
356     return;
357   O << " (";
358 
359   if (isABI) {
360     if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
361         !ShouldPassAsArray(Ty)) {
362       unsigned size = 0;
363       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
364         size = ITy->getBitWidth();
365       } else {
366         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
367         size = Ty->getPrimitiveSizeInBits();
368       }
369       size = promoteScalarArgumentSize(size);
370       O << ".param .b" << size << " func_retval0";
371     } else if (isa<PointerType>(Ty)) {
372       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
373         << " func_retval0";
374     } else if (ShouldPassAsArray(Ty)) {
375       unsigned totalsz = DL.getTypeAllocSize(Ty);
376       Align RetAlignment = TLI->getFunctionArgumentAlignment(
377           F, Ty, AttributeList::ReturnIndex, DL);
378       O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
379         << totalsz << "]";
380     } else
381       llvm_unreachable("Unknown return type");
382   } else {
383     SmallVector<EVT, 16> vtparts;
384     ComputeValueVTs(*TLI, DL, Ty, vtparts);
385     unsigned idx = 0;
386     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
387       unsigned elems = 1;
388       EVT elemtype = vtparts[i];
389       if (vtparts[i].isVector()) {
390         elems = vtparts[i].getVectorNumElements();
391         elemtype = vtparts[i].getVectorElementType();
392       }
393 
394       for (unsigned j = 0, je = elems; j != je; ++j) {
395         unsigned sz = elemtype.getSizeInBits();
396         if (elemtype.isInteger())
397           sz = promoteScalarArgumentSize(sz);
398         O << ".reg .b" << sz << " func_retval" << idx;
399         if (j < je - 1)
400           O << ", ";
401         ++idx;
402       }
403       if (i < e - 1)
404         O << ", ";
405     }
406   }
407   O << ") ";
408 }
409 
410 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
411                                         raw_ostream &O) {
412   const Function &F = MF.getFunction();
413   printReturnValStr(&F, O);
414 }
415 
416 // Return true if MBB is the header of a loop marked with
417 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
418 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
419     const MachineBasicBlock &MBB) const {
420   MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
421   // We insert .pragma "nounroll" only to the loop header.
422   if (!LI.isLoopHeader(&MBB))
423     return false;
424 
425   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
426   // we iterate through each back edge of the loop with header MBB, and check
427   // whether its metadata contains llvm.loop.unroll.disable.
428   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
429     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
430       // Edges from other loops to MBB are not back edges.
431       continue;
432     }
433     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
434       if (MDNode *LoopID =
435               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
436         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
437           return true;
438         if (MDNode *UnrollCountMD =
439                 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
440           if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
441                   ->isOne())
442             return true;
443         }
444       }
445     }
446   }
447   return false;
448 }
449 
450 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
451   AsmPrinter::emitBasicBlockStart(MBB);
452   if (isLoopHeaderOfNoUnroll(MBB))
453     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
454 }
455 
456 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
457   SmallString<128> Str;
458   raw_svector_ostream O(Str);
459 
460   if (!GlobalsEmitted) {
461     emitGlobals(*MF->getFunction().getParent());
462     GlobalsEmitted = true;
463   }
464 
465   // Set up
466   MRI = &MF->getRegInfo();
467   F = &MF->getFunction();
468   emitLinkageDirective(F, O);
469   if (isKernelFunction(*F))
470     O << ".entry ";
471   else {
472     O << ".func ";
473     printReturnValStr(*MF, O);
474   }
475 
476   CurrentFnSym->print(O, MAI);
477 
478   emitFunctionParamList(F, O);
479   O << "\n";
480 
481   if (isKernelFunction(*F))
482     emitKernelFunctionDirectives(*F, O);
483 
484   if (shouldEmitPTXNoReturn(F, TM))
485     O << ".noreturn";
486 
487   OutStreamer->emitRawText(O.str());
488 
489   VRegMapping.clear();
490   // Emit open brace for function body.
491   OutStreamer->emitRawText(StringRef("{\n"));
492   setAndEmitFunctionVirtualRegisters(*MF);
493   // Emit initial .loc debug directive for correct relocation symbol data.
494   if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
495     assert(SP->getUnit());
496     if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo())
497       emitInitialRawDwarfLocDirective(*MF);
498   }
499 }
500 
501 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
502   bool Result = AsmPrinter::runOnMachineFunction(F);
503   // Emit closing brace for the body of function F.
504   // The closing brace must be emitted here because we need to emit additional
505   // debug labels/data after the last basic block.
506   // We need to emit the closing brace here because we don't have function that
507   // finished emission of the function body.
508   OutStreamer->emitRawText(StringRef("}\n"));
509   return Result;
510 }
511 
512 void NVPTXAsmPrinter::emitFunctionBodyStart() {
513   SmallString<128> Str;
514   raw_svector_ostream O(Str);
515   emitDemotedVars(&MF->getFunction(), O);
516   OutStreamer->emitRawText(O.str());
517 }
518 
519 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
520   VRegMapping.clear();
521 }
522 
523 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
524     SmallString<128> Str;
525     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
526     return OutContext.getOrCreateSymbol(Str);
527 }
528 
529 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
530   Register RegNo = MI->getOperand(0).getReg();
531   if (RegNo.isVirtual()) {
532     OutStreamer->AddComment(Twine("implicit-def: ") +
533                             getVirtualRegisterName(RegNo));
534   } else {
535     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
536     OutStreamer->AddComment(Twine("implicit-def: ") +
537                             STI.getRegisterInfo()->getName(RegNo));
538   }
539   OutStreamer->addBlankLine();
540 }
541 
542 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
543                                                    raw_ostream &O) const {
544   // If the NVVM IR has some of reqntid* specified, then output
545   // the reqntid directive, and set the unspecified ones to 1.
546   // If none of Reqntid* is specified, don't output reqntid directive.
547   std::optional<unsigned> Reqntidx = getReqNTIDx(F);
548   std::optional<unsigned> Reqntidy = getReqNTIDy(F);
549   std::optional<unsigned> Reqntidz = getReqNTIDz(F);
550 
551   if (Reqntidx || Reqntidy || Reqntidz)
552     O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
553       << ", " << Reqntidz.value_or(1) << "\n";
554 
555   // If the NVVM IR has some of maxntid* specified, then output
556   // the maxntid directive, and set the unspecified ones to 1.
557   // If none of maxntid* is specified, don't output maxntid directive.
558   std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
559   std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
560   std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
561 
562   if (Maxntidx || Maxntidy || Maxntidz)
563     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
564       << ", " << Maxntidz.value_or(1) << "\n";
565 
566   unsigned Mincta = 0;
567   if (getMinCTASm(F, Mincta))
568     O << ".minnctapersm " << Mincta << "\n";
569 
570   unsigned Maxnreg = 0;
571   if (getMaxNReg(F, Maxnreg))
572     O << ".maxnreg " << Maxnreg << "\n";
573 
574   // .maxclusterrank directive requires SM_90 or higher, make sure that we
575   // filter it out for lower SM versions, as it causes a hard ptxas crash.
576   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578   unsigned Maxclusterrank = 0;
579   if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580     O << ".maxclusterrank " << Maxclusterrank << "\n";
581 }
582 
583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
584   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
585 
586   std::string Name;
587   raw_string_ostream NameStr(Name);
588 
589   VRegRCMap::const_iterator I = VRegMapping.find(RC);
590   assert(I != VRegMapping.end() && "Bad register class");
591   const DenseMap<unsigned, unsigned> &RegMap = I->second;
592 
593   VRegMap::const_iterator VI = RegMap.find(Reg);
594   assert(VI != RegMap.end() && "Bad virtual register");
595   unsigned MappedVR = VI->second;
596 
597   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
598 
599   NameStr.flush();
600   return Name;
601 }
602 
603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
604                                           raw_ostream &O) {
605   O << getVirtualRegisterName(vr);
606 }
607 
608 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
609                                            raw_ostream &O) {
610   const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());
611   if (!F || isKernelFunction(*F) || F->isDeclaration())
612     report_fatal_error(
613         "NVPTX aliasee must be a non-kernel function definition");
614 
615   if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
616       GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
617     report_fatal_error("NVPTX aliasee must not be '.weak'");
618 
619   emitDeclarationWithName(F, getSymbol(GA), O);
620 }
621 
622 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
623   emitDeclarationWithName(F, getSymbol(F), O);
624 }
625 
626 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
627                                               raw_ostream &O) {
628   emitLinkageDirective(F, O);
629   if (isKernelFunction(*F))
630     O << ".entry ";
631   else
632     O << ".func ";
633   printReturnValStr(F, O);
634   S->print(O, MAI);
635   O << "\n";
636   emitFunctionParamList(F, O);
637   O << "\n";
638   if (shouldEmitPTXNoReturn(F, TM))
639     O << ".noreturn";
640   O << ";\n";
641 }
642 
643 static bool usedInGlobalVarDef(const Constant *C) {
644   if (!C)
645     return false;
646 
647   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
648     return GV->getName() != "llvm.used";
649   }
650 
651   for (const User *U : C->users())
652     if (const Constant *C = dyn_cast<Constant>(U))
653       if (usedInGlobalVarDef(C))
654         return true;
655 
656   return false;
657 }
658 
659 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
660   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
661     if (othergv->getName() == "llvm.used")
662       return true;
663   }
664 
665   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
666     if (instr->getParent() && instr->getParent()->getParent()) {
667       const Function *curFunc = instr->getParent()->getParent();
668       if (oneFunc && (curFunc != oneFunc))
669         return false;
670       oneFunc = curFunc;
671       return true;
672     } else
673       return false;
674   }
675 
676   for (const User *UU : U->users())
677     if (!usedInOneFunc(UU, oneFunc))
678       return false;
679 
680   return true;
681 }
682 
683 /* Find out if a global variable can be demoted to local scope.
684  * Currently, this is valid for CUDA shared variables, which have local
685  * scope and global lifetime. So the conditions to check are :
686  * 1. Is the global variable in shared address space?
687  * 2. Does it have local linkage?
688  * 3. Is the global variable referenced only in one function?
689  */
690 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
691   if (!gv->hasLocalLinkage())
692     return false;
693   PointerType *Pty = gv->getType();
694   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
695     return false;
696 
697   const Function *oneFunc = nullptr;
698 
699   bool flag = usedInOneFunc(gv, oneFunc);
700   if (!flag)
701     return false;
702   if (!oneFunc)
703     return false;
704   f = oneFunc;
705   return true;
706 }
707 
708 static bool useFuncSeen(const Constant *C,
709                         DenseMap<const Function *, bool> &seenMap) {
710   for (const User *U : C->users()) {
711     if (const Constant *cu = dyn_cast<Constant>(U)) {
712       if (useFuncSeen(cu, seenMap))
713         return true;
714     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
715       const BasicBlock *bb = I->getParent();
716       if (!bb)
717         continue;
718       const Function *caller = bb->getParent();
719       if (!caller)
720         continue;
721       if (seenMap.contains(caller))
722         return true;
723     }
724   }
725   return false;
726 }
727 
728 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
729   DenseMap<const Function *, bool> seenMap;
730   for (const Function &F : M) {
731     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
732       emitDeclaration(&F, O);
733       continue;
734     }
735 
736     if (F.isDeclaration()) {
737       if (F.use_empty())
738         continue;
739       if (F.getIntrinsicID())
740         continue;
741       emitDeclaration(&F, O);
742       continue;
743     }
744     for (const User *U : F.users()) {
745       if (const Constant *C = dyn_cast<Constant>(U)) {
746         if (usedInGlobalVarDef(C)) {
747           // The use is in the initialization of a global variable
748           // that is a function pointer, so print a declaration
749           // for the original function
750           emitDeclaration(&F, O);
751           break;
752         }
753         // Emit a declaration of this function if the function that
754         // uses this constant expr has already been seen.
755         if (useFuncSeen(C, seenMap)) {
756           emitDeclaration(&F, O);
757           break;
758         }
759       }
760 
761       if (!isa<Instruction>(U))
762         continue;
763       const Instruction *instr = cast<Instruction>(U);
764       const BasicBlock *bb = instr->getParent();
765       if (!bb)
766         continue;
767       const Function *caller = bb->getParent();
768       if (!caller)
769         continue;
770 
771       // If a caller has already been seen, then the caller is
772       // appearing in the module before the callee. so print out
773       // a declaration for the callee.
774       if (seenMap.contains(caller)) {
775         emitDeclaration(&F, O);
776         break;
777       }
778     }
779     seenMap[&F] = true;
780   }
781   for (const GlobalAlias &GA : M.aliases())
782     emitAliasDeclaration(&GA, O);
783 }
784 
785 static bool isEmptyXXStructor(GlobalVariable *GV) {
786   if (!GV) return true;
787   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
788   if (!InitList) return true;  // Not an array; we don't know how to parse.
789   return InitList->getNumOperands() == 0;
790 }
791 
792 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
793   // Construct a default subtarget off of the TargetMachine defaults. The
794   // rest of NVPTX isn't friendly to change subtargets per function and
795   // so the default TargetMachine will have all of the options.
796   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
797   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
798   SmallString<128> Str1;
799   raw_svector_ostream OS1(Str1);
800 
801   // Emit header before any dwarf directives are emitted below.
802   emitHeader(M, OS1, *STI);
803   OutStreamer->emitRawText(OS1.str());
804 }
805 
806 bool NVPTXAsmPrinter::doInitialization(Module &M) {
807   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
808   const NVPTXSubtarget &STI =
809       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
810   if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
811     report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
812 
813   // OpenMP supports NVPTX global constructors and destructors.
814   bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
815 
816   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
817       !LowerCtorDtor && !IsOpenMP) {
818     report_fatal_error(
819         "Module has a nontrivial global ctor, which NVPTX does not support.");
820     return true;  // error
821   }
822   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
823       !LowerCtorDtor && !IsOpenMP) {
824     report_fatal_error(
825         "Module has a nontrivial global dtor, which NVPTX does not support.");
826     return true;  // error
827   }
828 
829   // We need to call the parent's one explicitly.
830   bool Result = AsmPrinter::doInitialization(M);
831 
832   GlobalsEmitted = false;
833 
834   return Result;
835 }
836 
837 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
838   SmallString<128> Str2;
839   raw_svector_ostream OS2(Str2);
840 
841   emitDeclarations(M, OS2);
842 
843   // As ptxas does not support forward references of globals, we need to first
844   // sort the list of module-level globals in def-use order. We visit each
845   // global variable in order, and ensure that we emit it *after* its dependent
846   // globals. We use a little extra memory maintaining both a set and a list to
847   // have fast searches while maintaining a strict ordering.
848   SmallVector<const GlobalVariable *, 8> Globals;
849   DenseSet<const GlobalVariable *> GVVisited;
850   DenseSet<const GlobalVariable *> GVVisiting;
851 
852   // Visit each global variable, in order
853   for (const GlobalVariable &I : M.globals())
854     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
855 
856   assert(GVVisited.size() == M.global_size() && "Missed a global variable");
857   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
858 
859   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860   const NVPTXSubtarget &STI =
861       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
862 
863   // Print out module-level global variables in proper order
864   for (const GlobalVariable *GV : Globals)
865     printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
866 
867   OS2 << '\n';
868 
869   OutStreamer->emitRawText(OS2.str());
870 }
871 
872 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
873   SmallString<128> Str;
874   raw_svector_ostream OS(Str);
875 
876   MCSymbol *Name = getSymbol(&GA);
877 
878   OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
879      << ";\n";
880 
881   OutStreamer->emitRawText(OS.str());
882 }
883 
884 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
885                                  const NVPTXSubtarget &STI) {
886   O << "//\n";
887   O << "// Generated by LLVM NVPTX Back-End\n";
888   O << "//\n";
889   O << "\n";
890 
891   unsigned PTXVersion = STI.getPTXVersion();
892   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
893 
894   O << ".target ";
895   O << STI.getTargetName();
896 
897   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
898   if (NTM.getDrvInterface() == NVPTX::NVCL)
899     O << ", texmode_independent";
900 
901   bool HasFullDebugInfo = false;
902   for (DICompileUnit *CU : M.debug_compile_units()) {
903     switch(CU->getEmissionKind()) {
904     case DICompileUnit::NoDebug:
905     case DICompileUnit::DebugDirectivesOnly:
906       break;
907     case DICompileUnit::LineTablesOnly:
908     case DICompileUnit::FullDebug:
909       HasFullDebugInfo = true;
910       break;
911     }
912     if (HasFullDebugInfo)
913       break;
914   }
915   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
916     O << ", debug";
917 
918   O << "\n";
919 
920   O << ".address_size ";
921   if (NTM.is64Bit())
922     O << "64";
923   else
924     O << "32";
925   O << "\n";
926 
927   O << "\n";
928 }
929 
930 bool NVPTXAsmPrinter::doFinalization(Module &M) {
931   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
932 
933   // If we did not emit any functions, then the global declarations have not
934   // yet been emitted.
935   if (!GlobalsEmitted) {
936     emitGlobals(M);
937     GlobalsEmitted = true;
938   }
939 
940   // call doFinalization
941   bool ret = AsmPrinter::doFinalization(M);
942 
943   clearAnnotationCache(&M);
944 
945   auto *TS =
946       static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
947   // Close the last emitted section
948   if (HasDebugInfo) {
949     TS->closeLastSection();
950     // Emit empty .debug_loc section for better support of the empty files.
951     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
952   }
953 
954   // Output last DWARF .file directives, if any.
955   TS->outputDwarfFileDirectives();
956 
957   return ret;
958 }
959 
960 // This function emits appropriate linkage directives for
961 // functions and global variables.
962 //
963 // extern function declaration            -> .extern
964 // extern function definition             -> .visible
965 // external global variable with init     -> .visible
966 // external without init                  -> .extern
967 // appending                              -> not allowed, assert.
968 // for any linkage other than
969 // internal, private, linker_private,
970 // linker_private_weak, linker_private_weak_def_auto,
971 // we emit                                -> .weak.
972 
973 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
974                                            raw_ostream &O) {
975   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
976     if (V->hasExternalLinkage()) {
977       if (isa<GlobalVariable>(V)) {
978         const GlobalVariable *GVar = cast<GlobalVariable>(V);
979         if (GVar) {
980           if (GVar->hasInitializer())
981             O << ".visible ";
982           else
983             O << ".extern ";
984         }
985       } else if (V->isDeclaration())
986         O << ".extern ";
987       else
988         O << ".visible ";
989     } else if (V->hasAppendingLinkage()) {
990       std::string msg;
991       msg.append("Error: ");
992       msg.append("Symbol ");
993       if (V->hasName())
994         msg.append(std::string(V->getName()));
995       msg.append("has unsupported appending linkage type");
996       llvm_unreachable(msg.c_str());
997     } else if (!V->hasInternalLinkage() &&
998                !V->hasPrivateLinkage()) {
999       O << ".weak ";
1000     }
1001   }
1002 }
1003 
1004 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1005                                          raw_ostream &O, bool processDemoted,
1006                                          const NVPTXSubtarget &STI) {
1007   // Skip meta data
1008   if (GVar->hasSection()) {
1009     if (GVar->getSection() == "llvm.metadata")
1010       return;
1011   }
1012 
1013   // Skip LLVM intrinsic global variables
1014   if (GVar->getName().starts_with("llvm.") ||
1015       GVar->getName().starts_with("nvvm."))
1016     return;
1017 
1018   const DataLayout &DL = getDataLayout();
1019 
1020   // GlobalVariables are always constant pointers themselves.
1021   Type *ETy = GVar->getValueType();
1022 
1023   if (GVar->hasExternalLinkage()) {
1024     if (GVar->hasInitializer())
1025       O << ".visible ";
1026     else
1027       O << ".extern ";
1028   } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1029              GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1030     O << ".common ";
1031   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1032              GVar->hasAvailableExternallyLinkage() ||
1033              GVar->hasCommonLinkage()) {
1034     O << ".weak ";
1035   }
1036 
1037   if (isTexture(*GVar)) {
1038     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1039     return;
1040   }
1041 
1042   if (isSurface(*GVar)) {
1043     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1044     return;
1045   }
1046 
1047   if (GVar->isDeclaration()) {
1048     // (extern) declarations, no definition or initializer
1049     // Currently the only known declaration is for an automatic __local
1050     // (.shared) promoted to global.
1051     emitPTXGlobalVariable(GVar, O, STI);
1052     O << ";\n";
1053     return;
1054   }
1055 
1056   if (isSampler(*GVar)) {
1057     O << ".global .samplerref " << getSamplerName(*GVar);
1058 
1059     const Constant *Initializer = nullptr;
1060     if (GVar->hasInitializer())
1061       Initializer = GVar->getInitializer();
1062     const ConstantInt *CI = nullptr;
1063     if (Initializer)
1064       CI = dyn_cast<ConstantInt>(Initializer);
1065     if (CI) {
1066       unsigned sample = CI->getZExtValue();
1067 
1068       O << " = { ";
1069 
1070       for (int i = 0,
1071                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1072            i < 3; i++) {
1073         O << "addr_mode_" << i << " = ";
1074         switch (addr) {
1075         case 0:
1076           O << "wrap";
1077           break;
1078         case 1:
1079           O << "clamp_to_border";
1080           break;
1081         case 2:
1082           O << "clamp_to_edge";
1083           break;
1084         case 3:
1085           O << "wrap";
1086           break;
1087         case 4:
1088           O << "mirror";
1089           break;
1090         }
1091         O << ", ";
1092       }
1093       O << "filter_mode = ";
1094       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1095       case 0:
1096         O << "nearest";
1097         break;
1098       case 1:
1099         O << "linear";
1100         break;
1101       case 2:
1102         llvm_unreachable("Anisotropic filtering is not supported");
1103       default:
1104         O << "nearest";
1105         break;
1106       }
1107       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1108         O << ", force_unnormalized_coords = 1";
1109       }
1110       O << " }";
1111     }
1112 
1113     O << ";\n";
1114     return;
1115   }
1116 
1117   if (GVar->hasPrivateLinkage()) {
1118     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1119       return;
1120 
1121     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1122     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1123       return;
1124     if (GVar->use_empty())
1125       return;
1126   }
1127 
1128   const Function *demotedFunc = nullptr;
1129   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1130     O << "// " << GVar->getName() << " has been demoted\n";
1131     if (localDecls.find(demotedFunc) != localDecls.end())
1132       localDecls[demotedFunc].push_back(GVar);
1133     else {
1134       std::vector<const GlobalVariable *> temp;
1135       temp.push_back(GVar);
1136       localDecls[demotedFunc] = temp;
1137     }
1138     return;
1139   }
1140 
1141   O << ".";
1142   emitPTXAddressSpace(GVar->getAddressSpace(), O);
1143 
1144   if (isManaged(*GVar)) {
1145     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1146       report_fatal_error(
1147           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1148     }
1149     O << " .attribute(.managed)";
1150   }
1151 
1152   if (MaybeAlign A = GVar->getAlign())
1153     O << " .align " << A->value();
1154   else
1155     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1156 
1157   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1158       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1159     O << " .";
1160     // Special case: ABI requires that we use .u8 for predicates
1161     if (ETy->isIntegerTy(1))
1162       O << "u8";
1163     else
1164       O << getPTXFundamentalTypeStr(ETy, false);
1165     O << " ";
1166     getSymbol(GVar)->print(O, MAI);
1167 
1168     // Ptx allows variable initilization only for constant and global state
1169     // spaces.
1170     if (GVar->hasInitializer()) {
1171       if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1172           (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1173         const Constant *Initializer = GVar->getInitializer();
1174         // 'undef' is treated as there is no value specified.
1175         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1176           O << " = ";
1177           printScalarConstant(Initializer, O);
1178         }
1179       } else {
1180         // The frontend adds zero-initializer to device and constant variables
1181         // that don't have an initial value, and UndefValue to shared
1182         // variables, so skip warning for this case.
1183         if (!GVar->getInitializer()->isNullValue() &&
1184             !isa<UndefValue>(GVar->getInitializer())) {
1185           report_fatal_error("initial value of '" + GVar->getName() +
1186                              "' is not allowed in addrspace(" +
1187                              Twine(GVar->getAddressSpace()) + ")");
1188         }
1189       }
1190     }
1191   } else {
1192     uint64_t ElementSize = 0;
1193 
1194     // Although PTX has direct support for struct type and array type and
1195     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1196     // targets that support these high level field accesses. Structs, arrays
1197     // and vectors are lowered into arrays of bytes.
1198     switch (ETy->getTypeID()) {
1199     case Type::IntegerTyID: // Integers larger than 64 bits
1200     case Type::StructTyID:
1201     case Type::ArrayTyID:
1202     case Type::FixedVectorTyID:
1203       ElementSize = DL.getTypeStoreSize(ETy);
1204       // Ptx allows variable initilization only for constant and
1205       // global state spaces.
1206       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1207            (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1208           GVar->hasInitializer()) {
1209         const Constant *Initializer = GVar->getInitializer();
1210         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1211           AggBuffer aggBuffer(ElementSize, *this);
1212           bufferAggregateConstant(Initializer, &aggBuffer);
1213           if (aggBuffer.numSymbols()) {
1214             unsigned int ptrSize = MAI->getCodePointerSize();
1215             if (ElementSize % ptrSize ||
1216                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1217               // Print in bytes and use the mask() operator for pointers.
1218               if (!STI.hasMaskOperator())
1219                 report_fatal_error(
1220                     "initialized packed aggregate with pointers '" +
1221                     GVar->getName() +
1222                     "' requires at least PTX ISA version 7.1");
1223               O << " .u8 ";
1224               getSymbol(GVar)->print(O, MAI);
1225               O << "[" << ElementSize << "] = {";
1226               aggBuffer.printBytes(O);
1227               O << "}";
1228             } else {
1229               O << " .u" << ptrSize * 8 << " ";
1230               getSymbol(GVar)->print(O, MAI);
1231               O << "[" << ElementSize / ptrSize << "] = {";
1232               aggBuffer.printWords(O);
1233               O << "}";
1234             }
1235           } else {
1236             O << " .b8 ";
1237             getSymbol(GVar)->print(O, MAI);
1238             O << "[" << ElementSize << "] = {";
1239             aggBuffer.printBytes(O);
1240             O << "}";
1241           }
1242         } else {
1243           O << " .b8 ";
1244           getSymbol(GVar)->print(O, MAI);
1245           if (ElementSize) {
1246             O << "[";
1247             O << ElementSize;
1248             O << "]";
1249           }
1250         }
1251       } else {
1252         O << " .b8 ";
1253         getSymbol(GVar)->print(O, MAI);
1254         if (ElementSize) {
1255           O << "[";
1256           O << ElementSize;
1257           O << "]";
1258         }
1259       }
1260       break;
1261     default:
1262       llvm_unreachable("type not supported yet");
1263     }
1264   }
1265   O << ";\n";
1266 }
1267 
1268 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1269   const Value *v = Symbols[nSym];
1270   const Value *v0 = SymbolsBeforeStripping[nSym];
1271   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1272     MCSymbol *Name = AP.getSymbol(GVar);
1273     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1274     // Is v0 a generic pointer?
1275     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1276     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1277       os << "generic(";
1278       Name->print(os, AP.MAI);
1279       os << ")";
1280     } else {
1281       Name->print(os, AP.MAI);
1282     }
1283   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1284     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1285     AP.printMCExpr(*Expr, os);
1286   } else
1287     llvm_unreachable("symbol type unknown");
1288 }
1289 
1290 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1291   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1292   // Do not emit trailing zero initializers. They will be zero-initialized by
1293   // ptxas. This saves on both space requirements for the generated PTX and on
1294   // memory use by ptxas. (See:
1295   // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1296   unsigned int InitializerCount = size;
1297   // TODO: symbols make this harder, but it would still be good to trim trailing
1298   // 0s for aggs with symbols as well.
1299   if (numSymbols() == 0)
1300     while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1301       InitializerCount--;
1302 
1303   symbolPosInBuffer.push_back(InitializerCount);
1304   unsigned int nSym = 0;
1305   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1306   for (unsigned int pos = 0; pos < InitializerCount;) {
1307     if (pos)
1308       os << ", ";
1309     if (pos != nextSymbolPos) {
1310       os << (unsigned int)buffer[pos];
1311       ++pos;
1312       continue;
1313     }
1314     // Generate a per-byte mask() operator for the symbol, which looks like:
1315     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1316     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1317     std::string symText;
1318     llvm::raw_string_ostream oss(symText);
1319     printSymbol(nSym, oss);
1320     for (unsigned i = 0; i < ptrSize; ++i) {
1321       if (i)
1322         os << ", ";
1323       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1324       os << "(" << symText << ")";
1325     }
1326     pos += ptrSize;
1327     nextSymbolPos = symbolPosInBuffer[++nSym];
1328     assert(nextSymbolPos >= pos);
1329   }
1330 }
1331 
1332 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1333   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1334   symbolPosInBuffer.push_back(size);
1335   unsigned int nSym = 0;
1336   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1337   assert(nextSymbolPos % ptrSize == 0);
1338   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1339     if (pos)
1340       os << ", ";
1341     if (pos == nextSymbolPos) {
1342       printSymbol(nSym, os);
1343       nextSymbolPos = symbolPosInBuffer[++nSym];
1344       assert(nextSymbolPos % ptrSize == 0);
1345       assert(nextSymbolPos >= pos + ptrSize);
1346     } else if (ptrSize == 4)
1347       os << support::endian::read32le(&buffer[pos]);
1348     else
1349       os << support::endian::read64le(&buffer[pos]);
1350   }
1351 }
1352 
1353 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1354   if (localDecls.find(f) == localDecls.end())
1355     return;
1356 
1357   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1358 
1359   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1360   const NVPTXSubtarget &STI =
1361       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1362 
1363   for (const GlobalVariable *GV : gvars) {
1364     O << "\t// demoted variable\n\t";
1365     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1366   }
1367 }
1368 
1369 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1370                                           raw_ostream &O) const {
1371   switch (AddressSpace) {
1372   case ADDRESS_SPACE_LOCAL:
1373     O << "local";
1374     break;
1375   case ADDRESS_SPACE_GLOBAL:
1376     O << "global";
1377     break;
1378   case ADDRESS_SPACE_CONST:
1379     O << "const";
1380     break;
1381   case ADDRESS_SPACE_SHARED:
1382     O << "shared";
1383     break;
1384   default:
1385     report_fatal_error("Bad address space found while emitting PTX: " +
1386                        llvm::Twine(AddressSpace));
1387     break;
1388   }
1389 }
1390 
1391 std::string
1392 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1393   switch (Ty->getTypeID()) {
1394   case Type::IntegerTyID: {
1395     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1396     if (NumBits == 1)
1397       return "pred";
1398     else if (NumBits <= 64) {
1399       std::string name = "u";
1400       return name + utostr(NumBits);
1401     } else {
1402       llvm_unreachable("Integer too large");
1403       break;
1404     }
1405     break;
1406   }
1407   case Type::BFloatTyID:
1408   case Type::HalfTyID:
1409     // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1410     // PTX assembly.
1411     return "b16";
1412   case Type::FloatTyID:
1413     return "f32";
1414   case Type::DoubleTyID:
1415     return "f64";
1416   case Type::PointerTyID: {
1417     unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1418     assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1419 
1420     if (PtrSize == 64)
1421       if (useB4PTR)
1422         return "b64";
1423       else
1424         return "u64";
1425     else if (useB4PTR)
1426       return "b32";
1427     else
1428       return "u32";
1429   }
1430   default:
1431     break;
1432   }
1433   llvm_unreachable("unexpected type");
1434 }
1435 
1436 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1437                                             raw_ostream &O,
1438                                             const NVPTXSubtarget &STI) {
1439   const DataLayout &DL = getDataLayout();
1440 
1441   // GlobalVariables are always constant pointers themselves.
1442   Type *ETy = GVar->getValueType();
1443 
1444   O << ".";
1445   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1446   if (isManaged(*GVar)) {
1447     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1448       report_fatal_error(
1449           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1450     }
1451     O << " .attribute(.managed)";
1452   }
1453   if (MaybeAlign A = GVar->getAlign())
1454     O << " .align " << A->value();
1455   else
1456     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1457 
1458   // Special case for i128
1459   if (ETy->isIntegerTy(128)) {
1460     O << " .b8 ";
1461     getSymbol(GVar)->print(O, MAI);
1462     O << "[16]";
1463     return;
1464   }
1465 
1466   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1467     O << " .";
1468     O << getPTXFundamentalTypeStr(ETy);
1469     O << " ";
1470     getSymbol(GVar)->print(O, MAI);
1471     return;
1472   }
1473 
1474   int64_t ElementSize = 0;
1475 
1476   // Although PTX has direct support for struct type and array type and LLVM IR
1477   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1478   // support these high level field accesses. Structs and arrays are lowered
1479   // into arrays of bytes.
1480   switch (ETy->getTypeID()) {
1481   case Type::StructTyID:
1482   case Type::ArrayTyID:
1483   case Type::FixedVectorTyID:
1484     ElementSize = DL.getTypeStoreSize(ETy);
1485     O << " .b8 ";
1486     getSymbol(GVar)->print(O, MAI);
1487     O << "[";
1488     if (ElementSize) {
1489       O << ElementSize;
1490     }
1491     O << "]";
1492     break;
1493   default:
1494     llvm_unreachable("type not supported yet");
1495   }
1496 }
1497 
1498 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1499   const DataLayout &DL = getDataLayout();
1500   const AttributeList &PAL = F->getAttributes();
1501   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1502   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1503 
1504   Function::const_arg_iterator I, E;
1505   unsigned paramIndex = 0;
1506   bool first = true;
1507   bool isKernelFunc = isKernelFunction(*F);
1508   bool isABI = (STI.getSmVersion() >= 20);
1509   bool hasImageHandles = STI.hasImageHandles();
1510 
1511   if (F->arg_empty() && !F->isVarArg()) {
1512     O << "()";
1513     return;
1514   }
1515 
1516   O << "(\n";
1517 
1518   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1519     Type *Ty = I->getType();
1520 
1521     if (!first)
1522       O << ",\n";
1523 
1524     first = false;
1525 
1526     // Handle image/sampler parameters
1527     if (isKernelFunction(*F)) {
1528       if (isSampler(*I) || isImage(*I)) {
1529         if (isImage(*I)) {
1530           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1531             if (hasImageHandles)
1532               O << "\t.param .u64 .ptr .surfref ";
1533             else
1534               O << "\t.param .surfref ";
1535             O << TLI->getParamName(F, paramIndex);
1536           }
1537           else { // Default image is read_only
1538             if (hasImageHandles)
1539               O << "\t.param .u64 .ptr .texref ";
1540             else
1541               O << "\t.param .texref ";
1542             O << TLI->getParamName(F, paramIndex);
1543           }
1544         } else {
1545           if (hasImageHandles)
1546             O << "\t.param .u64 .ptr .samplerref ";
1547           else
1548             O << "\t.param .samplerref ";
1549           O << TLI->getParamName(F, paramIndex);
1550         }
1551         continue;
1552       }
1553     }
1554 
1555     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1556                                     paramIndex](Type *Ty) -> Align {
1557       if (MaybeAlign StackAlign =
1558               getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1559         return StackAlign.value();
1560 
1561       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1562       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1563       return std::max(TypeAlign, ParamAlign.valueOrOne());
1564     };
1565 
1566     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1567       if (ShouldPassAsArray(Ty)) {
1568         // Just print .param .align <a> .b8 .param[size];
1569         // <a>  = optimal alignment for the element type; always multiple of
1570         //        PAL.getParamAlignment
1571         // size = typeallocsize of element type
1572         Align OptimalAlign = getOptimalAlignForParam(Ty);
1573 
1574         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1575         O << TLI->getParamName(F, paramIndex);
1576         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1577 
1578         continue;
1579       }
1580       // Just a scalar
1581       auto *PTy = dyn_cast<PointerType>(Ty);
1582       unsigned PTySizeInBits = 0;
1583       if (PTy) {
1584         PTySizeInBits =
1585             TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1586         assert(PTySizeInBits && "Invalid pointer size");
1587       }
1588 
1589       if (isKernelFunc) {
1590         if (PTy) {
1591           // Special handling for pointer arguments to kernel
1592           O << "\t.param .u" << PTySizeInBits << " ";
1593 
1594           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1595               NVPTX::CUDA) {
1596             int addrSpace = PTy->getAddressSpace();
1597             switch (addrSpace) {
1598             default:
1599               O << ".ptr ";
1600               break;
1601             case ADDRESS_SPACE_CONST:
1602               O << ".ptr .const ";
1603               break;
1604             case ADDRESS_SPACE_SHARED:
1605               O << ".ptr .shared ";
1606               break;
1607             case ADDRESS_SPACE_GLOBAL:
1608               O << ".ptr .global ";
1609               break;
1610             }
1611             Align ParamAlign = I->getParamAlign().valueOrOne();
1612             O << ".align " << ParamAlign.value() << " ";
1613           }
1614           O << TLI->getParamName(F, paramIndex);
1615           continue;
1616         }
1617 
1618         // non-pointer scalar to kernel func
1619         O << "\t.param .";
1620         // Special case: predicate operands become .u8 types
1621         if (Ty->isIntegerTy(1))
1622           O << "u8";
1623         else
1624           O << getPTXFundamentalTypeStr(Ty);
1625         O << " ";
1626         O << TLI->getParamName(F, paramIndex);
1627         continue;
1628       }
1629       // Non-kernel function, just print .param .b<size> for ABI
1630       // and .reg .b<size> for non-ABI
1631       unsigned sz = 0;
1632       if (isa<IntegerType>(Ty)) {
1633         sz = cast<IntegerType>(Ty)->getBitWidth();
1634         sz = promoteScalarArgumentSize(sz);
1635       } else if (PTy) {
1636         assert(PTySizeInBits && "Invalid pointer size");
1637         sz = PTySizeInBits;
1638       } else
1639         sz = Ty->getPrimitiveSizeInBits();
1640       if (isABI)
1641         O << "\t.param .b" << sz << " ";
1642       else
1643         O << "\t.reg .b" << sz << " ";
1644       O << TLI->getParamName(F, paramIndex);
1645       continue;
1646     }
1647 
1648     // param has byVal attribute.
1649     Type *ETy = PAL.getParamByValType(paramIndex);
1650     assert(ETy && "Param should have byval type");
1651 
1652     if (isABI || isKernelFunc) {
1653       // Just print .param .align <a> .b8 .param[size];
1654       // <a>  = optimal alignment for the element type; always multiple of
1655       //        PAL.getParamAlignment
1656       // size = typeallocsize of element type
1657       Align OptimalAlign =
1658           isKernelFunc
1659               ? getOptimalAlignForParam(ETy)
1660               : TLI->getFunctionByValParamAlign(
1661                     F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1662 
1663       unsigned sz = DL.getTypeAllocSize(ETy);
1664       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1665       O << TLI->getParamName(F, paramIndex);
1666       O << "[" << sz << "]";
1667       continue;
1668     } else {
1669       // Split the ETy into constituent parts and
1670       // print .param .b<size> <name> for each part.
1671       // Further, if a part is vector, print the above for
1672       // each vector element.
1673       SmallVector<EVT, 16> vtparts;
1674       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1675       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1676         unsigned elems = 1;
1677         EVT elemtype = vtparts[i];
1678         if (vtparts[i].isVector()) {
1679           elems = vtparts[i].getVectorNumElements();
1680           elemtype = vtparts[i].getVectorElementType();
1681         }
1682 
1683         for (unsigned j = 0, je = elems; j != je; ++j) {
1684           unsigned sz = elemtype.getSizeInBits();
1685           if (elemtype.isInteger())
1686             sz = promoteScalarArgumentSize(sz);
1687           O << "\t.reg .b" << sz << " ";
1688           O << TLI->getParamName(F, paramIndex);
1689           if (j < je - 1)
1690             O << ",\n";
1691           ++paramIndex;
1692         }
1693         if (i < e - 1)
1694           O << ",\n";
1695       }
1696       --paramIndex;
1697       continue;
1698     }
1699   }
1700 
1701   if (F->isVarArg()) {
1702     if (!first)
1703       O << ",\n";
1704     O << "\t.param .align " << STI.getMaxRequiredAlignment();
1705     O << " .b8 ";
1706     O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1707   }
1708 
1709   O << "\n)";
1710 }
1711 
1712 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1713     const MachineFunction &MF) {
1714   SmallString<128> Str;
1715   raw_svector_ostream O(Str);
1716 
1717   // Map the global virtual register number to a register class specific
1718   // virtual register number starting from 1 with that class.
1719   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1720   //unsigned numRegClasses = TRI->getNumRegClasses();
1721 
1722   // Emit the Fake Stack Object
1723   const MachineFrameInfo &MFI = MF.getFrameInfo();
1724   int64_t NumBytes = MFI.getStackSize();
1725   if (NumBytes) {
1726     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1727       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1728     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1729       O << "\t.reg .b64 \t%SP;\n";
1730       O << "\t.reg .b64 \t%SPL;\n";
1731     } else {
1732       O << "\t.reg .b32 \t%SP;\n";
1733       O << "\t.reg .b32 \t%SPL;\n";
1734     }
1735   }
1736 
1737   // Go through all virtual registers to establish the mapping between the
1738   // global virtual
1739   // register number and the per class virtual register number.
1740   // We use the per class virtual register number in the ptx output.
1741   unsigned int numVRs = MRI->getNumVirtRegs();
1742   for (unsigned i = 0; i < numVRs; i++) {
1743     Register vr = Register::index2VirtReg(i);
1744     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1745     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1746     int n = regmap.size();
1747     regmap.insert(std::make_pair(vr, n + 1));
1748   }
1749 
1750   // Emit register declarations
1751   // @TODO: Extract out the real register usage
1752   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1753   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1754   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1755   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1756   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1757   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1758   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1759 
1760   // Emit declaration of the virtual registers or 'physical' registers for
1761   // each register class
1762   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1763     const TargetRegisterClass *RC = TRI->getRegClass(i);
1764     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1765     std::string rcname = getNVPTXRegClassName(RC);
1766     std::string rcStr = getNVPTXRegClassStr(RC);
1767     int n = regmap.size();
1768 
1769     // Only declare those registers that may be used.
1770     if (n) {
1771        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1772          << ">;\n";
1773     }
1774   }
1775 
1776   OutStreamer->emitRawText(O.str());
1777 }
1778 
1779 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1780   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1781   bool ignored;
1782   unsigned int numHex;
1783   const char *lead;
1784 
1785   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1786     numHex = 8;
1787     lead = "0f";
1788     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1789   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1790     numHex = 16;
1791     lead = "0d";
1792     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1793   } else
1794     llvm_unreachable("unsupported fp type");
1795 
1796   APInt API = APF.bitcastToAPInt();
1797   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1798 }
1799 
1800 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1801   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1802     O << CI->getValue();
1803     return;
1804   }
1805   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1806     printFPConstant(CFP, O);
1807     return;
1808   }
1809   if (isa<ConstantPointerNull>(CPV)) {
1810     O << "0";
1811     return;
1812   }
1813   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1814     bool IsNonGenericPointer = false;
1815     if (GVar->getType()->getAddressSpace() != 0) {
1816       IsNonGenericPointer = true;
1817     }
1818     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1819       O << "generic(";
1820       getSymbol(GVar)->print(O, MAI);
1821       O << ")";
1822     } else {
1823       getSymbol(GVar)->print(O, MAI);
1824     }
1825     return;
1826   }
1827   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828     const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1829     printMCExpr(*E, O);
1830     return;
1831   }
1832   llvm_unreachable("Not scalar type found in printScalarConstant()");
1833 }
1834 
1835 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1836                                    AggBuffer *AggBuffer) {
1837   const DataLayout &DL = getDataLayout();
1838   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1839   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1840     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1841     // only the space allocated by CPV.
1842     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1843     return;
1844   }
1845 
1846   // Helper for filling AggBuffer with APInts.
1847   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1848     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1849     SmallVector<unsigned char, 16> Buf(NumBytes);
1850     // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1851     // input's bit width, and i1 arrays may not have a length that is a multuple
1852     // of 8. We handle the last byte separately, so we never request out of
1853     // bounds bits.
1854     for (unsigned I = 0; I < NumBytes - 1; ++I) {
1855       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1856     }
1857     size_t LastBytePosition = (NumBytes - 1) * 8;
1858     size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1859     Buf[NumBytes - 1] =
1860         Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);
1861     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1862   };
1863 
1864   switch (CPV->getType()->getTypeID()) {
1865   case Type::IntegerTyID:
1866     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1867       AddIntToBuffer(CI->getValue());
1868       break;
1869     }
1870     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1871       if (const auto *CI =
1872               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1873         AddIntToBuffer(CI->getValue());
1874         break;
1875       }
1876       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1877         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1878         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1879         AggBuffer->addZeros(AllocSize);
1880         break;
1881       }
1882     }
1883     llvm_unreachable("unsupported integer const type");
1884     break;
1885 
1886   case Type::HalfTyID:
1887   case Type::BFloatTyID:
1888   case Type::FloatTyID:
1889   case Type::DoubleTyID:
1890     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1891     break;
1892 
1893   case Type::PointerTyID: {
1894     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1895       AggBuffer->addSymbol(GVar, GVar);
1896     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1897       const Value *v = Cexpr->stripPointerCasts();
1898       AggBuffer->addSymbol(v, Cexpr);
1899     }
1900     AggBuffer->addZeros(AllocSize);
1901     break;
1902   }
1903 
1904   case Type::ArrayTyID:
1905   case Type::FixedVectorTyID:
1906   case Type::StructTyID: {
1907     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1908       bufferAggregateConstant(CPV, AggBuffer);
1909       if (Bytes > AllocSize)
1910         AggBuffer->addZeros(Bytes - AllocSize);
1911     } else if (isa<ConstantAggregateZero>(CPV))
1912       AggBuffer->addZeros(Bytes);
1913     else
1914       llvm_unreachable("Unexpected Constant type");
1915     break;
1916   }
1917 
1918   default:
1919     llvm_unreachable("unsupported type");
1920   }
1921 }
1922 
1923 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1924                                               AggBuffer *aggBuffer) {
1925   const DataLayout &DL = getDataLayout();
1926   int Bytes;
1927 
1928   // Integers of arbitrary width
1929   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1930     APInt Val = CI->getValue();
1931     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1932       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1933       aggBuffer->addBytes(&Byte, 1, 1);
1934       Val.lshrInPlace(8);
1935     }
1936     return;
1937   }
1938 
1939   // Old constants
1940   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1941     if (CPV->getNumOperands())
1942       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1943         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1944     return;
1945   }
1946 
1947   if (const ConstantDataSequential *CDS =
1948           dyn_cast<ConstantDataSequential>(CPV)) {
1949     if (CDS->getNumElements())
1950       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1951         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1952                      aggBuffer);
1953     return;
1954   }
1955 
1956   if (isa<ConstantStruct>(CPV)) {
1957     if (CPV->getNumOperands()) {
1958       StructType *ST = cast<StructType>(CPV->getType());
1959       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1960         if (i == (e - 1))
1961           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1962                   DL.getTypeAllocSize(ST) -
1963                   DL.getStructLayout(ST)->getElementOffset(i);
1964         else
1965           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1966                   DL.getStructLayout(ST)->getElementOffset(i);
1967         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1968       }
1969     }
1970     return;
1971   }
1972   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1973 }
1974 
1975 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1976 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1977 /// expressions that are representable in PTX and create
1978 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1979 const MCExpr *
1980 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1981   MCContext &Ctx = OutContext;
1982 
1983   if (CV->isNullValue() || isa<UndefValue>(CV))
1984     return MCConstantExpr::create(0, Ctx);
1985 
1986   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1987     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1988 
1989   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1990     const MCSymbolRefExpr *Expr =
1991       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1992     if (ProcessingGeneric) {
1993       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1994     } else {
1995       return Expr;
1996     }
1997   }
1998 
1999   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2000   if (!CE) {
2001     llvm_unreachable("Unknown constant value to lower!");
2002   }
2003 
2004   switch (CE->getOpcode()) {
2005   default:
2006     break; // Error
2007 
2008   case Instruction::AddrSpaceCast: {
2009     // Strip the addrspacecast and pass along the operand
2010     PointerType *DstTy = cast<PointerType>(CE->getType());
2011     if (DstTy->getAddressSpace() == 0)
2012       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2013 
2014     break; // Error
2015   }
2016 
2017   case Instruction::GetElementPtr: {
2018     const DataLayout &DL = getDataLayout();
2019 
2020     // Generate a symbolic expression for the byte address
2021     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2022     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2023 
2024     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2025                                             ProcessingGeneric);
2026     if (!OffsetAI)
2027       return Base;
2028 
2029     int64_t Offset = OffsetAI.getSExtValue();
2030     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2031                                    Ctx);
2032   }
2033 
2034   case Instruction::Trunc:
2035     // We emit the value and depend on the assembler to truncate the generated
2036     // expression properly.  This is important for differences between
2037     // blockaddress labels.  Since the two labels are in the same function, it
2038     // is reasonable to treat their delta as a 32-bit value.
2039     [[fallthrough]];
2040   case Instruction::BitCast:
2041     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2042 
2043   case Instruction::IntToPtr: {
2044     const DataLayout &DL = getDataLayout();
2045 
2046     // Handle casts to pointers by changing them into casts to the appropriate
2047     // integer type.  This promotes constant folding and simplifies this code.
2048     Constant *Op = CE->getOperand(0);
2049     Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2050                                  /*IsSigned*/ false, DL);
2051     if (Op)
2052       return lowerConstantForGV(Op, ProcessingGeneric);
2053 
2054     break; // Error
2055   }
2056 
2057   case Instruction::PtrToInt: {
2058     const DataLayout &DL = getDataLayout();
2059 
2060     // Support only foldable casts to/from pointers that can be eliminated by
2061     // changing the pointer to the appropriately sized integer type.
2062     Constant *Op = CE->getOperand(0);
2063     Type *Ty = CE->getType();
2064 
2065     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2066 
2067     // We can emit the pointer value into this slot if the slot is an
2068     // integer slot equal to the size of the pointer.
2069     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2070       return OpExpr;
2071 
2072     // Otherwise the pointer is smaller than the resultant integer, mask off
2073     // the high bits so we are sure to get a proper truncation if the input is
2074     // a constant expr.
2075     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2076     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2077     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2078   }
2079 
2080   // The MC library also has a right-shift operator, but it isn't consistently
2081   // signed or unsigned between different targets.
2082   case Instruction::Add: {
2083     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2084     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2085     switch (CE->getOpcode()) {
2086     default: llvm_unreachable("Unknown binary operator constant cast expr");
2087     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2088     }
2089   }
2090   }
2091 
2092   // If the code isn't optimized, there may be outstanding folding
2093   // opportunities. Attempt to fold the expression using DataLayout as a
2094   // last resort before giving up.
2095   Constant *C = ConstantFoldConstant(CE, getDataLayout());
2096   if (C != CE)
2097     return lowerConstantForGV(C, ProcessingGeneric);
2098 
2099   // Otherwise report the problem to the user.
2100   std::string S;
2101   raw_string_ostream OS(S);
2102   OS << "Unsupported expression in static initializer: ";
2103   CE->printAsOperand(OS, /*PrintType=*/false,
2104                  !MF ? nullptr : MF->getFunction().getParent());
2105   report_fatal_error(Twine(OS.str()));
2106 }
2107 
2108 // Copy of MCExpr::print customized for NVPTX
2109 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2110   switch (Expr.getKind()) {
2111   case MCExpr::Target:
2112     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2113   case MCExpr::Constant:
2114     OS << cast<MCConstantExpr>(Expr).getValue();
2115     return;
2116 
2117   case MCExpr::SymbolRef: {
2118     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2119     const MCSymbol &Sym = SRE.getSymbol();
2120     Sym.print(OS, MAI);
2121     return;
2122   }
2123 
2124   case MCExpr::Unary: {
2125     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2126     switch (UE.getOpcode()) {
2127     case MCUnaryExpr::LNot:  OS << '!'; break;
2128     case MCUnaryExpr::Minus: OS << '-'; break;
2129     case MCUnaryExpr::Not:   OS << '~'; break;
2130     case MCUnaryExpr::Plus:  OS << '+'; break;
2131     }
2132     printMCExpr(*UE.getSubExpr(), OS);
2133     return;
2134   }
2135 
2136   case MCExpr::Binary: {
2137     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2138 
2139     // Only print parens around the LHS if it is non-trivial.
2140     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2141         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2142       printMCExpr(*BE.getLHS(), OS);
2143     } else {
2144       OS << '(';
2145       printMCExpr(*BE.getLHS(), OS);
2146       OS<< ')';
2147     }
2148 
2149     switch (BE.getOpcode()) {
2150     case MCBinaryExpr::Add:
2151       // Print "X-42" instead of "X+-42".
2152       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2153         if (RHSC->getValue() < 0) {
2154           OS << RHSC->getValue();
2155           return;
2156         }
2157       }
2158 
2159       OS <<  '+';
2160       break;
2161     default: llvm_unreachable("Unhandled binary operator");
2162     }
2163 
2164     // Only print parens around the LHS if it is non-trivial.
2165     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2166       printMCExpr(*BE.getRHS(), OS);
2167     } else {
2168       OS << '(';
2169       printMCExpr(*BE.getRHS(), OS);
2170       OS << ')';
2171     }
2172     return;
2173   }
2174   }
2175 
2176   llvm_unreachable("Invalid expression kind!");
2177 }
2178 
2179 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2180 ///
2181 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2182                                       const char *ExtraCode, raw_ostream &O) {
2183   if (ExtraCode && ExtraCode[0]) {
2184     if (ExtraCode[1] != 0)
2185       return true; // Unknown modifier.
2186 
2187     switch (ExtraCode[0]) {
2188     default:
2189       // See if this is a generic print operand
2190       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2191     case 'r':
2192       break;
2193     }
2194   }
2195 
2196   printOperand(MI, OpNo, O);
2197 
2198   return false;
2199 }
2200 
2201 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2202                                             unsigned OpNo,
2203                                             const char *ExtraCode,
2204                                             raw_ostream &O) {
2205   if (ExtraCode && ExtraCode[0])
2206     return true; // Unknown modifier
2207 
2208   O << '[';
2209   printMemOperand(MI, OpNo, O);
2210   O << ']';
2211 
2212   return false;
2213 }
2214 
2215 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2216                                    raw_ostream &O) {
2217   const MachineOperand &MO = MI->getOperand(OpNum);
2218   switch (MO.getType()) {
2219   case MachineOperand::MO_Register:
2220     if (MO.getReg().isPhysical()) {
2221       if (MO.getReg() == NVPTX::VRDepot)
2222         O << DEPOTNAME << getFunctionNumber();
2223       else
2224         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2225     } else {
2226       emitVirtualRegister(MO.getReg(), O);
2227     }
2228     break;
2229 
2230   case MachineOperand::MO_Immediate:
2231     O << MO.getImm();
2232     break;
2233 
2234   case MachineOperand::MO_FPImmediate:
2235     printFPConstant(MO.getFPImm(), O);
2236     break;
2237 
2238   case MachineOperand::MO_GlobalAddress:
2239     PrintSymbolOperand(MO, O);
2240     break;
2241 
2242   case MachineOperand::MO_MachineBasicBlock:
2243     MO.getMBB()->getSymbol()->print(O, MAI);
2244     break;
2245 
2246   default:
2247     llvm_unreachable("Operand type not supported.");
2248   }
2249 }
2250 
2251 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2252                                       raw_ostream &O, const char *Modifier) {
2253   printOperand(MI, OpNum, O);
2254 
2255   if (Modifier && strcmp(Modifier, "add") == 0) {
2256     O << ", ";
2257     printOperand(MI, OpNum + 1, O);
2258   } else {
2259     if (MI->getOperand(OpNum + 1).isImm() &&
2260         MI->getOperand(OpNum + 1).getImm() == 0)
2261       return; // don't print ',0' or '+0'
2262     O << "+";
2263     printOperand(MI, OpNum + 1, O);
2264   }
2265 }
2266 
2267 // Force static initialization.
2268 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2269   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2270   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2271 }
2272