xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (revision 2a0c0aea42092f89c2a5345991e6e3ce4cbef99a)
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MachineValueType.h"
79 #include "llvm/Support/NativeFormatting.h"
80 #include "llvm/Support/Path.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92 
93 using namespace llvm;
94 
95 #define DEPOTNAME "__local_depot"
96 
97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
98 /// depends.
99 static void
100 DiscoverDependentGlobals(const Value *V,
101                          DenseSet<const GlobalVariable *> &Globals) {
102   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
103     Globals.insert(GV);
104   else {
105     if (const User *U = dyn_cast<User>(V)) {
106       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
107         DiscoverDependentGlobals(U->getOperand(i), Globals);
108       }
109     }
110   }
111 }
112 
113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114 /// instances to be emitted, but only after any dependents have been added
115 /// first.s
116 static void
117 VisitGlobalVariableForEmission(const GlobalVariable *GV,
118                                SmallVectorImpl<const GlobalVariable *> &Order,
119                                DenseSet<const GlobalVariable *> &Visited,
120                                DenseSet<const GlobalVariable *> &Visiting) {
121   // Have we already visited this one?
122   if (Visited.count(GV))
123     return;
124 
125   // Do we have a circular dependency?
126   if (!Visiting.insert(GV).second)
127     report_fatal_error("Circular dependency found in global variable set");
128 
129   // Make sure we visit all dependents first
130   DenseSet<const GlobalVariable *> Others;
131   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
132     DiscoverDependentGlobals(GV->getOperand(i), Others);
133 
134   for (const GlobalVariable *GV : Others)
135     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136 
137   // Now we can visit ourself
138   Order.push_back(GV);
139   Visited.insert(GV);
140   Visiting.erase(GV);
141 }
142 
143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
145                                         getSubtargetInfo().getFeatureBits());
146 
147   MCInst Inst;
148   lowerToMCInst(MI, Inst);
149   EmitToStreamer(*OutStreamer, Inst);
150 }
151 
152 // Handle symbol backtracking for targets that do not support image handles
153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154                                            unsigned OpNo, MCOperand &MCOp) {
155   const MachineOperand &MO = MI->getOperand(OpNo);
156   const MCInstrDesc &MCID = MI->getDesc();
157 
158   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159     // This is a texture fetch, so operand 4 is a texref and operand 5 is
160     // a samplerref
161     if (OpNo == 4 && MO.isImm()) {
162       lowerImageHandleSymbol(MO.getImm(), MCOp);
163       return true;
164     }
165     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166       lowerImageHandleSymbol(MO.getImm(), MCOp);
167       return true;
168     }
169 
170     return false;
171   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172     unsigned VecSize =
173       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174 
175     // For a surface load of vector size N, the Nth operand will be the surfref
176     if (OpNo == VecSize && MO.isImm()) {
177       lowerImageHandleSymbol(MO.getImm(), MCOp);
178       return true;
179     }
180 
181     return false;
182   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183     // This is a surface store, so operand 0 is a surfref
184     if (OpNo == 0 && MO.isImm()) {
185       lowerImageHandleSymbol(MO.getImm(), MCOp);
186       return true;
187     }
188 
189     return false;
190   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191     // This is a query, so operand 1 is a surfref/texref
192     if (OpNo == 1 && MO.isImm()) {
193       lowerImageHandleSymbol(MO.getImm(), MCOp);
194       return true;
195     }
196 
197     return false;
198   }
199 
200   return false;
201 }
202 
203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204   // Ewwww
205   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
206   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
207   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208   const char *Sym = MFI->getImageHandleSymbol(Index);
209   StringRef SymName = nvTM.getStrPool().save(Sym);
210   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
211 }
212 
213 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
214   OutMI.setOpcode(MI->getOpcode());
215   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
216   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
217     const MachineOperand &MO = MI->getOperand(0);
218     OutMI.addOperand(GetSymbolRef(
219       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
220     return;
221   }
222 
223   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
224   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
225     const MachineOperand &MO = MI->getOperand(i);
226 
227     MCOperand MCOp;
228     if (!STI.hasImageHandles()) {
229       if (lowerImageHandleOperand(MI, i, MCOp)) {
230         OutMI.addOperand(MCOp);
231         continue;
232       }
233     }
234 
235     if (lowerOperand(MO, MCOp))
236       OutMI.addOperand(MCOp);
237   }
238 }
239 
240 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
241                                    MCOperand &MCOp) {
242   switch (MO.getType()) {
243   default: llvm_unreachable("unknown operand type");
244   case MachineOperand::MO_Register:
245     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
246     break;
247   case MachineOperand::MO_Immediate:
248     MCOp = MCOperand::createImm(MO.getImm());
249     break;
250   case MachineOperand::MO_MachineBasicBlock:
251     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
252         MO.getMBB()->getSymbol(), OutContext));
253     break;
254   case MachineOperand::MO_ExternalSymbol:
255     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
256     break;
257   case MachineOperand::MO_GlobalAddress:
258     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
259     break;
260   case MachineOperand::MO_FPImmediate: {
261     const ConstantFP *Cnt = MO.getFPImm();
262     const APFloat &Val = Cnt->getValueAPF();
263 
264     switch (Cnt->getType()->getTypeID()) {
265     default: report_fatal_error("Unsupported FP type"); break;
266     case Type::HalfTyID:
267       MCOp = MCOperand::createExpr(
268         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
269       break;
270     case Type::FloatTyID:
271       MCOp = MCOperand::createExpr(
272         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
273       break;
274     case Type::DoubleTyID:
275       MCOp = MCOperand::createExpr(
276         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
277       break;
278     }
279     break;
280   }
281   }
282   return true;
283 }
284 
285 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
286   if (Register::isVirtualRegister(Reg)) {
287     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
288 
289     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
290     unsigned RegNum = RegMap[Reg];
291 
292     // Encode the register class in the upper 4 bits
293     // Must be kept in sync with NVPTXInstPrinter::printRegName
294     unsigned Ret = 0;
295     if (RC == &NVPTX::Int1RegsRegClass) {
296       Ret = (1 << 28);
297     } else if (RC == &NVPTX::Int16RegsRegClass) {
298       Ret = (2 << 28);
299     } else if (RC == &NVPTX::Int32RegsRegClass) {
300       Ret = (3 << 28);
301     } else if (RC == &NVPTX::Int64RegsRegClass) {
302       Ret = (4 << 28);
303     } else if (RC == &NVPTX::Float32RegsRegClass) {
304       Ret = (5 << 28);
305     } else if (RC == &NVPTX::Float64RegsRegClass) {
306       Ret = (6 << 28);
307     } else if (RC == &NVPTX::Float16RegsRegClass) {
308       Ret = (7 << 28);
309     } else if (RC == &NVPTX::Float16x2RegsRegClass) {
310       Ret = (8 << 28);
311     } else {
312       report_fatal_error("Bad register class");
313     }
314 
315     // Insert the vreg number
316     Ret |= (RegNum & 0x0FFFFFFF);
317     return Ret;
318   } else {
319     // Some special-use registers are actually physical registers.
320     // Encode this as the register class ID of 0 and the real register ID.
321     return Reg & 0x0FFFFFFF;
322   }
323 }
324 
325 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
326   const MCExpr *Expr;
327   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
328                                  OutContext);
329   return MCOperand::createExpr(Expr);
330 }
331 
332 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
333   const DataLayout &DL = getDataLayout();
334   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
335   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
336 
337   Type *Ty = F->getReturnType();
338 
339   bool isABI = (STI.getSmVersion() >= 20);
340 
341   if (Ty->getTypeID() == Type::VoidTyID)
342     return;
343 
344   O << " (";
345 
346   if (isABI) {
347     if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
348       unsigned size = 0;
349       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
350         size = ITy->getBitWidth();
351       } else {
352         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
353         size = Ty->getPrimitiveSizeInBits();
354       }
355       // PTX ABI requires all scalar return values to be at least 32
356       // bits in size.  fp16 normally uses .b16 as its storage type in
357       // PTX, so its size must be adjusted here, too.
358       size = promoteScalarArgumentSize(size);
359 
360       O << ".param .b" << size << " func_retval0";
361     } else if (isa<PointerType>(Ty)) {
362       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
363         << " func_retval0";
364     } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
365       unsigned totalsz = DL.getTypeAllocSize(Ty);
366       unsigned retAlignment = 0;
367       if (!getAlign(*F, 0, retAlignment))
368         retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
369       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
370         << "]";
371     } else
372       llvm_unreachable("Unknown return type");
373   } else {
374     SmallVector<EVT, 16> vtparts;
375     ComputeValueVTs(*TLI, DL, Ty, vtparts);
376     unsigned idx = 0;
377     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378       unsigned elems = 1;
379       EVT elemtype = vtparts[i];
380       if (vtparts[i].isVector()) {
381         elems = vtparts[i].getVectorNumElements();
382         elemtype = vtparts[i].getVectorElementType();
383       }
384 
385       for (unsigned j = 0, je = elems; j != je; ++j) {
386         unsigned sz = elemtype.getSizeInBits();
387         if (elemtype.isInteger())
388           sz = promoteScalarArgumentSize(sz);
389         O << ".reg .b" << sz << " func_retval" << idx;
390         if (j < je - 1)
391           O << ", ";
392         ++idx;
393       }
394       if (i < e - 1)
395         O << ", ";
396     }
397   }
398   O << ") ";
399 }
400 
401 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
402                                         raw_ostream &O) {
403   const Function &F = MF.getFunction();
404   printReturnValStr(&F, O);
405 }
406 
407 // Return true if MBB is the header of a loop marked with
408 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
409 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
410     const MachineBasicBlock &MBB) const {
411   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
412   // We insert .pragma "nounroll" only to the loop header.
413   if (!LI.isLoopHeader(&MBB))
414     return false;
415 
416   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
417   // we iterate through each back edge of the loop with header MBB, and check
418   // whether its metadata contains llvm.loop.unroll.disable.
419   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
420     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
421       // Edges from other loops to MBB are not back edges.
422       continue;
423     }
424     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
425       if (MDNode *LoopID =
426               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
427         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
428           return true;
429         if (MDNode *UnrollCountMD =
430                 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
431           if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
432                   ->getZExtValue() == 1)
433             return true;
434         }
435       }
436     }
437   }
438   return false;
439 }
440 
441 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
442   AsmPrinter::emitBasicBlockStart(MBB);
443   if (isLoopHeaderOfNoUnroll(MBB))
444     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
445 }
446 
447 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
448   SmallString<128> Str;
449   raw_svector_ostream O(Str);
450 
451   if (!GlobalsEmitted) {
452     emitGlobals(*MF->getFunction().getParent());
453     GlobalsEmitted = true;
454   }
455 
456   // Set up
457   MRI = &MF->getRegInfo();
458   F = &MF->getFunction();
459   emitLinkageDirective(F, O);
460   if (isKernelFunction(*F))
461     O << ".entry ";
462   else {
463     O << ".func ";
464     printReturnValStr(*MF, O);
465   }
466 
467   CurrentFnSym->print(O, MAI);
468 
469   emitFunctionParamList(*MF, O);
470 
471   if (isKernelFunction(*F))
472     emitKernelFunctionDirectives(*F, O);
473 
474   if (shouldEmitPTXNoReturn(F, TM))
475     O << ".noreturn";
476 
477   OutStreamer->emitRawText(O.str());
478 
479   VRegMapping.clear();
480   // Emit open brace for function body.
481   OutStreamer->emitRawText(StringRef("{\n"));
482   setAndEmitFunctionVirtualRegisters(*MF);
483   // Emit initial .loc debug directive for correct relocation symbol data.
484   if (MMI && MMI->hasDebugInfo())
485     emitInitialRawDwarfLocDirective(*MF);
486 }
487 
488 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
489   bool Result = AsmPrinter::runOnMachineFunction(F);
490   // Emit closing brace for the body of function F.
491   // The closing brace must be emitted here because we need to emit additional
492   // debug labels/data after the last basic block.
493   // We need to emit the closing brace here because we don't have function that
494   // finished emission of the function body.
495   OutStreamer->emitRawText(StringRef("}\n"));
496   return Result;
497 }
498 
499 void NVPTXAsmPrinter::emitFunctionBodyStart() {
500   SmallString<128> Str;
501   raw_svector_ostream O(Str);
502   emitDemotedVars(&MF->getFunction(), O);
503   OutStreamer->emitRawText(O.str());
504 }
505 
506 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
507   VRegMapping.clear();
508 }
509 
510 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
511     SmallString<128> Str;
512     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
513     return OutContext.getOrCreateSymbol(Str);
514 }
515 
516 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
517   Register RegNo = MI->getOperand(0).getReg();
518   if (RegNo.isVirtual()) {
519     OutStreamer->AddComment(Twine("implicit-def: ") +
520                             getVirtualRegisterName(RegNo));
521   } else {
522     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
523     OutStreamer->AddComment(Twine("implicit-def: ") +
524                             STI.getRegisterInfo()->getName(RegNo));
525   }
526   OutStreamer->addBlankLine();
527 }
528 
529 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
530                                                    raw_ostream &O) const {
531   // If the NVVM IR has some of reqntid* specified, then output
532   // the reqntid directive, and set the unspecified ones to 1.
533   // If none of reqntid* is specified, don't output reqntid directive.
534   unsigned reqntidx, reqntidy, reqntidz;
535   bool specified = false;
536   if (!getReqNTIDx(F, reqntidx))
537     reqntidx = 1;
538   else
539     specified = true;
540   if (!getReqNTIDy(F, reqntidy))
541     reqntidy = 1;
542   else
543     specified = true;
544   if (!getReqNTIDz(F, reqntidz))
545     reqntidz = 1;
546   else
547     specified = true;
548 
549   if (specified)
550     O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
551       << "\n";
552 
553   // If the NVVM IR has some of maxntid* specified, then output
554   // the maxntid directive, and set the unspecified ones to 1.
555   // If none of maxntid* is specified, don't output maxntid directive.
556   unsigned maxntidx, maxntidy, maxntidz;
557   specified = false;
558   if (!getMaxNTIDx(F, maxntidx))
559     maxntidx = 1;
560   else
561     specified = true;
562   if (!getMaxNTIDy(F, maxntidy))
563     maxntidy = 1;
564   else
565     specified = true;
566   if (!getMaxNTIDz(F, maxntidz))
567     maxntidz = 1;
568   else
569     specified = true;
570 
571   if (specified)
572     O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
573       << "\n";
574 
575   unsigned mincta;
576   if (getMinCTASm(F, mincta))
577     O << ".minnctapersm " << mincta << "\n";
578 
579   unsigned maxnreg;
580   if (getMaxNReg(F, maxnreg))
581     O << ".maxnreg " << maxnreg << "\n";
582 }
583 
584 std::string
585 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
586   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
587 
588   std::string Name;
589   raw_string_ostream NameStr(Name);
590 
591   VRegRCMap::const_iterator I = VRegMapping.find(RC);
592   assert(I != VRegMapping.end() && "Bad register class");
593   const DenseMap<unsigned, unsigned> &RegMap = I->second;
594 
595   VRegMap::const_iterator VI = RegMap.find(Reg);
596   assert(VI != RegMap.end() && "Bad virtual register");
597   unsigned MappedVR = VI->second;
598 
599   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
600 
601   NameStr.flush();
602   return Name;
603 }
604 
605 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
606                                           raw_ostream &O) {
607   O << getVirtualRegisterName(vr);
608 }
609 
610 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
611   emitLinkageDirective(F, O);
612   if (isKernelFunction(*F))
613     O << ".entry ";
614   else
615     O << ".func ";
616   printReturnValStr(F, O);
617   getSymbol(F)->print(O, MAI);
618   O << "\n";
619   emitFunctionParamList(F, O);
620   if (shouldEmitPTXNoReturn(F, TM))
621     O << ".noreturn";
622   O << ";\n";
623 }
624 
625 static bool usedInGlobalVarDef(const Constant *C) {
626   if (!C)
627     return false;
628 
629   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
630     return GV->getName() != "llvm.used";
631   }
632 
633   for (const User *U : C->users())
634     if (const Constant *C = dyn_cast<Constant>(U))
635       if (usedInGlobalVarDef(C))
636         return true;
637 
638   return false;
639 }
640 
641 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
642   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
643     if (othergv->getName() == "llvm.used")
644       return true;
645   }
646 
647   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
648     if (instr->getParent() && instr->getParent()->getParent()) {
649       const Function *curFunc = instr->getParent()->getParent();
650       if (oneFunc && (curFunc != oneFunc))
651         return false;
652       oneFunc = curFunc;
653       return true;
654     } else
655       return false;
656   }
657 
658   for (const User *UU : U->users())
659     if (!usedInOneFunc(UU, oneFunc))
660       return false;
661 
662   return true;
663 }
664 
665 /* Find out if a global variable can be demoted to local scope.
666  * Currently, this is valid for CUDA shared variables, which have local
667  * scope and global lifetime. So the conditions to check are :
668  * 1. Is the global variable in shared address space?
669  * 2. Does it have internal linkage?
670  * 3. Is the global variable referenced only in one function?
671  */
672 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
673   if (!gv->hasInternalLinkage())
674     return false;
675   PointerType *Pty = gv->getType();
676   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
677     return false;
678 
679   const Function *oneFunc = nullptr;
680 
681   bool flag = usedInOneFunc(gv, oneFunc);
682   if (!flag)
683     return false;
684   if (!oneFunc)
685     return false;
686   f = oneFunc;
687   return true;
688 }
689 
690 static bool useFuncSeen(const Constant *C,
691                         DenseMap<const Function *, bool> &seenMap) {
692   for (const User *U : C->users()) {
693     if (const Constant *cu = dyn_cast<Constant>(U)) {
694       if (useFuncSeen(cu, seenMap))
695         return true;
696     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
697       const BasicBlock *bb = I->getParent();
698       if (!bb)
699         continue;
700       const Function *caller = bb->getParent();
701       if (!caller)
702         continue;
703       if (seenMap.find(caller) != seenMap.end())
704         return true;
705     }
706   }
707   return false;
708 }
709 
710 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
711   DenseMap<const Function *, bool> seenMap;
712   for (const Function &F : M) {
713     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
714       emitDeclaration(&F, O);
715       continue;
716     }
717 
718     if (F.isDeclaration()) {
719       if (F.use_empty())
720         continue;
721       if (F.getIntrinsicID())
722         continue;
723       emitDeclaration(&F, O);
724       continue;
725     }
726     for (const User *U : F.users()) {
727       if (const Constant *C = dyn_cast<Constant>(U)) {
728         if (usedInGlobalVarDef(C)) {
729           // The use is in the initialization of a global variable
730           // that is a function pointer, so print a declaration
731           // for the original function
732           emitDeclaration(&F, O);
733           break;
734         }
735         // Emit a declaration of this function if the function that
736         // uses this constant expr has already been seen.
737         if (useFuncSeen(C, seenMap)) {
738           emitDeclaration(&F, O);
739           break;
740         }
741       }
742 
743       if (!isa<Instruction>(U))
744         continue;
745       const Instruction *instr = cast<Instruction>(U);
746       const BasicBlock *bb = instr->getParent();
747       if (!bb)
748         continue;
749       const Function *caller = bb->getParent();
750       if (!caller)
751         continue;
752 
753       // If a caller has already been seen, then the caller is
754       // appearing in the module before the callee. so print out
755       // a declaration for the callee.
756       if (seenMap.find(caller) != seenMap.end()) {
757         emitDeclaration(&F, O);
758         break;
759       }
760     }
761     seenMap[&F] = true;
762   }
763 }
764 
765 static bool isEmptyXXStructor(GlobalVariable *GV) {
766   if (!GV) return true;
767   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
768   if (!InitList) return true;  // Not an array; we don't know how to parse.
769   return InitList->getNumOperands() == 0;
770 }
771 
772 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
773   // Construct a default subtarget off of the TargetMachine defaults. The
774   // rest of NVPTX isn't friendly to change subtargets per function and
775   // so the default TargetMachine will have all of the options.
776   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
777   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
778   SmallString<128> Str1;
779   raw_svector_ostream OS1(Str1);
780 
781   // Emit header before any dwarf directives are emitted below.
782   emitHeader(M, OS1, *STI);
783   OutStreamer->emitRawText(OS1.str());
784 }
785 
786 bool NVPTXAsmPrinter::doInitialization(Module &M) {
787   if (M.alias_size()) {
788     report_fatal_error("Module has aliases, which NVPTX does not support.");
789     return true; // error
790   }
791   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
792     report_fatal_error(
793         "Module has a nontrivial global ctor, which NVPTX does not support.");
794     return true;  // error
795   }
796   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
797     report_fatal_error(
798         "Module has a nontrivial global dtor, which NVPTX does not support.");
799     return true;  // error
800   }
801 
802   // We need to call the parent's one explicitly.
803   bool Result = AsmPrinter::doInitialization(M);
804 
805   GlobalsEmitted = false;
806 
807   return Result;
808 }
809 
810 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
811   SmallString<128> Str2;
812   raw_svector_ostream OS2(Str2);
813 
814   emitDeclarations(M, OS2);
815 
816   // As ptxas does not support forward references of globals, we need to first
817   // sort the list of module-level globals in def-use order. We visit each
818   // global variable in order, and ensure that we emit it *after* its dependent
819   // globals. We use a little extra memory maintaining both a set and a list to
820   // have fast searches while maintaining a strict ordering.
821   SmallVector<const GlobalVariable *, 8> Globals;
822   DenseSet<const GlobalVariable *> GVVisited;
823   DenseSet<const GlobalVariable *> GVVisiting;
824 
825   // Visit each global variable, in order
826   for (const GlobalVariable &I : M.globals())
827     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
828 
829   assert(GVVisited.size() == M.getGlobalList().size() &&
830          "Missed a global variable");
831   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
832 
833   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
834   const NVPTXSubtarget &STI =
835       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
836 
837   // Print out module-level global variables in proper order
838   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
839     printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
840 
841   OS2 << '\n';
842 
843   OutStreamer->emitRawText(OS2.str());
844 }
845 
846 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
847                                  const NVPTXSubtarget &STI) {
848   O << "//\n";
849   O << "// Generated by LLVM NVPTX Back-End\n";
850   O << "//\n";
851   O << "\n";
852 
853   unsigned PTXVersion = STI.getPTXVersion();
854   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
855 
856   O << ".target ";
857   O << STI.getTargetName();
858 
859   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860   if (NTM.getDrvInterface() == NVPTX::NVCL)
861     O << ", texmode_independent";
862 
863   bool HasFullDebugInfo = false;
864   for (DICompileUnit *CU : M.debug_compile_units()) {
865     switch(CU->getEmissionKind()) {
866     case DICompileUnit::NoDebug:
867     case DICompileUnit::DebugDirectivesOnly:
868       break;
869     case DICompileUnit::LineTablesOnly:
870     case DICompileUnit::FullDebug:
871       HasFullDebugInfo = true;
872       break;
873     }
874     if (HasFullDebugInfo)
875       break;
876   }
877   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
878     O << ", debug";
879 
880   O << "\n";
881 
882   O << ".address_size ";
883   if (NTM.is64Bit())
884     O << "64";
885   else
886     O << "32";
887   O << "\n";
888 
889   O << "\n";
890 }
891 
892 bool NVPTXAsmPrinter::doFinalization(Module &M) {
893   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
894 
895   // If we did not emit any functions, then the global declarations have not
896   // yet been emitted.
897   if (!GlobalsEmitted) {
898     emitGlobals(M);
899     GlobalsEmitted = true;
900   }
901 
902   // call doFinalization
903   bool ret = AsmPrinter::doFinalization(M);
904 
905   clearAnnotationCache(&M);
906 
907   auto *TS =
908       static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
909   // Close the last emitted section
910   if (HasDebugInfo) {
911     TS->closeLastSection();
912     // Emit empty .debug_loc section for better support of the empty files.
913     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
914   }
915 
916   // Output last DWARF .file directives, if any.
917   TS->outputDwarfFileDirectives();
918 
919   return ret;
920 }
921 
922 // This function emits appropriate linkage directives for
923 // functions and global variables.
924 //
925 // extern function declaration            -> .extern
926 // extern function definition             -> .visible
927 // external global variable with init     -> .visible
928 // external without init                  -> .extern
929 // appending                              -> not allowed, assert.
930 // for any linkage other than
931 // internal, private, linker_private,
932 // linker_private_weak, linker_private_weak_def_auto,
933 // we emit                                -> .weak.
934 
935 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
936                                            raw_ostream &O) {
937   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
938     if (V->hasExternalLinkage()) {
939       if (isa<GlobalVariable>(V)) {
940         const GlobalVariable *GVar = cast<GlobalVariable>(V);
941         if (GVar) {
942           if (GVar->hasInitializer())
943             O << ".visible ";
944           else
945             O << ".extern ";
946         }
947       } else if (V->isDeclaration())
948         O << ".extern ";
949       else
950         O << ".visible ";
951     } else if (V->hasAppendingLinkage()) {
952       std::string msg;
953       msg.append("Error: ");
954       msg.append("Symbol ");
955       if (V->hasName())
956         msg.append(std::string(V->getName()));
957       msg.append("has unsupported appending linkage type");
958       llvm_unreachable(msg.c_str());
959     } else if (!V->hasInternalLinkage() &&
960                !V->hasPrivateLinkage()) {
961       O << ".weak ";
962     }
963   }
964 }
965 
966 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
967                                          raw_ostream &O, bool processDemoted,
968                                          const NVPTXSubtarget &STI) {
969   // Skip meta data
970   if (GVar->hasSection()) {
971     if (GVar->getSection() == "llvm.metadata")
972       return;
973   }
974 
975   // Skip LLVM intrinsic global variables
976   if (GVar->getName().startswith("llvm.") ||
977       GVar->getName().startswith("nvvm."))
978     return;
979 
980   const DataLayout &DL = getDataLayout();
981 
982   // GlobalVariables are always constant pointers themselves.
983   PointerType *PTy = GVar->getType();
984   Type *ETy = GVar->getValueType();
985 
986   if (GVar->hasExternalLinkage()) {
987     if (GVar->hasInitializer())
988       O << ".visible ";
989     else
990       O << ".extern ";
991   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
992              GVar->hasAvailableExternallyLinkage() ||
993              GVar->hasCommonLinkage()) {
994     O << ".weak ";
995   }
996 
997   if (isTexture(*GVar)) {
998     O << ".global .texref " << getTextureName(*GVar) << ";\n";
999     return;
1000   }
1001 
1002   if (isSurface(*GVar)) {
1003     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1004     return;
1005   }
1006 
1007   if (GVar->isDeclaration()) {
1008     // (extern) declarations, no definition or initializer
1009     // Currently the only known declaration is for an automatic __local
1010     // (.shared) promoted to global.
1011     emitPTXGlobalVariable(GVar, O, STI);
1012     O << ";\n";
1013     return;
1014   }
1015 
1016   if (isSampler(*GVar)) {
1017     O << ".global .samplerref " << getSamplerName(*GVar);
1018 
1019     const Constant *Initializer = nullptr;
1020     if (GVar->hasInitializer())
1021       Initializer = GVar->getInitializer();
1022     const ConstantInt *CI = nullptr;
1023     if (Initializer)
1024       CI = dyn_cast<ConstantInt>(Initializer);
1025     if (CI) {
1026       unsigned sample = CI->getZExtValue();
1027 
1028       O << " = { ";
1029 
1030       for (int i = 0,
1031                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1032            i < 3; i++) {
1033         O << "addr_mode_" << i << " = ";
1034         switch (addr) {
1035         case 0:
1036           O << "wrap";
1037           break;
1038         case 1:
1039           O << "clamp_to_border";
1040           break;
1041         case 2:
1042           O << "clamp_to_edge";
1043           break;
1044         case 3:
1045           O << "wrap";
1046           break;
1047         case 4:
1048           O << "mirror";
1049           break;
1050         }
1051         O << ", ";
1052       }
1053       O << "filter_mode = ";
1054       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1055       case 0:
1056         O << "nearest";
1057         break;
1058       case 1:
1059         O << "linear";
1060         break;
1061       case 2:
1062         llvm_unreachable("Anisotropic filtering is not supported");
1063       default:
1064         O << "nearest";
1065         break;
1066       }
1067       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1068         O << ", force_unnormalized_coords = 1";
1069       }
1070       O << " }";
1071     }
1072 
1073     O << ";\n";
1074     return;
1075   }
1076 
1077   if (GVar->hasPrivateLinkage()) {
1078     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1079       return;
1080 
1081     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1082     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1083       return;
1084     if (GVar->use_empty())
1085       return;
1086   }
1087 
1088   const Function *demotedFunc = nullptr;
1089   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1090     O << "// " << GVar->getName() << " has been demoted\n";
1091     if (localDecls.find(demotedFunc) != localDecls.end())
1092       localDecls[demotedFunc].push_back(GVar);
1093     else {
1094       std::vector<const GlobalVariable *> temp;
1095       temp.push_back(GVar);
1096       localDecls[demotedFunc] = temp;
1097     }
1098     return;
1099   }
1100 
1101   O << ".";
1102   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1103 
1104   if (isManaged(*GVar)) {
1105     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1106       report_fatal_error(
1107           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1108     }
1109     O << " .attribute(.managed)";
1110   }
1111 
1112   if (MaybeAlign A = GVar->getAlign())
1113     O << " .align " << A->value();
1114   else
1115     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1116 
1117   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1118       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1119     O << " .";
1120     // Special case: ABI requires that we use .u8 for predicates
1121     if (ETy->isIntegerTy(1))
1122       O << "u8";
1123     else
1124       O << getPTXFundamentalTypeStr(ETy, false);
1125     O << " ";
1126     getSymbol(GVar)->print(O, MAI);
1127 
1128     // Ptx allows variable initilization only for constant and global state
1129     // spaces.
1130     if (GVar->hasInitializer()) {
1131       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1132           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1133         const Constant *Initializer = GVar->getInitializer();
1134         // 'undef' is treated as there is no value specified.
1135         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1136           O << " = ";
1137           printScalarConstant(Initializer, O);
1138         }
1139       } else {
1140         // The frontend adds zero-initializer to device and constant variables
1141         // that don't have an initial value, and UndefValue to shared
1142         // variables, so skip warning for this case.
1143         if (!GVar->getInitializer()->isNullValue() &&
1144             !isa<UndefValue>(GVar->getInitializer())) {
1145           report_fatal_error("initial value of '" + GVar->getName() +
1146                              "' is not allowed in addrspace(" +
1147                              Twine(PTy->getAddressSpace()) + ")");
1148         }
1149       }
1150     }
1151   } else {
1152     unsigned int ElementSize = 0;
1153 
1154     // Although PTX has direct support for struct type and array type and
1155     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1156     // targets that support these high level field accesses. Structs, arrays
1157     // and vectors are lowered into arrays of bytes.
1158     switch (ETy->getTypeID()) {
1159     case Type::IntegerTyID: // Integers larger than 64 bits
1160     case Type::StructTyID:
1161     case Type::ArrayTyID:
1162     case Type::FixedVectorTyID:
1163       ElementSize = DL.getTypeStoreSize(ETy);
1164       // Ptx allows variable initilization only for constant and
1165       // global state spaces.
1166       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1167            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1168           GVar->hasInitializer()) {
1169         const Constant *Initializer = GVar->getInitializer();
1170         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1171           AggBuffer aggBuffer(ElementSize, *this);
1172           bufferAggregateConstant(Initializer, &aggBuffer);
1173           if (aggBuffer.numSymbols()) {
1174             unsigned int ptrSize = MAI->getCodePointerSize();
1175             if (ElementSize % ptrSize ||
1176                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1177               // Print in bytes and use the mask() operator for pointers.
1178               if (!STI.hasMaskOperator())
1179                 report_fatal_error(
1180                     "initialized packed aggregate with pointers '" +
1181                     GVar->getName() +
1182                     "' requires at least PTX ISA version 7.1");
1183               O << " .u8 ";
1184               getSymbol(GVar)->print(O, MAI);
1185               O << "[" << ElementSize << "] = {";
1186               aggBuffer.printBytes(O);
1187               O << "}";
1188             } else {
1189               O << " .u" << ptrSize * 8 << " ";
1190               getSymbol(GVar)->print(O, MAI);
1191               O << "[" << ElementSize / ptrSize << "] = {";
1192               aggBuffer.printWords(O);
1193               O << "}";
1194             }
1195           } else {
1196             O << " .b8 ";
1197             getSymbol(GVar)->print(O, MAI);
1198             O << "[" << ElementSize << "] = {";
1199             aggBuffer.printBytes(O);
1200             O << "}";
1201           }
1202         } else {
1203           O << " .b8 ";
1204           getSymbol(GVar)->print(O, MAI);
1205           if (ElementSize) {
1206             O << "[";
1207             O << ElementSize;
1208             O << "]";
1209           }
1210         }
1211       } else {
1212         O << " .b8 ";
1213         getSymbol(GVar)->print(O, MAI);
1214         if (ElementSize) {
1215           O << "[";
1216           O << ElementSize;
1217           O << "]";
1218         }
1219       }
1220       break;
1221     default:
1222       llvm_unreachable("type not supported yet");
1223     }
1224   }
1225   O << ";\n";
1226 }
1227 
1228 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1229   const Value *v = Symbols[nSym];
1230   const Value *v0 = SymbolsBeforeStripping[nSym];
1231   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1232     MCSymbol *Name = AP.getSymbol(GVar);
1233     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1234     // Is v0 a generic pointer?
1235     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1236     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1237       os << "generic(";
1238       Name->print(os, AP.MAI);
1239       os << ")";
1240     } else {
1241       Name->print(os, AP.MAI);
1242     }
1243   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1244     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1245     AP.printMCExpr(*Expr, os);
1246   } else
1247     llvm_unreachable("symbol type unknown");
1248 }
1249 
1250 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1251   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1252   symbolPosInBuffer.push_back(size);
1253   unsigned int nSym = 0;
1254   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1255   for (unsigned int pos = 0; pos < size;) {
1256     if (pos)
1257       os << ", ";
1258     if (pos != nextSymbolPos) {
1259       os << (unsigned int)buffer[pos];
1260       ++pos;
1261       continue;
1262     }
1263     // Generate a per-byte mask() operator for the symbol, which looks like:
1264     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1265     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1266     std::string symText;
1267     llvm::raw_string_ostream oss(symText);
1268     printSymbol(nSym, oss);
1269     for (unsigned i = 0; i < ptrSize; ++i) {
1270       if (i)
1271         os << ", ";
1272       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1273       os << "(" << symText << ")";
1274     }
1275     pos += ptrSize;
1276     nextSymbolPos = symbolPosInBuffer[++nSym];
1277     assert(nextSymbolPos >= pos);
1278   }
1279 }
1280 
1281 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1282   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1283   symbolPosInBuffer.push_back(size);
1284   unsigned int nSym = 0;
1285   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1286   assert(nextSymbolPos % ptrSize == 0);
1287   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1288     if (pos)
1289       os << ", ";
1290     if (pos == nextSymbolPos) {
1291       printSymbol(nSym, os);
1292       nextSymbolPos = symbolPosInBuffer[++nSym];
1293       assert(nextSymbolPos % ptrSize == 0);
1294       assert(nextSymbolPos >= pos + ptrSize);
1295     } else if (ptrSize == 4)
1296       os << support::endian::read32le(&buffer[pos]);
1297     else
1298       os << support::endian::read64le(&buffer[pos]);
1299   }
1300 }
1301 
1302 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1303   if (localDecls.find(f) == localDecls.end())
1304     return;
1305 
1306   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1307 
1308   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1309   const NVPTXSubtarget &STI =
1310       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1311 
1312   for (const GlobalVariable *GV : gvars) {
1313     O << "\t// demoted variable\n\t";
1314     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1315   }
1316 }
1317 
1318 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1319                                           raw_ostream &O) const {
1320   switch (AddressSpace) {
1321   case ADDRESS_SPACE_LOCAL:
1322     O << "local";
1323     break;
1324   case ADDRESS_SPACE_GLOBAL:
1325     O << "global";
1326     break;
1327   case ADDRESS_SPACE_CONST:
1328     O << "const";
1329     break;
1330   case ADDRESS_SPACE_SHARED:
1331     O << "shared";
1332     break;
1333   default:
1334     report_fatal_error("Bad address space found while emitting PTX: " +
1335                        llvm::Twine(AddressSpace));
1336     break;
1337   }
1338 }
1339 
1340 std::string
1341 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1342   switch (Ty->getTypeID()) {
1343   case Type::IntegerTyID: {
1344     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1345     if (NumBits == 1)
1346       return "pred";
1347     else if (NumBits <= 64) {
1348       std::string name = "u";
1349       return name + utostr(NumBits);
1350     } else {
1351       llvm_unreachable("Integer too large");
1352       break;
1353     }
1354     break;
1355   }
1356   case Type::HalfTyID:
1357     // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1358     return "b16";
1359   case Type::FloatTyID:
1360     return "f32";
1361   case Type::DoubleTyID:
1362     return "f64";
1363   case Type::PointerTyID: {
1364     unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1365     assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1366 
1367     if (PtrSize == 64)
1368       if (useB4PTR)
1369         return "b64";
1370       else
1371         return "u64";
1372     else if (useB4PTR)
1373       return "b32";
1374     else
1375       return "u32";
1376   }
1377   default:
1378     break;
1379   }
1380   llvm_unreachable("unexpected type");
1381 }
1382 
1383 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1384                                             raw_ostream &O,
1385                                             const NVPTXSubtarget &STI) {
1386   const DataLayout &DL = getDataLayout();
1387 
1388   // GlobalVariables are always constant pointers themselves.
1389   Type *ETy = GVar->getValueType();
1390 
1391   O << ".";
1392   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1393   if (isManaged(*GVar)) {
1394     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1395       report_fatal_error(
1396           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1397     }
1398     O << " .attribute(.managed)";
1399   }
1400   if (MaybeAlign A = GVar->getAlign())
1401     O << " .align " << A->value();
1402   else
1403     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1404 
1405   // Special case for i128
1406   if (ETy->isIntegerTy(128)) {
1407     O << " .b8 ";
1408     getSymbol(GVar)->print(O, MAI);
1409     O << "[16]";
1410     return;
1411   }
1412 
1413   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1414     O << " .";
1415     O << getPTXFundamentalTypeStr(ETy);
1416     O << " ";
1417     getSymbol(GVar)->print(O, MAI);
1418     return;
1419   }
1420 
1421   int64_t ElementSize = 0;
1422 
1423   // Although PTX has direct support for struct type and array type and LLVM IR
1424   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1425   // support these high level field accesses. Structs and arrays are lowered
1426   // into arrays of bytes.
1427   switch (ETy->getTypeID()) {
1428   case Type::StructTyID:
1429   case Type::ArrayTyID:
1430   case Type::FixedVectorTyID:
1431     ElementSize = DL.getTypeStoreSize(ETy);
1432     O << " .b8 ";
1433     getSymbol(GVar)->print(O, MAI);
1434     O << "[";
1435     if (ElementSize) {
1436       O << ElementSize;
1437     }
1438     O << "]";
1439     break;
1440   default:
1441     llvm_unreachable("type not supported yet");
1442   }
1443 }
1444 
1445 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1446                                      int paramIndex, raw_ostream &O) {
1447   getSymbol(I->getParent())->print(O, MAI);
1448   O << "_param_" << paramIndex;
1449 }
1450 
1451 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1452   const DataLayout &DL = getDataLayout();
1453   const AttributeList &PAL = F->getAttributes();
1454   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1455   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1456 
1457   Function::const_arg_iterator I, E;
1458   unsigned paramIndex = 0;
1459   bool first = true;
1460   bool isKernelFunc = isKernelFunction(*F);
1461   bool isABI = (STI.getSmVersion() >= 20);
1462   bool hasImageHandles = STI.hasImageHandles();
1463 
1464   if (F->arg_empty() && !F->isVarArg()) {
1465     O << "()\n";
1466     return;
1467   }
1468 
1469   O << "(\n";
1470 
1471   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1472     Type *Ty = I->getType();
1473 
1474     if (!first)
1475       O << ",\n";
1476 
1477     first = false;
1478 
1479     // Handle image/sampler parameters
1480     if (isKernelFunction(*F)) {
1481       if (isSampler(*I) || isImage(*I)) {
1482         if (isImage(*I)) {
1483           std::string sname = std::string(I->getName());
1484           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1485             if (hasImageHandles)
1486               O << "\t.param .u64 .ptr .surfref ";
1487             else
1488               O << "\t.param .surfref ";
1489             CurrentFnSym->print(O, MAI);
1490             O << "_param_" << paramIndex;
1491           }
1492           else { // Default image is read_only
1493             if (hasImageHandles)
1494               O << "\t.param .u64 .ptr .texref ";
1495             else
1496               O << "\t.param .texref ";
1497             CurrentFnSym->print(O, MAI);
1498             O << "_param_" << paramIndex;
1499           }
1500         } else {
1501           if (hasImageHandles)
1502             O << "\t.param .u64 .ptr .samplerref ";
1503           else
1504             O << "\t.param .samplerref ";
1505           CurrentFnSym->print(O, MAI);
1506           O << "_param_" << paramIndex;
1507         }
1508         continue;
1509       }
1510     }
1511 
1512     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1513                                     paramIndex](Type *Ty) -> Align {
1514       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1515       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1516       return std::max(TypeAlign, ParamAlign.valueOrOne());
1517     };
1518 
1519     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1520       if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1521         // Just print .param .align <a> .b8 .param[size];
1522         // <a>  = optimal alignment for the element type; always multiple of
1523         //        PAL.getParamAlignment
1524         // size = typeallocsize of element type
1525         Align OptimalAlign = getOptimalAlignForParam(Ty);
1526 
1527         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1528         printParamName(I, paramIndex, O);
1529         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1530 
1531         continue;
1532       }
1533       // Just a scalar
1534       auto *PTy = dyn_cast<PointerType>(Ty);
1535       unsigned PTySizeInBits = 0;
1536       if (PTy) {
1537         PTySizeInBits =
1538             TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1539         assert(PTySizeInBits && "Invalid pointer size");
1540       }
1541 
1542       if (isKernelFunc) {
1543         if (PTy) {
1544           // Special handling for pointer arguments to kernel
1545           O << "\t.param .u" << PTySizeInBits << " ";
1546 
1547           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1548               NVPTX::CUDA) {
1549             int addrSpace = PTy->getAddressSpace();
1550             switch (addrSpace) {
1551             default:
1552               O << ".ptr ";
1553               break;
1554             case ADDRESS_SPACE_CONST:
1555               O << ".ptr .const ";
1556               break;
1557             case ADDRESS_SPACE_SHARED:
1558               O << ".ptr .shared ";
1559               break;
1560             case ADDRESS_SPACE_GLOBAL:
1561               O << ".ptr .global ";
1562               break;
1563             }
1564             Align ParamAlign = I->getParamAlign().valueOrOne();
1565             O << ".align " << ParamAlign.value() << " ";
1566           }
1567           printParamName(I, paramIndex, O);
1568           continue;
1569         }
1570 
1571         // non-pointer scalar to kernel func
1572         O << "\t.param .";
1573         // Special case: predicate operands become .u8 types
1574         if (Ty->isIntegerTy(1))
1575           O << "u8";
1576         else
1577           O << getPTXFundamentalTypeStr(Ty);
1578         O << " ";
1579         printParamName(I, paramIndex, O);
1580         continue;
1581       }
1582       // Non-kernel function, just print .param .b<size> for ABI
1583       // and .reg .b<size> for non-ABI
1584       unsigned sz = 0;
1585       if (isa<IntegerType>(Ty)) {
1586         sz = cast<IntegerType>(Ty)->getBitWidth();
1587         sz = promoteScalarArgumentSize(sz);
1588       } else if (PTy) {
1589         assert(PTySizeInBits && "Invalid pointer size");
1590         sz = PTySizeInBits;
1591       } else if (Ty->isHalfTy())
1592         // PTX ABI requires all scalar parameters to be at least 32
1593         // bits in size.  fp16 normally uses .b16 as its storage type
1594         // in PTX, so its size must be adjusted here, too.
1595         sz = 32;
1596       else
1597         sz = Ty->getPrimitiveSizeInBits();
1598       if (isABI)
1599         O << "\t.param .b" << sz << " ";
1600       else
1601         O << "\t.reg .b" << sz << " ";
1602       printParamName(I, paramIndex, O);
1603       continue;
1604     }
1605 
1606     // param has byVal attribute.
1607     Type *ETy = PAL.getParamByValType(paramIndex);
1608     assert(ETy && "Param should have byval type");
1609 
1610     if (isABI || isKernelFunc) {
1611       // Just print .param .align <a> .b8 .param[size];
1612       // <a>  = optimal alignment for the element type; always multiple of
1613       //        PAL.getParamAlignment
1614       // size = typeallocsize of element type
1615       Align OptimalAlign =
1616           isKernelFunc
1617               ? getOptimalAlignForParam(ETy)
1618               : TLI->getFunctionByValParamAlign(
1619                     F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1620 
1621       unsigned sz = DL.getTypeAllocSize(ETy);
1622       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1623       printParamName(I, paramIndex, O);
1624       O << "[" << sz << "]";
1625       continue;
1626     } else {
1627       // Split the ETy into constituent parts and
1628       // print .param .b<size> <name> for each part.
1629       // Further, if a part is vector, print the above for
1630       // each vector element.
1631       SmallVector<EVT, 16> vtparts;
1632       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1633       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1634         unsigned elems = 1;
1635         EVT elemtype = vtparts[i];
1636         if (vtparts[i].isVector()) {
1637           elems = vtparts[i].getVectorNumElements();
1638           elemtype = vtparts[i].getVectorElementType();
1639         }
1640 
1641         for (unsigned j = 0, je = elems; j != je; ++j) {
1642           unsigned sz = elemtype.getSizeInBits();
1643           if (elemtype.isInteger())
1644             sz = promoteScalarArgumentSize(sz);
1645           O << "\t.reg .b" << sz << " ";
1646           printParamName(I, paramIndex, O);
1647           if (j < je - 1)
1648             O << ",\n";
1649           ++paramIndex;
1650         }
1651         if (i < e - 1)
1652           O << ",\n";
1653       }
1654       --paramIndex;
1655       continue;
1656     }
1657   }
1658 
1659   if (F->isVarArg()) {
1660     if (!first)
1661       O << ",\n";
1662     O << "\t.param .align " << STI.getMaxRequiredAlignment();
1663     O << " .b8 ";
1664     getSymbol(F)->print(O, MAI);
1665     O << "_vararg[]";
1666   }
1667 
1668   O << "\n)\n";
1669 }
1670 
1671 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1672                                             raw_ostream &O) {
1673   const Function &F = MF.getFunction();
1674   emitFunctionParamList(&F, O);
1675 }
1676 
1677 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1678     const MachineFunction &MF) {
1679   SmallString<128> Str;
1680   raw_svector_ostream O(Str);
1681 
1682   // Map the global virtual register number to a register class specific
1683   // virtual register number starting from 1 with that class.
1684   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1685   //unsigned numRegClasses = TRI->getNumRegClasses();
1686 
1687   // Emit the Fake Stack Object
1688   const MachineFrameInfo &MFI = MF.getFrameInfo();
1689   int NumBytes = (int) MFI.getStackSize();
1690   if (NumBytes) {
1691     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1692       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1693     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1694       O << "\t.reg .b64 \t%SP;\n";
1695       O << "\t.reg .b64 \t%SPL;\n";
1696     } else {
1697       O << "\t.reg .b32 \t%SP;\n";
1698       O << "\t.reg .b32 \t%SPL;\n";
1699     }
1700   }
1701 
1702   // Go through all virtual registers to establish the mapping between the
1703   // global virtual
1704   // register number and the per class virtual register number.
1705   // We use the per class virtual register number in the ptx output.
1706   unsigned int numVRs = MRI->getNumVirtRegs();
1707   for (unsigned i = 0; i < numVRs; i++) {
1708     Register vr = Register::index2VirtReg(i);
1709     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1710     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1711     int n = regmap.size();
1712     regmap.insert(std::make_pair(vr, n + 1));
1713   }
1714 
1715   // Emit register declarations
1716   // @TODO: Extract out the real register usage
1717   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1718   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1719   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1720   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1721   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1722   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1723   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1724 
1725   // Emit declaration of the virtual registers or 'physical' registers for
1726   // each register class
1727   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1728     const TargetRegisterClass *RC = TRI->getRegClass(i);
1729     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1730     std::string rcname = getNVPTXRegClassName(RC);
1731     std::string rcStr = getNVPTXRegClassStr(RC);
1732     int n = regmap.size();
1733 
1734     // Only declare those registers that may be used.
1735     if (n) {
1736        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1737          << ">;\n";
1738     }
1739   }
1740 
1741   OutStreamer->emitRawText(O.str());
1742 }
1743 
1744 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1745   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1746   bool ignored;
1747   unsigned int numHex;
1748   const char *lead;
1749 
1750   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1751     numHex = 8;
1752     lead = "0f";
1753     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1754   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1755     numHex = 16;
1756     lead = "0d";
1757     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1758   } else
1759     llvm_unreachable("unsupported fp type");
1760 
1761   APInt API = APF.bitcastToAPInt();
1762   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1763 }
1764 
1765 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1766   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1767     O << CI->getValue();
1768     return;
1769   }
1770   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1771     printFPConstant(CFP, O);
1772     return;
1773   }
1774   if (isa<ConstantPointerNull>(CPV)) {
1775     O << "0";
1776     return;
1777   }
1778   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1779     bool IsNonGenericPointer = false;
1780     if (GVar->getType()->getAddressSpace() != 0) {
1781       IsNonGenericPointer = true;
1782     }
1783     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1784       O << "generic(";
1785       getSymbol(GVar)->print(O, MAI);
1786       O << ")";
1787     } else {
1788       getSymbol(GVar)->print(O, MAI);
1789     }
1790     return;
1791   }
1792   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1793     const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1794     printMCExpr(*E, O);
1795     return;
1796   }
1797   llvm_unreachable("Not scalar type found in printScalarConstant()");
1798 }
1799 
1800 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1801                                    AggBuffer *AggBuffer) {
1802   const DataLayout &DL = getDataLayout();
1803   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1804   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1805     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1806     // only the space allocated by CPV.
1807     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1808     return;
1809   }
1810 
1811   // Helper for filling AggBuffer with APInts.
1812   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1813     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1814     SmallVector<unsigned char, 16> Buf(NumBytes);
1815     for (unsigned I = 0; I < NumBytes; ++I) {
1816       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1817     }
1818     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1819   };
1820 
1821   switch (CPV->getType()->getTypeID()) {
1822   case Type::IntegerTyID:
1823     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1824       AddIntToBuffer(CI->getValue());
1825       break;
1826     }
1827     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828       if (const auto *CI =
1829               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1830         AddIntToBuffer(CI->getValue());
1831         break;
1832       }
1833       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1834         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1835         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1836         AggBuffer->addZeros(AllocSize);
1837         break;
1838       }
1839     }
1840     llvm_unreachable("unsupported integer const type");
1841     break;
1842 
1843   case Type::HalfTyID:
1844   case Type::BFloatTyID:
1845   case Type::FloatTyID:
1846   case Type::DoubleTyID:
1847     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1848     break;
1849 
1850   case Type::PointerTyID: {
1851     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1852       AggBuffer->addSymbol(GVar, GVar);
1853     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1854       const Value *v = Cexpr->stripPointerCasts();
1855       AggBuffer->addSymbol(v, Cexpr);
1856     }
1857     AggBuffer->addZeros(AllocSize);
1858     break;
1859   }
1860 
1861   case Type::ArrayTyID:
1862   case Type::FixedVectorTyID:
1863   case Type::StructTyID: {
1864     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1865       bufferAggregateConstant(CPV, AggBuffer);
1866       if (Bytes > AllocSize)
1867         AggBuffer->addZeros(Bytes - AllocSize);
1868     } else if (isa<ConstantAggregateZero>(CPV))
1869       AggBuffer->addZeros(Bytes);
1870     else
1871       llvm_unreachable("Unexpected Constant type");
1872     break;
1873   }
1874 
1875   default:
1876     llvm_unreachable("unsupported type");
1877   }
1878 }
1879 
1880 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1881                                               AggBuffer *aggBuffer) {
1882   const DataLayout &DL = getDataLayout();
1883   int Bytes;
1884 
1885   // Integers of arbitrary width
1886   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1887     APInt Val = CI->getValue();
1888     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1889       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1890       aggBuffer->addBytes(&Byte, 1, 1);
1891       Val.lshrInPlace(8);
1892     }
1893     return;
1894   }
1895 
1896   // Old constants
1897   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1898     if (CPV->getNumOperands())
1899       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1900         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1901     return;
1902   }
1903 
1904   if (const ConstantDataSequential *CDS =
1905           dyn_cast<ConstantDataSequential>(CPV)) {
1906     if (CDS->getNumElements())
1907       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1908         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1909                      aggBuffer);
1910     return;
1911   }
1912 
1913   if (isa<ConstantStruct>(CPV)) {
1914     if (CPV->getNumOperands()) {
1915       StructType *ST = cast<StructType>(CPV->getType());
1916       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1917         if (i == (e - 1))
1918           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1919                   DL.getTypeAllocSize(ST) -
1920                   DL.getStructLayout(ST)->getElementOffset(i);
1921         else
1922           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1923                   DL.getStructLayout(ST)->getElementOffset(i);
1924         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1925       }
1926     }
1927     return;
1928   }
1929   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1930 }
1931 
1932 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1933 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1934 /// expressions that are representable in PTX and create
1935 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1936 const MCExpr *
1937 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1938   MCContext &Ctx = OutContext;
1939 
1940   if (CV->isNullValue() || isa<UndefValue>(CV))
1941     return MCConstantExpr::create(0, Ctx);
1942 
1943   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1944     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1945 
1946   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1947     const MCSymbolRefExpr *Expr =
1948       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1949     if (ProcessingGeneric) {
1950       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1951     } else {
1952       return Expr;
1953     }
1954   }
1955 
1956   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1957   if (!CE) {
1958     llvm_unreachable("Unknown constant value to lower!");
1959   }
1960 
1961   switch (CE->getOpcode()) {
1962   default: {
1963     // If the code isn't optimized, there may be outstanding folding
1964     // opportunities. Attempt to fold the expression using DataLayout as a
1965     // last resort before giving up.
1966     Constant *C = ConstantFoldConstant(CE, getDataLayout());
1967     if (C != CE)
1968       return lowerConstantForGV(C, ProcessingGeneric);
1969 
1970     // Otherwise report the problem to the user.
1971     std::string S;
1972     raw_string_ostream OS(S);
1973     OS << "Unsupported expression in static initializer: ";
1974     CE->printAsOperand(OS, /*PrintType=*/false,
1975                    !MF ? nullptr : MF->getFunction().getParent());
1976     report_fatal_error(Twine(OS.str()));
1977   }
1978 
1979   case Instruction::AddrSpaceCast: {
1980     // Strip the addrspacecast and pass along the operand
1981     PointerType *DstTy = cast<PointerType>(CE->getType());
1982     if (DstTy->getAddressSpace() == 0) {
1983       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1984     }
1985     std::string S;
1986     raw_string_ostream OS(S);
1987     OS << "Unsupported expression in static initializer: ";
1988     CE->printAsOperand(OS, /*PrintType=*/ false,
1989                        !MF ? nullptr : MF->getFunction().getParent());
1990     report_fatal_error(Twine(OS.str()));
1991   }
1992 
1993   case Instruction::GetElementPtr: {
1994     const DataLayout &DL = getDataLayout();
1995 
1996     // Generate a symbolic expression for the byte address
1997     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1998     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
1999 
2000     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2001                                             ProcessingGeneric);
2002     if (!OffsetAI)
2003       return Base;
2004 
2005     int64_t Offset = OffsetAI.getSExtValue();
2006     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2007                                    Ctx);
2008   }
2009 
2010   case Instruction::Trunc:
2011     // We emit the value and depend on the assembler to truncate the generated
2012     // expression properly.  This is important for differences between
2013     // blockaddress labels.  Since the two labels are in the same function, it
2014     // is reasonable to treat their delta as a 32-bit value.
2015     [[fallthrough]];
2016   case Instruction::BitCast:
2017     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2018 
2019   case Instruction::IntToPtr: {
2020     const DataLayout &DL = getDataLayout();
2021 
2022     // Handle casts to pointers by changing them into casts to the appropriate
2023     // integer type.  This promotes constant folding and simplifies this code.
2024     Constant *Op = CE->getOperand(0);
2025     Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2026                                       false/*ZExt*/);
2027     return lowerConstantForGV(Op, ProcessingGeneric);
2028   }
2029 
2030   case Instruction::PtrToInt: {
2031     const DataLayout &DL = getDataLayout();
2032 
2033     // Support only foldable casts to/from pointers that can be eliminated by
2034     // changing the pointer to the appropriately sized integer type.
2035     Constant *Op = CE->getOperand(0);
2036     Type *Ty = CE->getType();
2037 
2038     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2039 
2040     // We can emit the pointer value into this slot if the slot is an
2041     // integer slot equal to the size of the pointer.
2042     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2043       return OpExpr;
2044 
2045     // Otherwise the pointer is smaller than the resultant integer, mask off
2046     // the high bits so we are sure to get a proper truncation if the input is
2047     // a constant expr.
2048     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2049     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2050     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2051   }
2052 
2053   // The MC library also has a right-shift operator, but it isn't consistently
2054   // signed or unsigned between different targets.
2055   case Instruction::Add: {
2056     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2057     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2058     switch (CE->getOpcode()) {
2059     default: llvm_unreachable("Unknown binary operator constant cast expr");
2060     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2061     }
2062   }
2063   }
2064 }
2065 
2066 // Copy of MCExpr::print customized for NVPTX
2067 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2068   switch (Expr.getKind()) {
2069   case MCExpr::Target:
2070     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2071   case MCExpr::Constant:
2072     OS << cast<MCConstantExpr>(Expr).getValue();
2073     return;
2074 
2075   case MCExpr::SymbolRef: {
2076     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2077     const MCSymbol &Sym = SRE.getSymbol();
2078     Sym.print(OS, MAI);
2079     return;
2080   }
2081 
2082   case MCExpr::Unary: {
2083     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2084     switch (UE.getOpcode()) {
2085     case MCUnaryExpr::LNot:  OS << '!'; break;
2086     case MCUnaryExpr::Minus: OS << '-'; break;
2087     case MCUnaryExpr::Not:   OS << '~'; break;
2088     case MCUnaryExpr::Plus:  OS << '+'; break;
2089     }
2090     printMCExpr(*UE.getSubExpr(), OS);
2091     return;
2092   }
2093 
2094   case MCExpr::Binary: {
2095     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2096 
2097     // Only print parens around the LHS if it is non-trivial.
2098     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2099         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2100       printMCExpr(*BE.getLHS(), OS);
2101     } else {
2102       OS << '(';
2103       printMCExpr(*BE.getLHS(), OS);
2104       OS<< ')';
2105     }
2106 
2107     switch (BE.getOpcode()) {
2108     case MCBinaryExpr::Add:
2109       // Print "X-42" instead of "X+-42".
2110       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2111         if (RHSC->getValue() < 0) {
2112           OS << RHSC->getValue();
2113           return;
2114         }
2115       }
2116 
2117       OS <<  '+';
2118       break;
2119     default: llvm_unreachable("Unhandled binary operator");
2120     }
2121 
2122     // Only print parens around the LHS if it is non-trivial.
2123     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2124       printMCExpr(*BE.getRHS(), OS);
2125     } else {
2126       OS << '(';
2127       printMCExpr(*BE.getRHS(), OS);
2128       OS << ')';
2129     }
2130     return;
2131   }
2132   }
2133 
2134   llvm_unreachable("Invalid expression kind!");
2135 }
2136 
2137 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2138 ///
2139 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2140                                       const char *ExtraCode, raw_ostream &O) {
2141   if (ExtraCode && ExtraCode[0]) {
2142     if (ExtraCode[1] != 0)
2143       return true; // Unknown modifier.
2144 
2145     switch (ExtraCode[0]) {
2146     default:
2147       // See if this is a generic print operand
2148       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2149     case 'r':
2150       break;
2151     }
2152   }
2153 
2154   printOperand(MI, OpNo, O);
2155 
2156   return false;
2157 }
2158 
2159 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2160                                             unsigned OpNo,
2161                                             const char *ExtraCode,
2162                                             raw_ostream &O) {
2163   if (ExtraCode && ExtraCode[0])
2164     return true; // Unknown modifier
2165 
2166   O << '[';
2167   printMemOperand(MI, OpNo, O);
2168   O << ']';
2169 
2170   return false;
2171 }
2172 
2173 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2174                                    raw_ostream &O) {
2175   const MachineOperand &MO = MI->getOperand(opNum);
2176   switch (MO.getType()) {
2177   case MachineOperand::MO_Register:
2178     if (MO.getReg().isPhysical()) {
2179       if (MO.getReg() == NVPTX::VRDepot)
2180         O << DEPOTNAME << getFunctionNumber();
2181       else
2182         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2183     } else {
2184       emitVirtualRegister(MO.getReg(), O);
2185     }
2186     break;
2187 
2188   case MachineOperand::MO_Immediate:
2189     O << MO.getImm();
2190     break;
2191 
2192   case MachineOperand::MO_FPImmediate:
2193     printFPConstant(MO.getFPImm(), O);
2194     break;
2195 
2196   case MachineOperand::MO_GlobalAddress:
2197     PrintSymbolOperand(MO, O);
2198     break;
2199 
2200   case MachineOperand::MO_MachineBasicBlock:
2201     MO.getMBB()->getSymbol()->print(O, MAI);
2202     break;
2203 
2204   default:
2205     llvm_unreachable("Operand type not supported.");
2206   }
2207 }
2208 
2209 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2210                                       raw_ostream &O, const char *Modifier) {
2211   printOperand(MI, opNum, O);
2212 
2213   if (Modifier && strcmp(Modifier, "add") == 0) {
2214     O << ", ";
2215     printOperand(MI, opNum + 1, O);
2216   } else {
2217     if (MI->getOperand(opNum + 1).isImm() &&
2218         MI->getOperand(opNum + 1).getImm() == 0)
2219       return; // don't print ',0' or '+0'
2220     O << "+";
2221     printOperand(MI, opNum + 1, O);
2222   }
2223 }
2224 
2225 // Force static initialization.
2226 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2227   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2228   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2229 }
2230