xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (revision 51015e6d0f570239b0c2088dc6cf2b018928375d)
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MachineValueType.h"
79 #include "llvm/Support/NativeFormatting.h"
80 #include "llvm/Support/Path.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92 
93 using namespace llvm;
94 
95 #define DEPOTNAME "__local_depot"
96 
97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
98 /// depends.
99 static void
100 DiscoverDependentGlobals(const Value *V,
101                          DenseSet<const GlobalVariable *> &Globals) {
102   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
103     Globals.insert(GV);
104   else {
105     if (const User *U = dyn_cast<User>(V)) {
106       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
107         DiscoverDependentGlobals(U->getOperand(i), Globals);
108       }
109     }
110   }
111 }
112 
113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114 /// instances to be emitted, but only after any dependents have been added
115 /// first.s
116 static void
117 VisitGlobalVariableForEmission(const GlobalVariable *GV,
118                                SmallVectorImpl<const GlobalVariable *> &Order,
119                                DenseSet<const GlobalVariable *> &Visited,
120                                DenseSet<const GlobalVariable *> &Visiting) {
121   // Have we already visited this one?
122   if (Visited.count(GV))
123     return;
124 
125   // Do we have a circular dependency?
126   if (!Visiting.insert(GV).second)
127     report_fatal_error("Circular dependency found in global variable set");
128 
129   // Make sure we visit all dependents first
130   DenseSet<const GlobalVariable *> Others;
131   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
132     DiscoverDependentGlobals(GV->getOperand(i), Others);
133 
134   for (const GlobalVariable *GV : Others)
135     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136 
137   // Now we can visit ourself
138   Order.push_back(GV);
139   Visited.insert(GV);
140   Visiting.erase(GV);
141 }
142 
143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
145                                         getSubtargetInfo().getFeatureBits());
146 
147   MCInst Inst;
148   lowerToMCInst(MI, Inst);
149   EmitToStreamer(*OutStreamer, Inst);
150 }
151 
152 // Handle symbol backtracking for targets that do not support image handles
153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154                                            unsigned OpNo, MCOperand &MCOp) {
155   const MachineOperand &MO = MI->getOperand(OpNo);
156   const MCInstrDesc &MCID = MI->getDesc();
157 
158   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159     // This is a texture fetch, so operand 4 is a texref and operand 5 is
160     // a samplerref
161     if (OpNo == 4 && MO.isImm()) {
162       lowerImageHandleSymbol(MO.getImm(), MCOp);
163       return true;
164     }
165     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166       lowerImageHandleSymbol(MO.getImm(), MCOp);
167       return true;
168     }
169 
170     return false;
171   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172     unsigned VecSize =
173       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174 
175     // For a surface load of vector size N, the Nth operand will be the surfref
176     if (OpNo == VecSize && MO.isImm()) {
177       lowerImageHandleSymbol(MO.getImm(), MCOp);
178       return true;
179     }
180 
181     return false;
182   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183     // This is a surface store, so operand 0 is a surfref
184     if (OpNo == 0 && MO.isImm()) {
185       lowerImageHandleSymbol(MO.getImm(), MCOp);
186       return true;
187     }
188 
189     return false;
190   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191     // This is a query, so operand 1 is a surfref/texref
192     if (OpNo == 1 && MO.isImm()) {
193       lowerImageHandleSymbol(MO.getImm(), MCOp);
194       return true;
195     }
196 
197     return false;
198   }
199 
200   return false;
201 }
202 
203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204   // Ewwww
205   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
206   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
207   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208   const char *Sym = MFI->getImageHandleSymbol(Index);
209   std::string *SymNamePtr =
210     nvTM.getManagedStrPool()->getManagedString(Sym);
211   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr)));
212 }
213 
214 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
215   OutMI.setOpcode(MI->getOpcode());
216   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
217   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
218     const MachineOperand &MO = MI->getOperand(0);
219     OutMI.addOperand(GetSymbolRef(
220       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
221     return;
222   }
223 
224   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
225   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
226     const MachineOperand &MO = MI->getOperand(i);
227 
228     MCOperand MCOp;
229     if (!STI.hasImageHandles()) {
230       if (lowerImageHandleOperand(MI, i, MCOp)) {
231         OutMI.addOperand(MCOp);
232         continue;
233       }
234     }
235 
236     if (lowerOperand(MO, MCOp))
237       OutMI.addOperand(MCOp);
238   }
239 }
240 
241 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
242                                    MCOperand &MCOp) {
243   switch (MO.getType()) {
244   default: llvm_unreachable("unknown operand type");
245   case MachineOperand::MO_Register:
246     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
247     break;
248   case MachineOperand::MO_Immediate:
249     MCOp = MCOperand::createImm(MO.getImm());
250     break;
251   case MachineOperand::MO_MachineBasicBlock:
252     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
253         MO.getMBB()->getSymbol(), OutContext));
254     break;
255   case MachineOperand::MO_ExternalSymbol:
256     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
257     break;
258   case MachineOperand::MO_GlobalAddress:
259     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
260     break;
261   case MachineOperand::MO_FPImmediate: {
262     const ConstantFP *Cnt = MO.getFPImm();
263     const APFloat &Val = Cnt->getValueAPF();
264 
265     switch (Cnt->getType()->getTypeID()) {
266     default: report_fatal_error("Unsupported FP type"); break;
267     case Type::HalfTyID:
268       MCOp = MCOperand::createExpr(
269         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
270       break;
271     case Type::FloatTyID:
272       MCOp = MCOperand::createExpr(
273         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
274       break;
275     case Type::DoubleTyID:
276       MCOp = MCOperand::createExpr(
277         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
278       break;
279     }
280     break;
281   }
282   }
283   return true;
284 }
285 
286 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
287   if (Register::isVirtualRegister(Reg)) {
288     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
289 
290     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
291     unsigned RegNum = RegMap[Reg];
292 
293     // Encode the register class in the upper 4 bits
294     // Must be kept in sync with NVPTXInstPrinter::printRegName
295     unsigned Ret = 0;
296     if (RC == &NVPTX::Int1RegsRegClass) {
297       Ret = (1 << 28);
298     } else if (RC == &NVPTX::Int16RegsRegClass) {
299       Ret = (2 << 28);
300     } else if (RC == &NVPTX::Int32RegsRegClass) {
301       Ret = (3 << 28);
302     } else if (RC == &NVPTX::Int64RegsRegClass) {
303       Ret = (4 << 28);
304     } else if (RC == &NVPTX::Float32RegsRegClass) {
305       Ret = (5 << 28);
306     } else if (RC == &NVPTX::Float64RegsRegClass) {
307       Ret = (6 << 28);
308     } else if (RC == &NVPTX::Float16RegsRegClass) {
309       Ret = (7 << 28);
310     } else if (RC == &NVPTX::Float16x2RegsRegClass) {
311       Ret = (8 << 28);
312     } else {
313       report_fatal_error("Bad register class");
314     }
315 
316     // Insert the vreg number
317     Ret |= (RegNum & 0x0FFFFFFF);
318     return Ret;
319   } else {
320     // Some special-use registers are actually physical registers.
321     // Encode this as the register class ID of 0 and the real register ID.
322     return Reg & 0x0FFFFFFF;
323   }
324 }
325 
326 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
327   const MCExpr *Expr;
328   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
329                                  OutContext);
330   return MCOperand::createExpr(Expr);
331 }
332 
333 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
334   const DataLayout &DL = getDataLayout();
335   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
336   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
337 
338   Type *Ty = F->getReturnType();
339 
340   bool isABI = (STI.getSmVersion() >= 20);
341 
342   if (Ty->getTypeID() == Type::VoidTyID)
343     return;
344 
345   O << " (";
346 
347   if (isABI) {
348     if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
349       unsigned size = 0;
350       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
351         size = ITy->getBitWidth();
352       } else {
353         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
354         size = Ty->getPrimitiveSizeInBits();
355       }
356       // PTX ABI requires all scalar return values to be at least 32
357       // bits in size.  fp16 normally uses .b16 as its storage type in
358       // PTX, so its size must be adjusted here, too.
359       size = promoteScalarArgumentSize(size);
360 
361       O << ".param .b" << size << " func_retval0";
362     } else if (isa<PointerType>(Ty)) {
363       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
364         << " func_retval0";
365     } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
366       unsigned totalsz = DL.getTypeAllocSize(Ty);
367       unsigned retAlignment = 0;
368       if (!getAlign(*F, 0, retAlignment))
369         retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
370       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
371         << "]";
372     } else
373       llvm_unreachable("Unknown return type");
374   } else {
375     SmallVector<EVT, 16> vtparts;
376     ComputeValueVTs(*TLI, DL, Ty, vtparts);
377     unsigned idx = 0;
378     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
379       unsigned elems = 1;
380       EVT elemtype = vtparts[i];
381       if (vtparts[i].isVector()) {
382         elems = vtparts[i].getVectorNumElements();
383         elemtype = vtparts[i].getVectorElementType();
384       }
385 
386       for (unsigned j = 0, je = elems; j != je; ++j) {
387         unsigned sz = elemtype.getSizeInBits();
388         if (elemtype.isInteger())
389           sz = promoteScalarArgumentSize(sz);
390         O << ".reg .b" << sz << " func_retval" << idx;
391         if (j < je - 1)
392           O << ", ";
393         ++idx;
394       }
395       if (i < e - 1)
396         O << ", ";
397     }
398   }
399   O << ") ";
400 }
401 
402 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
403                                         raw_ostream &O) {
404   const Function &F = MF.getFunction();
405   printReturnValStr(&F, O);
406 }
407 
408 // Return true if MBB is the header of a loop marked with
409 // llvm.loop.unroll.disable.
410 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll".
411 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
412     const MachineBasicBlock &MBB) const {
413   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
414   // We insert .pragma "nounroll" only to the loop header.
415   if (!LI.isLoopHeader(&MBB))
416     return false;
417 
418   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
419   // we iterate through each back edge of the loop with header MBB, and check
420   // whether its metadata contains llvm.loop.unroll.disable.
421   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
422     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
423       // Edges from other loops to MBB are not back edges.
424       continue;
425     }
426     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
427       if (MDNode *LoopID =
428               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
429         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
430           return true;
431       }
432     }
433   }
434   return false;
435 }
436 
437 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
438   AsmPrinter::emitBasicBlockStart(MBB);
439   if (isLoopHeaderOfNoUnroll(MBB))
440     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
441 }
442 
443 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
444   SmallString<128> Str;
445   raw_svector_ostream O(Str);
446 
447   if (!GlobalsEmitted) {
448     emitGlobals(*MF->getFunction().getParent());
449     GlobalsEmitted = true;
450   }
451 
452   // Set up
453   MRI = &MF->getRegInfo();
454   F = &MF->getFunction();
455   emitLinkageDirective(F, O);
456   if (isKernelFunction(*F))
457     O << ".entry ";
458   else {
459     O << ".func ";
460     printReturnValStr(*MF, O);
461   }
462 
463   CurrentFnSym->print(O, MAI);
464 
465   emitFunctionParamList(*MF, O);
466 
467   if (isKernelFunction(*F))
468     emitKernelFunctionDirectives(*F, O);
469 
470   OutStreamer->emitRawText(O.str());
471 
472   VRegMapping.clear();
473   // Emit open brace for function body.
474   OutStreamer->emitRawText(StringRef("{\n"));
475   setAndEmitFunctionVirtualRegisters(*MF);
476   // Emit initial .loc debug directive for correct relocation symbol data.
477   if (MMI && MMI->hasDebugInfo())
478     emitInitialRawDwarfLocDirective(*MF);
479 }
480 
481 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
482   bool Result = AsmPrinter::runOnMachineFunction(F);
483   // Emit closing brace for the body of function F.
484   // The closing brace must be emitted here because we need to emit additional
485   // debug labels/data after the last basic block.
486   // We need to emit the closing brace here because we don't have function that
487   // finished emission of the function body.
488   OutStreamer->emitRawText(StringRef("}\n"));
489   return Result;
490 }
491 
492 void NVPTXAsmPrinter::emitFunctionBodyStart() {
493   SmallString<128> Str;
494   raw_svector_ostream O(Str);
495   emitDemotedVars(&MF->getFunction(), O);
496   OutStreamer->emitRawText(O.str());
497 }
498 
499 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
500   VRegMapping.clear();
501 }
502 
503 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
504     SmallString<128> Str;
505     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
506     return OutContext.getOrCreateSymbol(Str);
507 }
508 
509 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
510   Register RegNo = MI->getOperand(0).getReg();
511   if (Register::isVirtualRegister(RegNo)) {
512     OutStreamer->AddComment(Twine("implicit-def: ") +
513                             getVirtualRegisterName(RegNo));
514   } else {
515     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
516     OutStreamer->AddComment(Twine("implicit-def: ") +
517                             STI.getRegisterInfo()->getName(RegNo));
518   }
519   OutStreamer->addBlankLine();
520 }
521 
522 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
523                                                    raw_ostream &O) const {
524   // If the NVVM IR has some of reqntid* specified, then output
525   // the reqntid directive, and set the unspecified ones to 1.
526   // If none of reqntid* is specified, don't output reqntid directive.
527   unsigned reqntidx, reqntidy, reqntidz;
528   bool specified = false;
529   if (!getReqNTIDx(F, reqntidx))
530     reqntidx = 1;
531   else
532     specified = true;
533   if (!getReqNTIDy(F, reqntidy))
534     reqntidy = 1;
535   else
536     specified = true;
537   if (!getReqNTIDz(F, reqntidz))
538     reqntidz = 1;
539   else
540     specified = true;
541 
542   if (specified)
543     O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
544       << "\n";
545 
546   // If the NVVM IR has some of maxntid* specified, then output
547   // the maxntid directive, and set the unspecified ones to 1.
548   // If none of maxntid* is specified, don't output maxntid directive.
549   unsigned maxntidx, maxntidy, maxntidz;
550   specified = false;
551   if (!getMaxNTIDx(F, maxntidx))
552     maxntidx = 1;
553   else
554     specified = true;
555   if (!getMaxNTIDy(F, maxntidy))
556     maxntidy = 1;
557   else
558     specified = true;
559   if (!getMaxNTIDz(F, maxntidz))
560     maxntidz = 1;
561   else
562     specified = true;
563 
564   if (specified)
565     O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
566       << "\n";
567 
568   unsigned mincta;
569   if (getMinCTASm(F, mincta))
570     O << ".minnctapersm " << mincta << "\n";
571 
572   unsigned maxnreg;
573   if (getMaxNReg(F, maxnreg))
574     O << ".maxnreg " << maxnreg << "\n";
575 }
576 
577 std::string
578 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
579   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
580 
581   std::string Name;
582   raw_string_ostream NameStr(Name);
583 
584   VRegRCMap::const_iterator I = VRegMapping.find(RC);
585   assert(I != VRegMapping.end() && "Bad register class");
586   const DenseMap<unsigned, unsigned> &RegMap = I->second;
587 
588   VRegMap::const_iterator VI = RegMap.find(Reg);
589   assert(VI != RegMap.end() && "Bad virtual register");
590   unsigned MappedVR = VI->second;
591 
592   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
593 
594   NameStr.flush();
595   return Name;
596 }
597 
598 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
599                                           raw_ostream &O) {
600   O << getVirtualRegisterName(vr);
601 }
602 
603 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
604   emitLinkageDirective(F, O);
605   if (isKernelFunction(*F))
606     O << ".entry ";
607   else
608     O << ".func ";
609   printReturnValStr(F, O);
610   getSymbol(F)->print(O, MAI);
611   O << "\n";
612   emitFunctionParamList(F, O);
613   O << ";\n";
614 }
615 
616 static bool usedInGlobalVarDef(const Constant *C) {
617   if (!C)
618     return false;
619 
620   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
621     return GV->getName() != "llvm.used";
622   }
623 
624   for (const User *U : C->users())
625     if (const Constant *C = dyn_cast<Constant>(U))
626       if (usedInGlobalVarDef(C))
627         return true;
628 
629   return false;
630 }
631 
632 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
633   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
634     if (othergv->getName() == "llvm.used")
635       return true;
636   }
637 
638   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
639     if (instr->getParent() && instr->getParent()->getParent()) {
640       const Function *curFunc = instr->getParent()->getParent();
641       if (oneFunc && (curFunc != oneFunc))
642         return false;
643       oneFunc = curFunc;
644       return true;
645     } else
646       return false;
647   }
648 
649   for (const User *UU : U->users())
650     if (!usedInOneFunc(UU, oneFunc))
651       return false;
652 
653   return true;
654 }
655 
656 /* Find out if a global variable can be demoted to local scope.
657  * Currently, this is valid for CUDA shared variables, which have local
658  * scope and global lifetime. So the conditions to check are :
659  * 1. Is the global variable in shared address space?
660  * 2. Does it have internal linkage?
661  * 3. Is the global variable referenced only in one function?
662  */
663 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
664   if (!gv->hasInternalLinkage())
665     return false;
666   PointerType *Pty = gv->getType();
667   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
668     return false;
669 
670   const Function *oneFunc = nullptr;
671 
672   bool flag = usedInOneFunc(gv, oneFunc);
673   if (!flag)
674     return false;
675   if (!oneFunc)
676     return false;
677   f = oneFunc;
678   return true;
679 }
680 
681 static bool useFuncSeen(const Constant *C,
682                         DenseMap<const Function *, bool> &seenMap) {
683   for (const User *U : C->users()) {
684     if (const Constant *cu = dyn_cast<Constant>(U)) {
685       if (useFuncSeen(cu, seenMap))
686         return true;
687     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
688       const BasicBlock *bb = I->getParent();
689       if (!bb)
690         continue;
691       const Function *caller = bb->getParent();
692       if (!caller)
693         continue;
694       if (seenMap.find(caller) != seenMap.end())
695         return true;
696     }
697   }
698   return false;
699 }
700 
701 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
702   DenseMap<const Function *, bool> seenMap;
703   for (const Function &F : M) {
704     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
705       emitDeclaration(&F, O);
706       continue;
707     }
708 
709     if (F.isDeclaration()) {
710       if (F.use_empty())
711         continue;
712       if (F.getIntrinsicID())
713         continue;
714       emitDeclaration(&F, O);
715       continue;
716     }
717     for (const User *U : F.users()) {
718       if (const Constant *C = dyn_cast<Constant>(U)) {
719         if (usedInGlobalVarDef(C)) {
720           // The use is in the initialization of a global variable
721           // that is a function pointer, so print a declaration
722           // for the original function
723           emitDeclaration(&F, O);
724           break;
725         }
726         // Emit a declaration of this function if the function that
727         // uses this constant expr has already been seen.
728         if (useFuncSeen(C, seenMap)) {
729           emitDeclaration(&F, O);
730           break;
731         }
732       }
733 
734       if (!isa<Instruction>(U))
735         continue;
736       const Instruction *instr = cast<Instruction>(U);
737       const BasicBlock *bb = instr->getParent();
738       if (!bb)
739         continue;
740       const Function *caller = bb->getParent();
741       if (!caller)
742         continue;
743 
744       // If a caller has already been seen, then the caller is
745       // appearing in the module before the callee. so print out
746       // a declaration for the callee.
747       if (seenMap.find(caller) != seenMap.end()) {
748         emitDeclaration(&F, O);
749         break;
750       }
751     }
752     seenMap[&F] = true;
753   }
754 }
755 
756 static bool isEmptyXXStructor(GlobalVariable *GV) {
757   if (!GV) return true;
758   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
759   if (!InitList) return true;  // Not an array; we don't know how to parse.
760   return InitList->getNumOperands() == 0;
761 }
762 
763 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
764   // Construct a default subtarget off of the TargetMachine defaults. The
765   // rest of NVPTX isn't friendly to change subtargets per function and
766   // so the default TargetMachine will have all of the options.
767   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
768   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
769   SmallString<128> Str1;
770   raw_svector_ostream OS1(Str1);
771 
772   // Emit header before any dwarf directives are emitted below.
773   emitHeader(M, OS1, *STI);
774   OutStreamer->emitRawText(OS1.str());
775 }
776 
777 bool NVPTXAsmPrinter::doInitialization(Module &M) {
778   if (M.alias_size()) {
779     report_fatal_error("Module has aliases, which NVPTX does not support.");
780     return true; // error
781   }
782   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
783     report_fatal_error(
784         "Module has a nontrivial global ctor, which NVPTX does not support.");
785     return true;  // error
786   }
787   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
788     report_fatal_error(
789         "Module has a nontrivial global dtor, which NVPTX does not support.");
790     return true;  // error
791   }
792 
793   // We need to call the parent's one explicitly.
794   bool Result = AsmPrinter::doInitialization(M);
795 
796   GlobalsEmitted = false;
797 
798   return Result;
799 }
800 
801 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
802   SmallString<128> Str2;
803   raw_svector_ostream OS2(Str2);
804 
805   emitDeclarations(M, OS2);
806 
807   // As ptxas does not support forward references of globals, we need to first
808   // sort the list of module-level globals in def-use order. We visit each
809   // global variable in order, and ensure that we emit it *after* its dependent
810   // globals. We use a little extra memory maintaining both a set and a list to
811   // have fast searches while maintaining a strict ordering.
812   SmallVector<const GlobalVariable *, 8> Globals;
813   DenseSet<const GlobalVariable *> GVVisited;
814   DenseSet<const GlobalVariable *> GVVisiting;
815 
816   // Visit each global variable, in order
817   for (const GlobalVariable &I : M.globals())
818     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
819 
820   assert(GVVisited.size() == M.getGlobalList().size() &&
821          "Missed a global variable");
822   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
823 
824   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
825   const NVPTXSubtarget &STI =
826       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
827 
828   // Print out module-level global variables in proper order
829   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
830     printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
831 
832   OS2 << '\n';
833 
834   OutStreamer->emitRawText(OS2.str());
835 }
836 
837 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
838                                  const NVPTXSubtarget &STI) {
839   O << "//\n";
840   O << "// Generated by LLVM NVPTX Back-End\n";
841   O << "//\n";
842   O << "\n";
843 
844   unsigned PTXVersion = STI.getPTXVersion();
845   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
846 
847   O << ".target ";
848   O << STI.getTargetName();
849 
850   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
851   if (NTM.getDrvInterface() == NVPTX::NVCL)
852     O << ", texmode_independent";
853 
854   bool HasFullDebugInfo = false;
855   for (DICompileUnit *CU : M.debug_compile_units()) {
856     switch(CU->getEmissionKind()) {
857     case DICompileUnit::NoDebug:
858     case DICompileUnit::DebugDirectivesOnly:
859       break;
860     case DICompileUnit::LineTablesOnly:
861     case DICompileUnit::FullDebug:
862       HasFullDebugInfo = true;
863       break;
864     }
865     if (HasFullDebugInfo)
866       break;
867   }
868   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
869     O << ", debug";
870 
871   O << "\n";
872 
873   O << ".address_size ";
874   if (NTM.is64Bit())
875     O << "64";
876   else
877     O << "32";
878   O << "\n";
879 
880   O << "\n";
881 }
882 
883 bool NVPTXAsmPrinter::doFinalization(Module &M) {
884   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
885 
886   // If we did not emit any functions, then the global declarations have not
887   // yet been emitted.
888   if (!GlobalsEmitted) {
889     emitGlobals(M);
890     GlobalsEmitted = true;
891   }
892 
893   // call doFinalization
894   bool ret = AsmPrinter::doFinalization(M);
895 
896   clearAnnotationCache(&M);
897 
898   if (auto *TS = static_cast<NVPTXTargetStreamer *>(
899           OutStreamer->getTargetStreamer())) {
900     // Close the last emitted section
901     if (HasDebugInfo) {
902       TS->closeLastSection();
903       // Emit empty .debug_loc section for better support of the empty files.
904       OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
905     }
906 
907     // Output last DWARF .file directives, if any.
908     TS->outputDwarfFileDirectives();
909   }
910 
911   return ret;
912 
913   //bool Result = AsmPrinter::doFinalization(M);
914   // Instead of calling the parents doFinalization, we may
915   // clone parents doFinalization and customize here.
916   // Currently, we if NVISA out the EmitGlobals() in
917   // parent's doFinalization, which is too intrusive.
918   //
919   // Same for the doInitialization.
920   //return Result;
921 }
922 
923 // This function emits appropriate linkage directives for
924 // functions and global variables.
925 //
926 // extern function declaration            -> .extern
927 // extern function definition             -> .visible
928 // external global variable with init     -> .visible
929 // external without init                  -> .extern
930 // appending                              -> not allowed, assert.
931 // for any linkage other than
932 // internal, private, linker_private,
933 // linker_private_weak, linker_private_weak_def_auto,
934 // we emit                                -> .weak.
935 
936 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
937                                            raw_ostream &O) {
938   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
939     if (V->hasExternalLinkage()) {
940       if (isa<GlobalVariable>(V)) {
941         const GlobalVariable *GVar = cast<GlobalVariable>(V);
942         if (GVar) {
943           if (GVar->hasInitializer())
944             O << ".visible ";
945           else
946             O << ".extern ";
947         }
948       } else if (V->isDeclaration())
949         O << ".extern ";
950       else
951         O << ".visible ";
952     } else if (V->hasAppendingLinkage()) {
953       std::string msg;
954       msg.append("Error: ");
955       msg.append("Symbol ");
956       if (V->hasName())
957         msg.append(std::string(V->getName()));
958       msg.append("has unsupported appending linkage type");
959       llvm_unreachable(msg.c_str());
960     } else if (!V->hasInternalLinkage() &&
961                !V->hasPrivateLinkage()) {
962       O << ".weak ";
963     }
964   }
965 }
966 
967 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
968                                          raw_ostream &O, bool processDemoted,
969                                          const NVPTXSubtarget &STI) {
970   // Skip meta data
971   if (GVar->hasSection()) {
972     if (GVar->getSection() == "llvm.metadata")
973       return;
974   }
975 
976   // Skip LLVM intrinsic global variables
977   if (GVar->getName().startswith("llvm.") ||
978       GVar->getName().startswith("nvvm."))
979     return;
980 
981   const DataLayout &DL = getDataLayout();
982 
983   // GlobalVariables are always constant pointers themselves.
984   PointerType *PTy = GVar->getType();
985   Type *ETy = GVar->getValueType();
986 
987   if (GVar->hasExternalLinkage()) {
988     if (GVar->hasInitializer())
989       O << ".visible ";
990     else
991       O << ".extern ";
992   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
993              GVar->hasAvailableExternallyLinkage() ||
994              GVar->hasCommonLinkage()) {
995     O << ".weak ";
996   }
997 
998   if (isTexture(*GVar)) {
999     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1000     return;
1001   }
1002 
1003   if (isSurface(*GVar)) {
1004     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1005     return;
1006   }
1007 
1008   if (GVar->isDeclaration()) {
1009     // (extern) declarations, no definition or initializer
1010     // Currently the only known declaration is for an automatic __local
1011     // (.shared) promoted to global.
1012     emitPTXGlobalVariable(GVar, O, STI);
1013     O << ";\n";
1014     return;
1015   }
1016 
1017   if (isSampler(*GVar)) {
1018     O << ".global .samplerref " << getSamplerName(*GVar);
1019 
1020     const Constant *Initializer = nullptr;
1021     if (GVar->hasInitializer())
1022       Initializer = GVar->getInitializer();
1023     const ConstantInt *CI = nullptr;
1024     if (Initializer)
1025       CI = dyn_cast<ConstantInt>(Initializer);
1026     if (CI) {
1027       unsigned sample = CI->getZExtValue();
1028 
1029       O << " = { ";
1030 
1031       for (int i = 0,
1032                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1033            i < 3; i++) {
1034         O << "addr_mode_" << i << " = ";
1035         switch (addr) {
1036         case 0:
1037           O << "wrap";
1038           break;
1039         case 1:
1040           O << "clamp_to_border";
1041           break;
1042         case 2:
1043           O << "clamp_to_edge";
1044           break;
1045         case 3:
1046           O << "wrap";
1047           break;
1048         case 4:
1049           O << "mirror";
1050           break;
1051         }
1052         O << ", ";
1053       }
1054       O << "filter_mode = ";
1055       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1056       case 0:
1057         O << "nearest";
1058         break;
1059       case 1:
1060         O << "linear";
1061         break;
1062       case 2:
1063         llvm_unreachable("Anisotropic filtering is not supported");
1064       default:
1065         O << "nearest";
1066         break;
1067       }
1068       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1069         O << ", force_unnormalized_coords = 1";
1070       }
1071       O << " }";
1072     }
1073 
1074     O << ";\n";
1075     return;
1076   }
1077 
1078   if (GVar->hasPrivateLinkage()) {
1079     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1080       return;
1081 
1082     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1083     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1084       return;
1085     if (GVar->use_empty())
1086       return;
1087   }
1088 
1089   const Function *demotedFunc = nullptr;
1090   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1091     O << "// " << GVar->getName() << " has been demoted\n";
1092     if (localDecls.find(demotedFunc) != localDecls.end())
1093       localDecls[demotedFunc].push_back(GVar);
1094     else {
1095       std::vector<const GlobalVariable *> temp;
1096       temp.push_back(GVar);
1097       localDecls[demotedFunc] = temp;
1098     }
1099     return;
1100   }
1101 
1102   O << ".";
1103   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1104 
1105   if (isManaged(*GVar)) {
1106     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1107       report_fatal_error(
1108           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1109     }
1110     O << " .attribute(.managed)";
1111   }
1112 
1113   if (MaybeAlign A = GVar->getAlign())
1114     O << " .align " << A->value();
1115   else
1116     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1117 
1118   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1119       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1120     O << " .";
1121     // Special case: ABI requires that we use .u8 for predicates
1122     if (ETy->isIntegerTy(1))
1123       O << "u8";
1124     else
1125       O << getPTXFundamentalTypeStr(ETy, false);
1126     O << " ";
1127     getSymbol(GVar)->print(O, MAI);
1128 
1129     // Ptx allows variable initilization only for constant and global state
1130     // spaces.
1131     if (GVar->hasInitializer()) {
1132       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1133           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1134         const Constant *Initializer = GVar->getInitializer();
1135         // 'undef' is treated as there is no value specified.
1136         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1137           O << " = ";
1138           printScalarConstant(Initializer, O);
1139         }
1140       } else {
1141         // The frontend adds zero-initializer to device and constant variables
1142         // that don't have an initial value, and UndefValue to shared
1143         // variables, so skip warning for this case.
1144         if (!GVar->getInitializer()->isNullValue() &&
1145             !isa<UndefValue>(GVar->getInitializer())) {
1146           report_fatal_error("initial value of '" + GVar->getName() +
1147                              "' is not allowed in addrspace(" +
1148                              Twine(PTy->getAddressSpace()) + ")");
1149         }
1150       }
1151     }
1152   } else {
1153     unsigned int ElementSize = 0;
1154 
1155     // Although PTX has direct support for struct type and array type and
1156     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1157     // targets that support these high level field accesses. Structs, arrays
1158     // and vectors are lowered into arrays of bytes.
1159     switch (ETy->getTypeID()) {
1160     case Type::IntegerTyID: // Integers larger than 64 bits
1161     case Type::StructTyID:
1162     case Type::ArrayTyID:
1163     case Type::FixedVectorTyID:
1164       ElementSize = DL.getTypeStoreSize(ETy);
1165       // Ptx allows variable initilization only for constant and
1166       // global state spaces.
1167       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1168            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1169           GVar->hasInitializer()) {
1170         const Constant *Initializer = GVar->getInitializer();
1171         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1172           AggBuffer aggBuffer(ElementSize, *this);
1173           bufferAggregateConstant(Initializer, &aggBuffer);
1174           if (aggBuffer.numSymbols()) {
1175             unsigned int ptrSize = MAI->getCodePointerSize();
1176             if (ElementSize % ptrSize ||
1177                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1178               // Print in bytes and use the mask() operator for pointers.
1179               if (!STI.hasMaskOperator())
1180                 report_fatal_error(
1181                     "initialized packed aggregate with pointers '" +
1182                     GVar->getName() +
1183                     "' requires at least PTX ISA version 7.1");
1184               O << " .u8 ";
1185               getSymbol(GVar)->print(O, MAI);
1186               O << "[" << ElementSize << "] = {";
1187               aggBuffer.printBytes(O);
1188               O << "}";
1189             } else {
1190               O << " .u" << ptrSize * 8 << " ";
1191               getSymbol(GVar)->print(O, MAI);
1192               O << "[" << ElementSize / ptrSize << "] = {";
1193               aggBuffer.printWords(O);
1194               O << "}";
1195             }
1196           } else {
1197             O << " .b8 ";
1198             getSymbol(GVar)->print(O, MAI);
1199             O << "[" << ElementSize << "] = {";
1200             aggBuffer.printBytes(O);
1201             O << "}";
1202           }
1203         } else {
1204           O << " .b8 ";
1205           getSymbol(GVar)->print(O, MAI);
1206           if (ElementSize) {
1207             O << "[";
1208             O << ElementSize;
1209             O << "]";
1210           }
1211         }
1212       } else {
1213         O << " .b8 ";
1214         getSymbol(GVar)->print(O, MAI);
1215         if (ElementSize) {
1216           O << "[";
1217           O << ElementSize;
1218           O << "]";
1219         }
1220       }
1221       break;
1222     default:
1223       llvm_unreachable("type not supported yet");
1224     }
1225   }
1226   O << ";\n";
1227 }
1228 
1229 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1230   const Value *v = Symbols[nSym];
1231   const Value *v0 = SymbolsBeforeStripping[nSym];
1232   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1233     MCSymbol *Name = AP.getSymbol(GVar);
1234     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1235     // Is v0 a generic pointer?
1236     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1237     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1238       os << "generic(";
1239       Name->print(os, AP.MAI);
1240       os << ")";
1241     } else {
1242       Name->print(os, AP.MAI);
1243     }
1244   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1245     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1246     AP.printMCExpr(*Expr, os);
1247   } else
1248     llvm_unreachable("symbol type unknown");
1249 }
1250 
1251 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1252   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1253   symbolPosInBuffer.push_back(size);
1254   unsigned int nSym = 0;
1255   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1256   for (unsigned int pos = 0; pos < size;) {
1257     if (pos)
1258       os << ", ";
1259     if (pos != nextSymbolPos) {
1260       os << (unsigned int)buffer[pos];
1261       ++pos;
1262       continue;
1263     }
1264     // Generate a per-byte mask() operator for the symbol, which looks like:
1265     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1266     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1267     std::string symText;
1268     llvm::raw_string_ostream oss(symText);
1269     printSymbol(nSym, oss);
1270     for (unsigned i = 0; i < ptrSize; ++i) {
1271       if (i)
1272         os << ", ";
1273       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1274       os << "(" << symText << ")";
1275     }
1276     pos += ptrSize;
1277     nextSymbolPos = symbolPosInBuffer[++nSym];
1278     assert(nextSymbolPos >= pos);
1279   }
1280 }
1281 
1282 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1283   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1284   symbolPosInBuffer.push_back(size);
1285   unsigned int nSym = 0;
1286   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1287   assert(nextSymbolPos % ptrSize == 0);
1288   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1289     if (pos)
1290       os << ", ";
1291     if (pos == nextSymbolPos) {
1292       printSymbol(nSym, os);
1293       nextSymbolPos = symbolPosInBuffer[++nSym];
1294       assert(nextSymbolPos % ptrSize == 0);
1295       assert(nextSymbolPos >= pos + ptrSize);
1296     } else if (ptrSize == 4)
1297       os << support::endian::read32le(&buffer[pos]);
1298     else
1299       os << support::endian::read64le(&buffer[pos]);
1300   }
1301 }
1302 
1303 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1304   if (localDecls.find(f) == localDecls.end())
1305     return;
1306 
1307   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1308 
1309   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1310   const NVPTXSubtarget &STI =
1311       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1312 
1313   for (const GlobalVariable *GV : gvars) {
1314     O << "\t// demoted variable\n\t";
1315     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1316   }
1317 }
1318 
1319 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1320                                           raw_ostream &O) const {
1321   switch (AddressSpace) {
1322   case ADDRESS_SPACE_LOCAL:
1323     O << "local";
1324     break;
1325   case ADDRESS_SPACE_GLOBAL:
1326     O << "global";
1327     break;
1328   case ADDRESS_SPACE_CONST:
1329     O << "const";
1330     break;
1331   case ADDRESS_SPACE_SHARED:
1332     O << "shared";
1333     break;
1334   default:
1335     report_fatal_error("Bad address space found while emitting PTX: " +
1336                        llvm::Twine(AddressSpace));
1337     break;
1338   }
1339 }
1340 
1341 std::string
1342 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1343   switch (Ty->getTypeID()) {
1344   case Type::IntegerTyID: {
1345     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1346     if (NumBits == 1)
1347       return "pred";
1348     else if (NumBits <= 64) {
1349       std::string name = "u";
1350       return name + utostr(NumBits);
1351     } else {
1352       llvm_unreachable("Integer too large");
1353       break;
1354     }
1355     break;
1356   }
1357   case Type::HalfTyID:
1358     // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1359     return "b16";
1360   case Type::FloatTyID:
1361     return "f32";
1362   case Type::DoubleTyID:
1363     return "f64";
1364   case Type::PointerTyID:
1365     if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit())
1366       if (useB4PTR)
1367         return "b64";
1368       else
1369         return "u64";
1370     else if (useB4PTR)
1371       return "b32";
1372     else
1373       return "u32";
1374   default:
1375     break;
1376   }
1377   llvm_unreachable("unexpected type");
1378 }
1379 
1380 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1381                                             raw_ostream &O,
1382                                             const NVPTXSubtarget &STI) {
1383   const DataLayout &DL = getDataLayout();
1384 
1385   // GlobalVariables are always constant pointers themselves.
1386   Type *ETy = GVar->getValueType();
1387 
1388   O << ".";
1389   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1390   if (isManaged(*GVar)) {
1391     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1392       report_fatal_error(
1393           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1394     }
1395     O << " .attribute(.managed)";
1396   }
1397   if (MaybeAlign A = GVar->getAlign())
1398     O << " .align " << A->value();
1399   else
1400     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1401 
1402   // Special case for i128
1403   if (ETy->isIntegerTy(128)) {
1404     O << " .b8 ";
1405     getSymbol(GVar)->print(O, MAI);
1406     O << "[16]";
1407     return;
1408   }
1409 
1410   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1411     O << " .";
1412     O << getPTXFundamentalTypeStr(ETy);
1413     O << " ";
1414     getSymbol(GVar)->print(O, MAI);
1415     return;
1416   }
1417 
1418   int64_t ElementSize = 0;
1419 
1420   // Although PTX has direct support for struct type and array type and LLVM IR
1421   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1422   // support these high level field accesses. Structs and arrays are lowered
1423   // into arrays of bytes.
1424   switch (ETy->getTypeID()) {
1425   case Type::StructTyID:
1426   case Type::ArrayTyID:
1427   case Type::FixedVectorTyID:
1428     ElementSize = DL.getTypeStoreSize(ETy);
1429     O << " .b8 ";
1430     getSymbol(GVar)->print(O, MAI);
1431     O << "[";
1432     if (ElementSize) {
1433       O << ElementSize;
1434     }
1435     O << "]";
1436     break;
1437   default:
1438     llvm_unreachable("type not supported yet");
1439   }
1440 }
1441 
1442 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1443                                      int paramIndex, raw_ostream &O) {
1444   getSymbol(I->getParent())->print(O, MAI);
1445   O << "_param_" << paramIndex;
1446 }
1447 
1448 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1449   const DataLayout &DL = getDataLayout();
1450   const AttributeList &PAL = F->getAttributes();
1451   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1452   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1453 
1454   Function::const_arg_iterator I, E;
1455   unsigned paramIndex = 0;
1456   bool first = true;
1457   bool isKernelFunc = isKernelFunction(*F);
1458   bool isABI = (STI.getSmVersion() >= 20);
1459   bool hasImageHandles = STI.hasImageHandles();
1460   MVT thePointerTy = TLI->getPointerTy(DL);
1461 
1462   if (F->arg_empty()) {
1463     O << "()\n";
1464     return;
1465   }
1466 
1467   O << "(\n";
1468 
1469   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1470     Type *Ty = I->getType();
1471 
1472     if (!first)
1473       O << ",\n";
1474 
1475     first = false;
1476 
1477     // Handle image/sampler parameters
1478     if (isKernelFunction(*F)) {
1479       if (isSampler(*I) || isImage(*I)) {
1480         if (isImage(*I)) {
1481           std::string sname = std::string(I->getName());
1482           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1483             if (hasImageHandles)
1484               O << "\t.param .u64 .ptr .surfref ";
1485             else
1486               O << "\t.param .surfref ";
1487             CurrentFnSym->print(O, MAI);
1488             O << "_param_" << paramIndex;
1489           }
1490           else { // Default image is read_only
1491             if (hasImageHandles)
1492               O << "\t.param .u64 .ptr .texref ";
1493             else
1494               O << "\t.param .texref ";
1495             CurrentFnSym->print(O, MAI);
1496             O << "_param_" << paramIndex;
1497           }
1498         } else {
1499           if (hasImageHandles)
1500             O << "\t.param .u64 .ptr .samplerref ";
1501           else
1502             O << "\t.param .samplerref ";
1503           CurrentFnSym->print(O, MAI);
1504           O << "_param_" << paramIndex;
1505         }
1506         continue;
1507       }
1508     }
1509 
1510     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1511                                     paramIndex](Type *Ty) -> Align {
1512       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1513       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1514       return std::max(TypeAlign, ParamAlign.valueOrOne());
1515     };
1516 
1517     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1518       if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1519         // Just print .param .align <a> .b8 .param[size];
1520         // <a>  = optimal alignment for the element type; always multiple of
1521         //        PAL.getParamAlignment
1522         // size = typeallocsize of element type
1523         Align OptimalAlign = getOptimalAlignForParam(Ty);
1524 
1525         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1526         printParamName(I, paramIndex, O);
1527         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1528 
1529         continue;
1530       }
1531       // Just a scalar
1532       auto *PTy = dyn_cast<PointerType>(Ty);
1533       if (isKernelFunc) {
1534         if (PTy) {
1535           // Special handling for pointer arguments to kernel
1536           O << "\t.param .u" << thePointerTy.getSizeInBits() << " ";
1537 
1538           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1539               NVPTX::CUDA) {
1540             int addrSpace = PTy->getAddressSpace();
1541             switch (addrSpace) {
1542             default:
1543               O << ".ptr ";
1544               break;
1545             case ADDRESS_SPACE_CONST:
1546               O << ".ptr .const ";
1547               break;
1548             case ADDRESS_SPACE_SHARED:
1549               O << ".ptr .shared ";
1550               break;
1551             case ADDRESS_SPACE_GLOBAL:
1552               O << ".ptr .global ";
1553               break;
1554             }
1555             Align ParamAlign = I->getParamAlign().valueOrOne();
1556             O << ".align " << ParamAlign.value() << " ";
1557           }
1558           printParamName(I, paramIndex, O);
1559           continue;
1560         }
1561 
1562         // non-pointer scalar to kernel func
1563         O << "\t.param .";
1564         // Special case: predicate operands become .u8 types
1565         if (Ty->isIntegerTy(1))
1566           O << "u8";
1567         else
1568           O << getPTXFundamentalTypeStr(Ty);
1569         O << " ";
1570         printParamName(I, paramIndex, O);
1571         continue;
1572       }
1573       // Non-kernel function, just print .param .b<size> for ABI
1574       // and .reg .b<size> for non-ABI
1575       unsigned sz = 0;
1576       if (isa<IntegerType>(Ty)) {
1577         sz = cast<IntegerType>(Ty)->getBitWidth();
1578         sz = promoteScalarArgumentSize(sz);
1579       } else if (isa<PointerType>(Ty))
1580         sz = thePointerTy.getSizeInBits();
1581       else if (Ty->isHalfTy())
1582         // PTX ABI requires all scalar parameters to be at least 32
1583         // bits in size.  fp16 normally uses .b16 as its storage type
1584         // in PTX, so its size must be adjusted here, too.
1585         sz = 32;
1586       else
1587         sz = Ty->getPrimitiveSizeInBits();
1588       if (isABI)
1589         O << "\t.param .b" << sz << " ";
1590       else
1591         O << "\t.reg .b" << sz << " ";
1592       printParamName(I, paramIndex, O);
1593       continue;
1594     }
1595 
1596     // param has byVal attribute.
1597     Type *ETy = PAL.getParamByValType(paramIndex);
1598     assert(ETy && "Param should have byval type");
1599 
1600     if (isABI || isKernelFunc) {
1601       // Just print .param .align <a> .b8 .param[size];
1602       // <a>  = optimal alignment for the element type; always multiple of
1603       //        PAL.getParamAlignment
1604       // size = typeallocsize of element type
1605       Align OptimalAlign = getOptimalAlignForParam(ETy);
1606 
1607       // Work around a bug in ptxas. When PTX code takes address of
1608       // byval parameter with alignment < 4, ptxas generates code to
1609       // spill argument into memory. Alas on sm_50+ ptxas generates
1610       // SASS code that fails with misaligned access. To work around
1611       // the problem, make sure that we align byval parameters by at
1612       // least 4. Matching change must be made in LowerCall() where we
1613       // prepare parameters for the call.
1614       //
1615       // TODO: this will need to be undone when we get to support multi-TU
1616       // device-side compilation as it breaks ABI compatibility with nvcc.
1617       // Hopefully ptxas bug is fixed by then.
1618       if (!isKernelFunc && OptimalAlign < Align(4))
1619         OptimalAlign = Align(4);
1620       unsigned sz = DL.getTypeAllocSize(ETy);
1621       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1622       printParamName(I, paramIndex, O);
1623       O << "[" << sz << "]";
1624       continue;
1625     } else {
1626       // Split the ETy into constituent parts and
1627       // print .param .b<size> <name> for each part.
1628       // Further, if a part is vector, print the above for
1629       // each vector element.
1630       SmallVector<EVT, 16> vtparts;
1631       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1632       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1633         unsigned elems = 1;
1634         EVT elemtype = vtparts[i];
1635         if (vtparts[i].isVector()) {
1636           elems = vtparts[i].getVectorNumElements();
1637           elemtype = vtparts[i].getVectorElementType();
1638         }
1639 
1640         for (unsigned j = 0, je = elems; j != je; ++j) {
1641           unsigned sz = elemtype.getSizeInBits();
1642           if (elemtype.isInteger())
1643             sz = promoteScalarArgumentSize(sz);
1644           O << "\t.reg .b" << sz << " ";
1645           printParamName(I, paramIndex, O);
1646           if (j < je - 1)
1647             O << ",\n";
1648           ++paramIndex;
1649         }
1650         if (i < e - 1)
1651           O << ",\n";
1652       }
1653       --paramIndex;
1654       continue;
1655     }
1656   }
1657 
1658   O << "\n)\n";
1659 }
1660 
1661 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1662                                             raw_ostream &O) {
1663   const Function &F = MF.getFunction();
1664   emitFunctionParamList(&F, O);
1665 }
1666 
1667 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1668     const MachineFunction &MF) {
1669   SmallString<128> Str;
1670   raw_svector_ostream O(Str);
1671 
1672   // Map the global virtual register number to a register class specific
1673   // virtual register number starting from 1 with that class.
1674   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1675   //unsigned numRegClasses = TRI->getNumRegClasses();
1676 
1677   // Emit the Fake Stack Object
1678   const MachineFrameInfo &MFI = MF.getFrameInfo();
1679   int NumBytes = (int) MFI.getStackSize();
1680   if (NumBytes) {
1681     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1682       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1683     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1684       O << "\t.reg .b64 \t%SP;\n";
1685       O << "\t.reg .b64 \t%SPL;\n";
1686     } else {
1687       O << "\t.reg .b32 \t%SP;\n";
1688       O << "\t.reg .b32 \t%SPL;\n";
1689     }
1690   }
1691 
1692   // Go through all virtual registers to establish the mapping between the
1693   // global virtual
1694   // register number and the per class virtual register number.
1695   // We use the per class virtual register number in the ptx output.
1696   unsigned int numVRs = MRI->getNumVirtRegs();
1697   for (unsigned i = 0; i < numVRs; i++) {
1698     Register vr = Register::index2VirtReg(i);
1699     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1700     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1701     int n = regmap.size();
1702     regmap.insert(std::make_pair(vr, n + 1));
1703   }
1704 
1705   // Emit register declarations
1706   // @TODO: Extract out the real register usage
1707   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1708   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1709   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1710   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1711   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1712   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1713   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1714 
1715   // Emit declaration of the virtual registers or 'physical' registers for
1716   // each register class
1717   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1718     const TargetRegisterClass *RC = TRI->getRegClass(i);
1719     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1720     std::string rcname = getNVPTXRegClassName(RC);
1721     std::string rcStr = getNVPTXRegClassStr(RC);
1722     int n = regmap.size();
1723 
1724     // Only declare those registers that may be used.
1725     if (n) {
1726        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1727          << ">;\n";
1728     }
1729   }
1730 
1731   OutStreamer->emitRawText(O.str());
1732 }
1733 
1734 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1735   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1736   bool ignored;
1737   unsigned int numHex;
1738   const char *lead;
1739 
1740   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1741     numHex = 8;
1742     lead = "0f";
1743     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1744   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1745     numHex = 16;
1746     lead = "0d";
1747     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1748   } else
1749     llvm_unreachable("unsupported fp type");
1750 
1751   APInt API = APF.bitcastToAPInt();
1752   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1753 }
1754 
1755 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1756   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1757     O << CI->getValue();
1758     return;
1759   }
1760   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1761     printFPConstant(CFP, O);
1762     return;
1763   }
1764   if (isa<ConstantPointerNull>(CPV)) {
1765     O << "0";
1766     return;
1767   }
1768   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1769     bool IsNonGenericPointer = false;
1770     if (GVar->getType()->getAddressSpace() != 0) {
1771       IsNonGenericPointer = true;
1772     }
1773     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1774       O << "generic(";
1775       getSymbol(GVar)->print(O, MAI);
1776       O << ")";
1777     } else {
1778       getSymbol(GVar)->print(O, MAI);
1779     }
1780     return;
1781   }
1782   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1783     const Value *v = Cexpr->stripPointerCasts();
1784     PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType());
1785     bool IsNonGenericPointer = false;
1786     if (PTy && PTy->getAddressSpace() != 0) {
1787       IsNonGenericPointer = true;
1788     }
1789     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1790       if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
1791         O << "generic(";
1792         getSymbol(GVar)->print(O, MAI);
1793         O << ")";
1794       } else {
1795         getSymbol(GVar)->print(O, MAI);
1796       }
1797       return;
1798     } else {
1799       lowerConstant(CPV)->print(O, MAI);
1800       return;
1801     }
1802   }
1803   llvm_unreachable("Not scalar type found in printScalarConstant()");
1804 }
1805 
1806 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1807                                    AggBuffer *AggBuffer) {
1808   const DataLayout &DL = getDataLayout();
1809   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1810   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1811     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1812     // only the space allocated by CPV.
1813     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1814     return;
1815   }
1816 
1817   // Helper for filling AggBuffer with APInts.
1818   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1819     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1820     SmallVector<unsigned char, 16> Buf(NumBytes);
1821     for (unsigned I = 0; I < NumBytes; ++I) {
1822       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1823     }
1824     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1825   };
1826 
1827   switch (CPV->getType()->getTypeID()) {
1828   case Type::IntegerTyID:
1829     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1830       AddIntToBuffer(CI->getValue());
1831       break;
1832     }
1833     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1834       if (const auto *CI =
1835               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1836         AddIntToBuffer(CI->getValue());
1837         break;
1838       }
1839       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1840         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1841         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1842         AggBuffer->addZeros(AllocSize);
1843         break;
1844       }
1845     }
1846     llvm_unreachable("unsupported integer const type");
1847     break;
1848 
1849   case Type::HalfTyID:
1850   case Type::FloatTyID:
1851   case Type::DoubleTyID:
1852     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1853     break;
1854 
1855   case Type::PointerTyID: {
1856     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1857       AggBuffer->addSymbol(GVar, GVar);
1858     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1859       const Value *v = Cexpr->stripPointerCasts();
1860       AggBuffer->addSymbol(v, Cexpr);
1861     }
1862     AggBuffer->addZeros(AllocSize);
1863     break;
1864   }
1865 
1866   case Type::ArrayTyID:
1867   case Type::FixedVectorTyID:
1868   case Type::StructTyID: {
1869     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1870       bufferAggregateConstant(CPV, AggBuffer);
1871       if (Bytes > AllocSize)
1872         AggBuffer->addZeros(Bytes - AllocSize);
1873     } else if (isa<ConstantAggregateZero>(CPV))
1874       AggBuffer->addZeros(Bytes);
1875     else
1876       llvm_unreachable("Unexpected Constant type");
1877     break;
1878   }
1879 
1880   default:
1881     llvm_unreachable("unsupported type");
1882   }
1883 }
1884 
1885 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1886                                               AggBuffer *aggBuffer) {
1887   const DataLayout &DL = getDataLayout();
1888   int Bytes;
1889 
1890   // Integers of arbitrary width
1891   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1892     APInt Val = CI->getValue();
1893     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1894       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1895       aggBuffer->addBytes(&Byte, 1, 1);
1896       Val.lshrInPlace(8);
1897     }
1898     return;
1899   }
1900 
1901   // Old constants
1902   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1903     if (CPV->getNumOperands())
1904       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1905         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1906     return;
1907   }
1908 
1909   if (const ConstantDataSequential *CDS =
1910           dyn_cast<ConstantDataSequential>(CPV)) {
1911     if (CDS->getNumElements())
1912       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1913         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1914                      aggBuffer);
1915     return;
1916   }
1917 
1918   if (isa<ConstantStruct>(CPV)) {
1919     if (CPV->getNumOperands()) {
1920       StructType *ST = cast<StructType>(CPV->getType());
1921       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1922         if (i == (e - 1))
1923           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1924                   DL.getTypeAllocSize(ST) -
1925                   DL.getStructLayout(ST)->getElementOffset(i);
1926         else
1927           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1928                   DL.getStructLayout(ST)->getElementOffset(i);
1929         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1930       }
1931     }
1932     return;
1933   }
1934   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1935 }
1936 
1937 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1938 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1939 /// expressions that are representable in PTX and create
1940 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1941 const MCExpr *
1942 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1943   MCContext &Ctx = OutContext;
1944 
1945   if (CV->isNullValue() || isa<UndefValue>(CV))
1946     return MCConstantExpr::create(0, Ctx);
1947 
1948   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1949     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1950 
1951   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1952     const MCSymbolRefExpr *Expr =
1953       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1954     if (ProcessingGeneric) {
1955       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1956     } else {
1957       return Expr;
1958     }
1959   }
1960 
1961   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1962   if (!CE) {
1963     llvm_unreachable("Unknown constant value to lower!");
1964   }
1965 
1966   switch (CE->getOpcode()) {
1967   default: {
1968     // If the code isn't optimized, there may be outstanding folding
1969     // opportunities. Attempt to fold the expression using DataLayout as a
1970     // last resort before giving up.
1971     Constant *C = ConstantFoldConstant(CE, getDataLayout());
1972     if (C != CE)
1973       return lowerConstantForGV(C, ProcessingGeneric);
1974 
1975     // Otherwise report the problem to the user.
1976     std::string S;
1977     raw_string_ostream OS(S);
1978     OS << "Unsupported expression in static initializer: ";
1979     CE->printAsOperand(OS, /*PrintType=*/false,
1980                    !MF ? nullptr : MF->getFunction().getParent());
1981     report_fatal_error(Twine(OS.str()));
1982   }
1983 
1984   case Instruction::AddrSpaceCast: {
1985     // Strip the addrspacecast and pass along the operand
1986     PointerType *DstTy = cast<PointerType>(CE->getType());
1987     if (DstTy->getAddressSpace() == 0) {
1988       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1989     }
1990     std::string S;
1991     raw_string_ostream OS(S);
1992     OS << "Unsupported expression in static initializer: ";
1993     CE->printAsOperand(OS, /*PrintType=*/ false,
1994                        !MF ? nullptr : MF->getFunction().getParent());
1995     report_fatal_error(Twine(OS.str()));
1996   }
1997 
1998   case Instruction::GetElementPtr: {
1999     const DataLayout &DL = getDataLayout();
2000 
2001     // Generate a symbolic expression for the byte address
2002     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2003     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2004 
2005     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2006                                             ProcessingGeneric);
2007     if (!OffsetAI)
2008       return Base;
2009 
2010     int64_t Offset = OffsetAI.getSExtValue();
2011     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2012                                    Ctx);
2013   }
2014 
2015   case Instruction::Trunc:
2016     // We emit the value and depend on the assembler to truncate the generated
2017     // expression properly.  This is important for differences between
2018     // blockaddress labels.  Since the two labels are in the same function, it
2019     // is reasonable to treat their delta as a 32-bit value.
2020     LLVM_FALLTHROUGH;
2021   case Instruction::BitCast:
2022     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2023 
2024   case Instruction::IntToPtr: {
2025     const DataLayout &DL = getDataLayout();
2026 
2027     // Handle casts to pointers by changing them into casts to the appropriate
2028     // integer type.  This promotes constant folding and simplifies this code.
2029     Constant *Op = CE->getOperand(0);
2030     Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2031                                       false/*ZExt*/);
2032     return lowerConstantForGV(Op, ProcessingGeneric);
2033   }
2034 
2035   case Instruction::PtrToInt: {
2036     const DataLayout &DL = getDataLayout();
2037 
2038     // Support only foldable casts to/from pointers that can be eliminated by
2039     // changing the pointer to the appropriately sized integer type.
2040     Constant *Op = CE->getOperand(0);
2041     Type *Ty = CE->getType();
2042 
2043     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2044 
2045     // We can emit the pointer value into this slot if the slot is an
2046     // integer slot equal to the size of the pointer.
2047     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2048       return OpExpr;
2049 
2050     // Otherwise the pointer is smaller than the resultant integer, mask off
2051     // the high bits so we are sure to get a proper truncation if the input is
2052     // a constant expr.
2053     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2054     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2055     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2056   }
2057 
2058   // The MC library also has a right-shift operator, but it isn't consistently
2059   // signed or unsigned between different targets.
2060   case Instruction::Add: {
2061     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2062     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2063     switch (CE->getOpcode()) {
2064     default: llvm_unreachable("Unknown binary operator constant cast expr");
2065     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2066     }
2067   }
2068   }
2069 }
2070 
2071 // Copy of MCExpr::print customized for NVPTX
2072 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2073   switch (Expr.getKind()) {
2074   case MCExpr::Target:
2075     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2076   case MCExpr::Constant:
2077     OS << cast<MCConstantExpr>(Expr).getValue();
2078     return;
2079 
2080   case MCExpr::SymbolRef: {
2081     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2082     const MCSymbol &Sym = SRE.getSymbol();
2083     Sym.print(OS, MAI);
2084     return;
2085   }
2086 
2087   case MCExpr::Unary: {
2088     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2089     switch (UE.getOpcode()) {
2090     case MCUnaryExpr::LNot:  OS << '!'; break;
2091     case MCUnaryExpr::Minus: OS << '-'; break;
2092     case MCUnaryExpr::Not:   OS << '~'; break;
2093     case MCUnaryExpr::Plus:  OS << '+'; break;
2094     }
2095     printMCExpr(*UE.getSubExpr(), OS);
2096     return;
2097   }
2098 
2099   case MCExpr::Binary: {
2100     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2101 
2102     // Only print parens around the LHS if it is non-trivial.
2103     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2104         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2105       printMCExpr(*BE.getLHS(), OS);
2106     } else {
2107       OS << '(';
2108       printMCExpr(*BE.getLHS(), OS);
2109       OS<< ')';
2110     }
2111 
2112     switch (BE.getOpcode()) {
2113     case MCBinaryExpr::Add:
2114       // Print "X-42" instead of "X+-42".
2115       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2116         if (RHSC->getValue() < 0) {
2117           OS << RHSC->getValue();
2118           return;
2119         }
2120       }
2121 
2122       OS <<  '+';
2123       break;
2124     default: llvm_unreachable("Unhandled binary operator");
2125     }
2126 
2127     // Only print parens around the LHS if it is non-trivial.
2128     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2129       printMCExpr(*BE.getRHS(), OS);
2130     } else {
2131       OS << '(';
2132       printMCExpr(*BE.getRHS(), OS);
2133       OS << ')';
2134     }
2135     return;
2136   }
2137   }
2138 
2139   llvm_unreachable("Invalid expression kind!");
2140 }
2141 
2142 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2143 ///
2144 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2145                                       const char *ExtraCode, raw_ostream &O) {
2146   if (ExtraCode && ExtraCode[0]) {
2147     if (ExtraCode[1] != 0)
2148       return true; // Unknown modifier.
2149 
2150     switch (ExtraCode[0]) {
2151     default:
2152       // See if this is a generic print operand
2153       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2154     case 'r':
2155       break;
2156     }
2157   }
2158 
2159   printOperand(MI, OpNo, O);
2160 
2161   return false;
2162 }
2163 
2164 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2165                                             unsigned OpNo,
2166                                             const char *ExtraCode,
2167                                             raw_ostream &O) {
2168   if (ExtraCode && ExtraCode[0])
2169     return true; // Unknown modifier
2170 
2171   O << '[';
2172   printMemOperand(MI, OpNo, O);
2173   O << ']';
2174 
2175   return false;
2176 }
2177 
2178 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2179                                    raw_ostream &O) {
2180   const MachineOperand &MO = MI->getOperand(opNum);
2181   switch (MO.getType()) {
2182   case MachineOperand::MO_Register:
2183     if (Register::isPhysicalRegister(MO.getReg())) {
2184       if (MO.getReg() == NVPTX::VRDepot)
2185         O << DEPOTNAME << getFunctionNumber();
2186       else
2187         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2188     } else {
2189       emitVirtualRegister(MO.getReg(), O);
2190     }
2191     break;
2192 
2193   case MachineOperand::MO_Immediate:
2194     O << MO.getImm();
2195     break;
2196 
2197   case MachineOperand::MO_FPImmediate:
2198     printFPConstant(MO.getFPImm(), O);
2199     break;
2200 
2201   case MachineOperand::MO_GlobalAddress:
2202     PrintSymbolOperand(MO, O);
2203     break;
2204 
2205   case MachineOperand::MO_MachineBasicBlock:
2206     MO.getMBB()->getSymbol()->print(O, MAI);
2207     break;
2208 
2209   default:
2210     llvm_unreachable("Operand type not supported.");
2211   }
2212 }
2213 
2214 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2215                                       raw_ostream &O, const char *Modifier) {
2216   printOperand(MI, opNum, O);
2217 
2218   if (Modifier && strcmp(Modifier, "add") == 0) {
2219     O << ", ";
2220     printOperand(MI, opNum + 1, O);
2221   } else {
2222     if (MI->getOperand(opNum + 1).isImm() &&
2223         MI->getOperand(opNum + 1).getImm() == 0)
2224       return; // don't print ',0' or '+0'
2225     O << "+";
2226     printOperand(MI, opNum + 1, O);
2227   }
2228 }
2229 
2230 // Force static initialization.
2231 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2232   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2233   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2234 }
2235