xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (revision 90b5fc95832da64a5f56295e687379732c33718f)
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetLowering.h"
49 #include "llvm/CodeGen/TargetRegisterInfo.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/Attributes.h"
52 #include "llvm/IR/BasicBlock.h"
53 #include "llvm/IR/Constant.h"
54 #include "llvm/IR/Constants.h"
55 #include "llvm/IR/DataLayout.h"
56 #include "llvm/IR/DebugInfo.h"
57 #include "llvm/IR/DebugInfoMetadata.h"
58 #include "llvm/IR/DebugLoc.h"
59 #include "llvm/IR/DerivedTypes.h"
60 #include "llvm/IR/Function.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/ErrorHandling.h"
77 #include "llvm/Support/MachineValueType.h"
78 #include "llvm/Support/Path.h"
79 #include "llvm/Support/TargetRegistry.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Target/TargetLoweringObjectFile.h"
82 #include "llvm/Target/TargetMachine.h"
83 #include "llvm/Transforms/Utils/UnrollLoop.h"
84 #include <cassert>
85 #include <cstdint>
86 #include <cstring>
87 #include <new>
88 #include <string>
89 #include <utility>
90 #include <vector>
91 
92 using namespace llvm;
93 
94 #define DEPOTNAME "__local_depot"
95 
96 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
97 /// depends.
98 static void
99 DiscoverDependentGlobals(const Value *V,
100                          DenseSet<const GlobalVariable *> &Globals) {
101   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
102     Globals.insert(GV);
103   else {
104     if (const User *U = dyn_cast<User>(V)) {
105       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
106         DiscoverDependentGlobals(U->getOperand(i), Globals);
107       }
108     }
109   }
110 }
111 
112 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
113 /// instances to be emitted, but only after any dependents have been added
114 /// first.s
115 static void
116 VisitGlobalVariableForEmission(const GlobalVariable *GV,
117                                SmallVectorImpl<const GlobalVariable *> &Order,
118                                DenseSet<const GlobalVariable *> &Visited,
119                                DenseSet<const GlobalVariable *> &Visiting) {
120   // Have we already visited this one?
121   if (Visited.count(GV))
122     return;
123 
124   // Do we have a circular dependency?
125   if (!Visiting.insert(GV).second)
126     report_fatal_error("Circular dependency found in global variable set");
127 
128   // Make sure we visit all dependents first
129   DenseSet<const GlobalVariable *> Others;
130   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
131     DiscoverDependentGlobals(GV->getOperand(i), Others);
132 
133   for (DenseSet<const GlobalVariable *>::iterator I = Others.begin(),
134                                                   E = Others.end();
135        I != E; ++I)
136     VisitGlobalVariableForEmission(*I, Order, Visited, Visiting);
137 
138   // Now we can visit ourself
139   Order.push_back(GV);
140   Visited.insert(GV);
141   Visiting.erase(GV);
142 }
143 
144 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
145   MCInst Inst;
146   lowerToMCInst(MI, Inst);
147   EmitToStreamer(*OutStreamer, Inst);
148 }
149 
150 // Handle symbol backtracking for targets that do not support image handles
151 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
152                                            unsigned OpNo, MCOperand &MCOp) {
153   const MachineOperand &MO = MI->getOperand(OpNo);
154   const MCInstrDesc &MCID = MI->getDesc();
155 
156   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
157     // This is a texture fetch, so operand 4 is a texref and operand 5 is
158     // a samplerref
159     if (OpNo == 4 && MO.isImm()) {
160       lowerImageHandleSymbol(MO.getImm(), MCOp);
161       return true;
162     }
163     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
164       lowerImageHandleSymbol(MO.getImm(), MCOp);
165       return true;
166     }
167 
168     return false;
169   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
170     unsigned VecSize =
171       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
172 
173     // For a surface load of vector size N, the Nth operand will be the surfref
174     if (OpNo == VecSize && MO.isImm()) {
175       lowerImageHandleSymbol(MO.getImm(), MCOp);
176       return true;
177     }
178 
179     return false;
180   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
181     // This is a surface store, so operand 0 is a surfref
182     if (OpNo == 0 && MO.isImm()) {
183       lowerImageHandleSymbol(MO.getImm(), MCOp);
184       return true;
185     }
186 
187     return false;
188   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
189     // This is a query, so operand 1 is a surfref/texref
190     if (OpNo == 1 && MO.isImm()) {
191       lowerImageHandleSymbol(MO.getImm(), MCOp);
192       return true;
193     }
194 
195     return false;
196   }
197 
198   return false;
199 }
200 
201 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
202   // Ewwww
203   LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
204   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
205   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
206   const char *Sym = MFI->getImageHandleSymbol(Index);
207   std::string *SymNamePtr =
208     nvTM.getManagedStrPool()->getManagedString(Sym);
209   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr)));
210 }
211 
212 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
213   OutMI.setOpcode(MI->getOpcode());
214   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
215   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
216     const MachineOperand &MO = MI->getOperand(0);
217     OutMI.addOperand(GetSymbolRef(
218       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
219     return;
220   }
221 
222   const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
223   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
224     const MachineOperand &MO = MI->getOperand(i);
225 
226     MCOperand MCOp;
227     if (!STI.hasImageHandles()) {
228       if (lowerImageHandleOperand(MI, i, MCOp)) {
229         OutMI.addOperand(MCOp);
230         continue;
231       }
232     }
233 
234     if (lowerOperand(MO, MCOp))
235       OutMI.addOperand(MCOp);
236   }
237 }
238 
239 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
240                                    MCOperand &MCOp) {
241   switch (MO.getType()) {
242   default: llvm_unreachable("unknown operand type");
243   case MachineOperand::MO_Register:
244     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
245     break;
246   case MachineOperand::MO_Immediate:
247     MCOp = MCOperand::createImm(MO.getImm());
248     break;
249   case MachineOperand::MO_MachineBasicBlock:
250     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
251         MO.getMBB()->getSymbol(), OutContext));
252     break;
253   case MachineOperand::MO_ExternalSymbol:
254     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
255     break;
256   case MachineOperand::MO_GlobalAddress:
257     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
258     break;
259   case MachineOperand::MO_FPImmediate: {
260     const ConstantFP *Cnt = MO.getFPImm();
261     const APFloat &Val = Cnt->getValueAPF();
262 
263     switch (Cnt->getType()->getTypeID()) {
264     default: report_fatal_error("Unsupported FP type"); break;
265     case Type::HalfTyID:
266       MCOp = MCOperand::createExpr(
267         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
268       break;
269     case Type::FloatTyID:
270       MCOp = MCOperand::createExpr(
271         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
272       break;
273     case Type::DoubleTyID:
274       MCOp = MCOperand::createExpr(
275         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
276       break;
277     }
278     break;
279   }
280   }
281   return true;
282 }
283 
284 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
285   if (Register::isVirtualRegister(Reg)) {
286     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
287 
288     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
289     unsigned RegNum = RegMap[Reg];
290 
291     // Encode the register class in the upper 4 bits
292     // Must be kept in sync with NVPTXInstPrinter::printRegName
293     unsigned Ret = 0;
294     if (RC == &NVPTX::Int1RegsRegClass) {
295       Ret = (1 << 28);
296     } else if (RC == &NVPTX::Int16RegsRegClass) {
297       Ret = (2 << 28);
298     } else if (RC == &NVPTX::Int32RegsRegClass) {
299       Ret = (3 << 28);
300     } else if (RC == &NVPTX::Int64RegsRegClass) {
301       Ret = (4 << 28);
302     } else if (RC == &NVPTX::Float32RegsRegClass) {
303       Ret = (5 << 28);
304     } else if (RC == &NVPTX::Float64RegsRegClass) {
305       Ret = (6 << 28);
306     } else if (RC == &NVPTX::Float16RegsRegClass) {
307       Ret = (7 << 28);
308     } else if (RC == &NVPTX::Float16x2RegsRegClass) {
309       Ret = (8 << 28);
310     } else {
311       report_fatal_error("Bad register class");
312     }
313 
314     // Insert the vreg number
315     Ret |= (RegNum & 0x0FFFFFFF);
316     return Ret;
317   } else {
318     // Some special-use registers are actually physical registers.
319     // Encode this as the register class ID of 0 and the real register ID.
320     return Reg & 0x0FFFFFFF;
321   }
322 }
323 
324 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
325   const MCExpr *Expr;
326   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
327                                  OutContext);
328   return MCOperand::createExpr(Expr);
329 }
330 
331 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
332   const DataLayout &DL = getDataLayout();
333   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
334   const TargetLowering *TLI = STI.getTargetLowering();
335 
336   Type *Ty = F->getReturnType();
337 
338   bool isABI = (STI.getSmVersion() >= 20);
339 
340   if (Ty->getTypeID() == Type::VoidTyID)
341     return;
342 
343   O << " (";
344 
345   if (isABI) {
346     if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
347       unsigned size = 0;
348       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
349         size = ITy->getBitWidth();
350       } else {
351         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
352         size = Ty->getPrimitiveSizeInBits();
353       }
354       // PTX ABI requires all scalar return values to be at least 32
355       // bits in size.  fp16 normally uses .b16 as its storage type in
356       // PTX, so its size must be adjusted here, too.
357       if (size < 32)
358         size = 32;
359 
360       O << ".param .b" << size << " func_retval0";
361     } else if (isa<PointerType>(Ty)) {
362       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
363         << " func_retval0";
364     } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
365       unsigned totalsz = DL.getTypeAllocSize(Ty);
366       unsigned retAlignment = 0;
367       if (!getAlign(*F, 0, retAlignment))
368         retAlignment = DL.getABITypeAlignment(Ty);
369       O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
370         << "]";
371     } else
372       llvm_unreachable("Unknown return type");
373   } else {
374     SmallVector<EVT, 16> vtparts;
375     ComputeValueVTs(*TLI, DL, Ty, vtparts);
376     unsigned idx = 0;
377     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378       unsigned elems = 1;
379       EVT elemtype = vtparts[i];
380       if (vtparts[i].isVector()) {
381         elems = vtparts[i].getVectorNumElements();
382         elemtype = vtparts[i].getVectorElementType();
383       }
384 
385       for (unsigned j = 0, je = elems; j != je; ++j) {
386         unsigned sz = elemtype.getSizeInBits();
387         if (elemtype.isInteger() && (sz < 32))
388           sz = 32;
389         O << ".reg .b" << sz << " func_retval" << idx;
390         if (j < je - 1)
391           O << ", ";
392         ++idx;
393       }
394       if (i < e - 1)
395         O << ", ";
396     }
397   }
398   O << ") ";
399 }
400 
401 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
402                                         raw_ostream &O) {
403   const Function &F = MF.getFunction();
404   printReturnValStr(&F, O);
405 }
406 
407 // Return true if MBB is the header of a loop marked with
408 // llvm.loop.unroll.disable.
409 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll".
410 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
411     const MachineBasicBlock &MBB) const {
412   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
413   // We insert .pragma "nounroll" only to the loop header.
414   if (!LI.isLoopHeader(&MBB))
415     return false;
416 
417   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
418   // we iterate through each back edge of the loop with header MBB, and check
419   // whether its metadata contains llvm.loop.unroll.disable.
420   for (auto I = MBB.pred_begin(); I != MBB.pred_end(); ++I) {
421     const MachineBasicBlock *PMBB = *I;
422     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
423       // Edges from other loops to MBB are not back edges.
424       continue;
425     }
426     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
427       if (MDNode *LoopID =
428               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
429         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
430           return true;
431       }
432     }
433   }
434   return false;
435 }
436 
437 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
438   AsmPrinter::emitBasicBlockStart(MBB);
439   if (isLoopHeaderOfNoUnroll(MBB))
440     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
441 }
442 
443 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
444   SmallString<128> Str;
445   raw_svector_ostream O(Str);
446 
447   if (!GlobalsEmitted) {
448     emitGlobals(*MF->getFunction().getParent());
449     GlobalsEmitted = true;
450   }
451 
452   // Set up
453   MRI = &MF->getRegInfo();
454   F = &MF->getFunction();
455   emitLinkageDirective(F, O);
456   if (isKernelFunction(*F))
457     O << ".entry ";
458   else {
459     O << ".func ";
460     printReturnValStr(*MF, O);
461   }
462 
463   CurrentFnSym->print(O, MAI);
464 
465   emitFunctionParamList(*MF, O);
466 
467   if (isKernelFunction(*F))
468     emitKernelFunctionDirectives(*F, O);
469 
470   OutStreamer->emitRawText(O.str());
471 
472   VRegMapping.clear();
473   // Emit open brace for function body.
474   OutStreamer->emitRawText(StringRef("{\n"));
475   setAndEmitFunctionVirtualRegisters(*MF);
476   // Emit initial .loc debug directive for correct relocation symbol data.
477   if (MMI && MMI->hasDebugInfo())
478     emitInitialRawDwarfLocDirective(*MF);
479 }
480 
481 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
482   bool Result = AsmPrinter::runOnMachineFunction(F);
483   // Emit closing brace for the body of function F.
484   // The closing brace must be emitted here because we need to emit additional
485   // debug labels/data after the last basic block.
486   // We need to emit the closing brace here because we don't have function that
487   // finished emission of the function body.
488   OutStreamer->emitRawText(StringRef("}\n"));
489   return Result;
490 }
491 
492 void NVPTXAsmPrinter::emitFunctionBodyStart() {
493   SmallString<128> Str;
494   raw_svector_ostream O(Str);
495   emitDemotedVars(&MF->getFunction(), O);
496   OutStreamer->emitRawText(O.str());
497 }
498 
499 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
500   VRegMapping.clear();
501 }
502 
503 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
504     SmallString<128> Str;
505     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
506     return OutContext.getOrCreateSymbol(Str);
507 }
508 
509 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
510   Register RegNo = MI->getOperand(0).getReg();
511   if (Register::isVirtualRegister(RegNo)) {
512     OutStreamer->AddComment(Twine("implicit-def: ") +
513                             getVirtualRegisterName(RegNo));
514   } else {
515     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
516     OutStreamer->AddComment(Twine("implicit-def: ") +
517                             STI.getRegisterInfo()->getName(RegNo));
518   }
519   OutStreamer->AddBlankLine();
520 }
521 
522 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
523                                                    raw_ostream &O) const {
524   // If the NVVM IR has some of reqntid* specified, then output
525   // the reqntid directive, and set the unspecified ones to 1.
526   // If none of reqntid* is specified, don't output reqntid directive.
527   unsigned reqntidx, reqntidy, reqntidz;
528   bool specified = false;
529   if (!getReqNTIDx(F, reqntidx))
530     reqntidx = 1;
531   else
532     specified = true;
533   if (!getReqNTIDy(F, reqntidy))
534     reqntidy = 1;
535   else
536     specified = true;
537   if (!getReqNTIDz(F, reqntidz))
538     reqntidz = 1;
539   else
540     specified = true;
541 
542   if (specified)
543     O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
544       << "\n";
545 
546   // If the NVVM IR has some of maxntid* specified, then output
547   // the maxntid directive, and set the unspecified ones to 1.
548   // If none of maxntid* is specified, don't output maxntid directive.
549   unsigned maxntidx, maxntidy, maxntidz;
550   specified = false;
551   if (!getMaxNTIDx(F, maxntidx))
552     maxntidx = 1;
553   else
554     specified = true;
555   if (!getMaxNTIDy(F, maxntidy))
556     maxntidy = 1;
557   else
558     specified = true;
559   if (!getMaxNTIDz(F, maxntidz))
560     maxntidz = 1;
561   else
562     specified = true;
563 
564   if (specified)
565     O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
566       << "\n";
567 
568   unsigned mincta;
569   if (getMinCTASm(F, mincta))
570     O << ".minnctapersm " << mincta << "\n";
571 
572   unsigned maxnreg;
573   if (getMaxNReg(F, maxnreg))
574     O << ".maxnreg " << maxnreg << "\n";
575 }
576 
577 std::string
578 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
579   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
580 
581   std::string Name;
582   raw_string_ostream NameStr(Name);
583 
584   VRegRCMap::const_iterator I = VRegMapping.find(RC);
585   assert(I != VRegMapping.end() && "Bad register class");
586   const DenseMap<unsigned, unsigned> &RegMap = I->second;
587 
588   VRegMap::const_iterator VI = RegMap.find(Reg);
589   assert(VI != RegMap.end() && "Bad virtual register");
590   unsigned MappedVR = VI->second;
591 
592   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
593 
594   NameStr.flush();
595   return Name;
596 }
597 
598 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
599                                           raw_ostream &O) {
600   O << getVirtualRegisterName(vr);
601 }
602 
603 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
604   emitLinkageDirective(F, O);
605   if (isKernelFunction(*F))
606     O << ".entry ";
607   else
608     O << ".func ";
609   printReturnValStr(F, O);
610   getSymbol(F)->print(O, MAI);
611   O << "\n";
612   emitFunctionParamList(F, O);
613   O << ";\n";
614 }
615 
616 static bool usedInGlobalVarDef(const Constant *C) {
617   if (!C)
618     return false;
619 
620   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
621     return GV->getName() != "llvm.used";
622   }
623 
624   for (const User *U : C->users())
625     if (const Constant *C = dyn_cast<Constant>(U))
626       if (usedInGlobalVarDef(C))
627         return true;
628 
629   return false;
630 }
631 
632 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
633   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
634     if (othergv->getName() == "llvm.used")
635       return true;
636   }
637 
638   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
639     if (instr->getParent() && instr->getParent()->getParent()) {
640       const Function *curFunc = instr->getParent()->getParent();
641       if (oneFunc && (curFunc != oneFunc))
642         return false;
643       oneFunc = curFunc;
644       return true;
645     } else
646       return false;
647   }
648 
649   for (const User *UU : U->users())
650     if (!usedInOneFunc(UU, oneFunc))
651       return false;
652 
653   return true;
654 }
655 
656 /* Find out if a global variable can be demoted to local scope.
657  * Currently, this is valid for CUDA shared variables, which have local
658  * scope and global lifetime. So the conditions to check are :
659  * 1. Is the global variable in shared address space?
660  * 2. Does it have internal linkage?
661  * 3. Is the global variable referenced only in one function?
662  */
663 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
664   if (!gv->hasInternalLinkage())
665     return false;
666   PointerType *Pty = gv->getType();
667   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
668     return false;
669 
670   const Function *oneFunc = nullptr;
671 
672   bool flag = usedInOneFunc(gv, oneFunc);
673   if (!flag)
674     return false;
675   if (!oneFunc)
676     return false;
677   f = oneFunc;
678   return true;
679 }
680 
681 static bool useFuncSeen(const Constant *C,
682                         DenseMap<const Function *, bool> &seenMap) {
683   for (const User *U : C->users()) {
684     if (const Constant *cu = dyn_cast<Constant>(U)) {
685       if (useFuncSeen(cu, seenMap))
686         return true;
687     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
688       const BasicBlock *bb = I->getParent();
689       if (!bb)
690         continue;
691       const Function *caller = bb->getParent();
692       if (!caller)
693         continue;
694       if (seenMap.find(caller) != seenMap.end())
695         return true;
696     }
697   }
698   return false;
699 }
700 
701 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
702   DenseMap<const Function *, bool> seenMap;
703   for (Module::const_iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) {
704     const Function *F = &*FI;
705 
706     if (F->getAttributes().hasFnAttribute("nvptx-libcall-callee")) {
707       emitDeclaration(F, O);
708       continue;
709     }
710 
711     if (F->isDeclaration()) {
712       if (F->use_empty())
713         continue;
714       if (F->getIntrinsicID())
715         continue;
716       emitDeclaration(F, O);
717       continue;
718     }
719     for (const User *U : F->users()) {
720       if (const Constant *C = dyn_cast<Constant>(U)) {
721         if (usedInGlobalVarDef(C)) {
722           // The use is in the initialization of a global variable
723           // that is a function pointer, so print a declaration
724           // for the original function
725           emitDeclaration(F, O);
726           break;
727         }
728         // Emit a declaration of this function if the function that
729         // uses this constant expr has already been seen.
730         if (useFuncSeen(C, seenMap)) {
731           emitDeclaration(F, O);
732           break;
733         }
734       }
735 
736       if (!isa<Instruction>(U))
737         continue;
738       const Instruction *instr = cast<Instruction>(U);
739       const BasicBlock *bb = instr->getParent();
740       if (!bb)
741         continue;
742       const Function *caller = bb->getParent();
743       if (!caller)
744         continue;
745 
746       // If a caller has already been seen, then the caller is
747       // appearing in the module before the callee. so print out
748       // a declaration for the callee.
749       if (seenMap.find(caller) != seenMap.end()) {
750         emitDeclaration(F, O);
751         break;
752       }
753     }
754     seenMap[F] = true;
755   }
756 }
757 
758 static bool isEmptyXXStructor(GlobalVariable *GV) {
759   if (!GV) return true;
760   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
761   if (!InitList) return true;  // Not an array; we don't know how to parse.
762   return InitList->getNumOperands() == 0;
763 }
764 
765 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
766   // Construct a default subtarget off of the TargetMachine defaults. The
767   // rest of NVPTX isn't friendly to change subtargets per function and
768   // so the default TargetMachine will have all of the options.
769   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
770   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
771   SmallString<128> Str1;
772   raw_svector_ostream OS1(Str1);
773 
774   // Emit header before any dwarf directives are emitted below.
775   emitHeader(M, OS1, *STI);
776   OutStreamer->emitRawText(OS1.str());
777 }
778 
779 bool NVPTXAsmPrinter::doInitialization(Module &M) {
780   if (M.alias_size()) {
781     report_fatal_error("Module has aliases, which NVPTX does not support.");
782     return true; // error
783   }
784   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
785     report_fatal_error(
786         "Module has a nontrivial global ctor, which NVPTX does not support.");
787     return true;  // error
788   }
789   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
790     report_fatal_error(
791         "Module has a nontrivial global dtor, which NVPTX does not support.");
792     return true;  // error
793   }
794 
795   // We need to call the parent's one explicitly.
796   bool Result = AsmPrinter::doInitialization(M);
797 
798   GlobalsEmitted = false;
799 
800   return Result;
801 }
802 
803 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
804   SmallString<128> Str2;
805   raw_svector_ostream OS2(Str2);
806 
807   emitDeclarations(M, OS2);
808 
809   // As ptxas does not support forward references of globals, we need to first
810   // sort the list of module-level globals in def-use order. We visit each
811   // global variable in order, and ensure that we emit it *after* its dependent
812   // globals. We use a little extra memory maintaining both a set and a list to
813   // have fast searches while maintaining a strict ordering.
814   SmallVector<const GlobalVariable *, 8> Globals;
815   DenseSet<const GlobalVariable *> GVVisited;
816   DenseSet<const GlobalVariable *> GVVisiting;
817 
818   // Visit each global variable, in order
819   for (const GlobalVariable &I : M.globals())
820     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
821 
822   assert(GVVisited.size() == M.getGlobalList().size() &&
823          "Missed a global variable");
824   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
825 
826   // Print out module-level global variables in proper order
827   for (unsigned i = 0, e = Globals.size(); i != e; ++i)
828     printModuleLevelGV(Globals[i], OS2);
829 
830   OS2 << '\n';
831 
832   OutStreamer->emitRawText(OS2.str());
833 }
834 
835 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
836                                  const NVPTXSubtarget &STI) {
837   O << "//\n";
838   O << "// Generated by LLVM NVPTX Back-End\n";
839   O << "//\n";
840   O << "\n";
841 
842   unsigned PTXVersion = STI.getPTXVersion();
843   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
844 
845   O << ".target ";
846   O << STI.getTargetName();
847 
848   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
849   if (NTM.getDrvInterface() == NVPTX::NVCL)
850     O << ", texmode_independent";
851 
852   bool HasFullDebugInfo = false;
853   for (DICompileUnit *CU : M.debug_compile_units()) {
854     switch(CU->getEmissionKind()) {
855     case DICompileUnit::NoDebug:
856     case DICompileUnit::DebugDirectivesOnly:
857       break;
858     case DICompileUnit::LineTablesOnly:
859     case DICompileUnit::FullDebug:
860       HasFullDebugInfo = true;
861       break;
862     }
863     if (HasFullDebugInfo)
864       break;
865   }
866   if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
867     O << ", debug";
868 
869   O << "\n";
870 
871   O << ".address_size ";
872   if (NTM.is64Bit())
873     O << "64";
874   else
875     O << "32";
876   O << "\n";
877 
878   O << "\n";
879 }
880 
881 bool NVPTXAsmPrinter::doFinalization(Module &M) {
882   bool HasDebugInfo = MMI && MMI->hasDebugInfo();
883 
884   // If we did not emit any functions, then the global declarations have not
885   // yet been emitted.
886   if (!GlobalsEmitted) {
887     emitGlobals(M);
888     GlobalsEmitted = true;
889   }
890 
891   // XXX Temproarily remove global variables so that doFinalization() will not
892   // emit them again (global variables are emitted at beginning).
893 
894   Module::GlobalListType &global_list = M.getGlobalList();
895   int i, n = global_list.size();
896   GlobalVariable **gv_array = new GlobalVariable *[n];
897 
898   // first, back-up GlobalVariable in gv_array
899   i = 0;
900   for (Module::global_iterator I = global_list.begin(), E = global_list.end();
901        I != E; ++I)
902     gv_array[i++] = &*I;
903 
904   // second, empty global_list
905   while (!global_list.empty())
906     global_list.remove(global_list.begin());
907 
908   // call doFinalization
909   bool ret = AsmPrinter::doFinalization(M);
910 
911   // now we restore global variables
912   for (i = 0; i < n; i++)
913     global_list.insert(global_list.end(), gv_array[i]);
914 
915   clearAnnotationCache(&M);
916 
917   delete[] gv_array;
918   // Close the last emitted section
919   if (HasDebugInfo) {
920     static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
921         ->closeLastSection();
922     // Emit empty .debug_loc section for better support of the empty files.
923     OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
924   }
925 
926   // Output last DWARF .file directives, if any.
927   static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
928       ->outputDwarfFileDirectives();
929 
930   return ret;
931 
932   //bool Result = AsmPrinter::doFinalization(M);
933   // Instead of calling the parents doFinalization, we may
934   // clone parents doFinalization and customize here.
935   // Currently, we if NVISA out the EmitGlobals() in
936   // parent's doFinalization, which is too intrusive.
937   //
938   // Same for the doInitialization.
939   //return Result;
940 }
941 
942 // This function emits appropriate linkage directives for
943 // functions and global variables.
944 //
945 // extern function declaration            -> .extern
946 // extern function definition             -> .visible
947 // external global variable with init     -> .visible
948 // external without init                  -> .extern
949 // appending                              -> not allowed, assert.
950 // for any linkage other than
951 // internal, private, linker_private,
952 // linker_private_weak, linker_private_weak_def_auto,
953 // we emit                                -> .weak.
954 
955 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
956                                            raw_ostream &O) {
957   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
958     if (V->hasExternalLinkage()) {
959       if (isa<GlobalVariable>(V)) {
960         const GlobalVariable *GVar = cast<GlobalVariable>(V);
961         if (GVar) {
962           if (GVar->hasInitializer())
963             O << ".visible ";
964           else
965             O << ".extern ";
966         }
967       } else if (V->isDeclaration())
968         O << ".extern ";
969       else
970         O << ".visible ";
971     } else if (V->hasAppendingLinkage()) {
972       std::string msg;
973       msg.append("Error: ");
974       msg.append("Symbol ");
975       if (V->hasName())
976         msg.append(std::string(V->getName()));
977       msg.append("has unsupported appending linkage type");
978       llvm_unreachable(msg.c_str());
979     } else if (!V->hasInternalLinkage() &&
980                !V->hasPrivateLinkage()) {
981       O << ".weak ";
982     }
983   }
984 }
985 
986 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
987                                          raw_ostream &O,
988                                          bool processDemoted) {
989   // Skip meta data
990   if (GVar->hasSection()) {
991     if (GVar->getSection() == "llvm.metadata")
992       return;
993   }
994 
995   // Skip LLVM intrinsic global variables
996   if (GVar->getName().startswith("llvm.") ||
997       GVar->getName().startswith("nvvm."))
998     return;
999 
1000   const DataLayout &DL = getDataLayout();
1001 
1002   // GlobalVariables are always constant pointers themselves.
1003   PointerType *PTy = GVar->getType();
1004   Type *ETy = GVar->getValueType();
1005 
1006   if (GVar->hasExternalLinkage()) {
1007     if (GVar->hasInitializer())
1008       O << ".visible ";
1009     else
1010       O << ".extern ";
1011   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1012              GVar->hasAvailableExternallyLinkage() ||
1013              GVar->hasCommonLinkage()) {
1014     O << ".weak ";
1015   }
1016 
1017   if (isTexture(*GVar)) {
1018     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1019     return;
1020   }
1021 
1022   if (isSurface(*GVar)) {
1023     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1024     return;
1025   }
1026 
1027   if (GVar->isDeclaration()) {
1028     // (extern) declarations, no definition or initializer
1029     // Currently the only known declaration is for an automatic __local
1030     // (.shared) promoted to global.
1031     emitPTXGlobalVariable(GVar, O);
1032     O << ";\n";
1033     return;
1034   }
1035 
1036   if (isSampler(*GVar)) {
1037     O << ".global .samplerref " << getSamplerName(*GVar);
1038 
1039     const Constant *Initializer = nullptr;
1040     if (GVar->hasInitializer())
1041       Initializer = GVar->getInitializer();
1042     const ConstantInt *CI = nullptr;
1043     if (Initializer)
1044       CI = dyn_cast<ConstantInt>(Initializer);
1045     if (CI) {
1046       unsigned sample = CI->getZExtValue();
1047 
1048       O << " = { ";
1049 
1050       for (int i = 0,
1051                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1052            i < 3; i++) {
1053         O << "addr_mode_" << i << " = ";
1054         switch (addr) {
1055         case 0:
1056           O << "wrap";
1057           break;
1058         case 1:
1059           O << "clamp_to_border";
1060           break;
1061         case 2:
1062           O << "clamp_to_edge";
1063           break;
1064         case 3:
1065           O << "wrap";
1066           break;
1067         case 4:
1068           O << "mirror";
1069           break;
1070         }
1071         O << ", ";
1072       }
1073       O << "filter_mode = ";
1074       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1075       case 0:
1076         O << "nearest";
1077         break;
1078       case 1:
1079         O << "linear";
1080         break;
1081       case 2:
1082         llvm_unreachable("Anisotropic filtering is not supported");
1083       default:
1084         O << "nearest";
1085         break;
1086       }
1087       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1088         O << ", force_unnormalized_coords = 1";
1089       }
1090       O << " }";
1091     }
1092 
1093     O << ";\n";
1094     return;
1095   }
1096 
1097   if (GVar->hasPrivateLinkage()) {
1098     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1099       return;
1100 
1101     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1102     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1103       return;
1104     if (GVar->use_empty())
1105       return;
1106   }
1107 
1108   const Function *demotedFunc = nullptr;
1109   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1110     O << "// " << GVar->getName() << " has been demoted\n";
1111     if (localDecls.find(demotedFunc) != localDecls.end())
1112       localDecls[demotedFunc].push_back(GVar);
1113     else {
1114       std::vector<const GlobalVariable *> temp;
1115       temp.push_back(GVar);
1116       localDecls[demotedFunc] = temp;
1117     }
1118     return;
1119   }
1120 
1121   O << ".";
1122   emitPTXAddressSpace(PTy->getAddressSpace(), O);
1123 
1124   if (isManaged(*GVar)) {
1125     O << " .attribute(.managed)";
1126   }
1127 
1128   if (GVar->getAlignment() == 0)
1129     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1130   else
1131     O << " .align " << GVar->getAlignment();
1132 
1133   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1134       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1135     O << " .";
1136     // Special case: ABI requires that we use .u8 for predicates
1137     if (ETy->isIntegerTy(1))
1138       O << "u8";
1139     else
1140       O << getPTXFundamentalTypeStr(ETy, false);
1141     O << " ";
1142     getSymbol(GVar)->print(O, MAI);
1143 
1144     // Ptx allows variable initilization only for constant and global state
1145     // spaces.
1146     if (GVar->hasInitializer()) {
1147       if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1148           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1149         const Constant *Initializer = GVar->getInitializer();
1150         // 'undef' is treated as there is no value specified.
1151         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1152           O << " = ";
1153           printScalarConstant(Initializer, O);
1154         }
1155       } else {
1156         // The frontend adds zero-initializer to device and constant variables
1157         // that don't have an initial value, and UndefValue to shared
1158         // variables, so skip warning for this case.
1159         if (!GVar->getInitializer()->isNullValue() &&
1160             !isa<UndefValue>(GVar->getInitializer())) {
1161           report_fatal_error("initial value of '" + GVar->getName() +
1162                              "' is not allowed in addrspace(" +
1163                              Twine(PTy->getAddressSpace()) + ")");
1164         }
1165       }
1166     }
1167   } else {
1168     unsigned int ElementSize = 0;
1169 
1170     // Although PTX has direct support for struct type and array type and
1171     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1172     // targets that support these high level field accesses. Structs, arrays
1173     // and vectors are lowered into arrays of bytes.
1174     switch (ETy->getTypeID()) {
1175     case Type::IntegerTyID: // Integers larger than 64 bits
1176     case Type::StructTyID:
1177     case Type::ArrayTyID:
1178     case Type::FixedVectorTyID:
1179       ElementSize = DL.getTypeStoreSize(ETy);
1180       // Ptx allows variable initilization only for constant and
1181       // global state spaces.
1182       if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1183            (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1184           GVar->hasInitializer()) {
1185         const Constant *Initializer = GVar->getInitializer();
1186         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1187           AggBuffer aggBuffer(ElementSize, O, *this);
1188           bufferAggregateConstant(Initializer, &aggBuffer);
1189           if (aggBuffer.numSymbols) {
1190             if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) {
1191               O << " .u64 ";
1192               getSymbol(GVar)->print(O, MAI);
1193               O << "[";
1194               O << ElementSize / 8;
1195             } else {
1196               O << " .u32 ";
1197               getSymbol(GVar)->print(O, MAI);
1198               O << "[";
1199               O << ElementSize / 4;
1200             }
1201             O << "]";
1202           } else {
1203             O << " .b8 ";
1204             getSymbol(GVar)->print(O, MAI);
1205             O << "[";
1206             O << ElementSize;
1207             O << "]";
1208           }
1209           O << " = {";
1210           aggBuffer.print();
1211           O << "}";
1212         } else {
1213           O << " .b8 ";
1214           getSymbol(GVar)->print(O, MAI);
1215           if (ElementSize) {
1216             O << "[";
1217             O << ElementSize;
1218             O << "]";
1219           }
1220         }
1221       } else {
1222         O << " .b8 ";
1223         getSymbol(GVar)->print(O, MAI);
1224         if (ElementSize) {
1225           O << "[";
1226           O << ElementSize;
1227           O << "]";
1228         }
1229       }
1230       break;
1231     default:
1232       llvm_unreachable("type not supported yet");
1233     }
1234   }
1235   O << ";\n";
1236 }
1237 
1238 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1239   if (localDecls.find(f) == localDecls.end())
1240     return;
1241 
1242   std::vector<const GlobalVariable *> &gvars = localDecls[f];
1243 
1244   for (unsigned i = 0, e = gvars.size(); i != e; ++i) {
1245     O << "\t// demoted variable\n\t";
1246     printModuleLevelGV(gvars[i], O, true);
1247   }
1248 }
1249 
1250 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1251                                           raw_ostream &O) const {
1252   switch (AddressSpace) {
1253   case ADDRESS_SPACE_LOCAL:
1254     O << "local";
1255     break;
1256   case ADDRESS_SPACE_GLOBAL:
1257     O << "global";
1258     break;
1259   case ADDRESS_SPACE_CONST:
1260     O << "const";
1261     break;
1262   case ADDRESS_SPACE_SHARED:
1263     O << "shared";
1264     break;
1265   default:
1266     report_fatal_error("Bad address space found while emitting PTX: " +
1267                        llvm::Twine(AddressSpace));
1268     break;
1269   }
1270 }
1271 
1272 std::string
1273 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1274   switch (Ty->getTypeID()) {
1275   default:
1276     llvm_unreachable("unexpected type");
1277     break;
1278   case Type::IntegerTyID: {
1279     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1280     if (NumBits == 1)
1281       return "pred";
1282     else if (NumBits <= 64) {
1283       std::string name = "u";
1284       return name + utostr(NumBits);
1285     } else {
1286       llvm_unreachable("Integer too large");
1287       break;
1288     }
1289     break;
1290   }
1291   case Type::HalfTyID:
1292     // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1293     return "b16";
1294   case Type::FloatTyID:
1295     return "f32";
1296   case Type::DoubleTyID:
1297     return "f64";
1298   case Type::PointerTyID:
1299     if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit())
1300       if (useB4PTR)
1301         return "b64";
1302       else
1303         return "u64";
1304     else if (useB4PTR)
1305       return "b32";
1306     else
1307       return "u32";
1308   }
1309   llvm_unreachable("unexpected type");
1310   return nullptr;
1311 }
1312 
1313 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1314                                             raw_ostream &O) {
1315   const DataLayout &DL = getDataLayout();
1316 
1317   // GlobalVariables are always constant pointers themselves.
1318   Type *ETy = GVar->getValueType();
1319 
1320   O << ".";
1321   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1322   if (GVar->getAlignment() == 0)
1323     O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1324   else
1325     O << " .align " << GVar->getAlignment();
1326 
1327   // Special case for i128
1328   if (ETy->isIntegerTy(128)) {
1329     O << " .b8 ";
1330     getSymbol(GVar)->print(O, MAI);
1331     O << "[16]";
1332     return;
1333   }
1334 
1335   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1336     O << " .";
1337     O << getPTXFundamentalTypeStr(ETy);
1338     O << " ";
1339     getSymbol(GVar)->print(O, MAI);
1340     return;
1341   }
1342 
1343   int64_t ElementSize = 0;
1344 
1345   // Although PTX has direct support for struct type and array type and LLVM IR
1346   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1347   // support these high level field accesses. Structs and arrays are lowered
1348   // into arrays of bytes.
1349   switch (ETy->getTypeID()) {
1350   case Type::StructTyID:
1351   case Type::ArrayTyID:
1352   case Type::FixedVectorTyID:
1353     ElementSize = DL.getTypeStoreSize(ETy);
1354     O << " .b8 ";
1355     getSymbol(GVar)->print(O, MAI);
1356     O << "[";
1357     if (ElementSize) {
1358       O << ElementSize;
1359     }
1360     O << "]";
1361     break;
1362   default:
1363     llvm_unreachable("type not supported yet");
1364   }
1365 }
1366 
1367 static unsigned int getOpenCLAlignment(const DataLayout &DL, Type *Ty) {
1368   if (Ty->isSingleValueType())
1369     return DL.getPrefTypeAlignment(Ty);
1370 
1371   auto *ATy = dyn_cast<ArrayType>(Ty);
1372   if (ATy)
1373     return getOpenCLAlignment(DL, ATy->getElementType());
1374 
1375   auto *STy = dyn_cast<StructType>(Ty);
1376   if (STy) {
1377     unsigned int alignStruct = 1;
1378     // Go through each element of the struct and find the
1379     // largest alignment.
1380     for (unsigned i = 0, e = STy->getNumElements(); i != e; i++) {
1381       Type *ETy = STy->getElementType(i);
1382       unsigned int align = getOpenCLAlignment(DL, ETy);
1383       if (align > alignStruct)
1384         alignStruct = align;
1385     }
1386     return alignStruct;
1387   }
1388 
1389   auto *FTy = dyn_cast<FunctionType>(Ty);
1390   if (FTy)
1391     return DL.getPointerPrefAlignment().value();
1392   return DL.getPrefTypeAlignment(Ty);
1393 }
1394 
1395 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1396                                      int paramIndex, raw_ostream &O) {
1397   getSymbol(I->getParent())->print(O, MAI);
1398   O << "_param_" << paramIndex;
1399 }
1400 
1401 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1402   const DataLayout &DL = getDataLayout();
1403   const AttributeList &PAL = F->getAttributes();
1404   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1405   const TargetLowering *TLI = STI.getTargetLowering();
1406   Function::const_arg_iterator I, E;
1407   unsigned paramIndex = 0;
1408   bool first = true;
1409   bool isKernelFunc = isKernelFunction(*F);
1410   bool isABI = (STI.getSmVersion() >= 20);
1411   bool hasImageHandles = STI.hasImageHandles();
1412   MVT thePointerTy = TLI->getPointerTy(DL);
1413 
1414   if (F->arg_empty()) {
1415     O << "()\n";
1416     return;
1417   }
1418 
1419   O << "(\n";
1420 
1421   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1422     Type *Ty = I->getType();
1423 
1424     if (!first)
1425       O << ",\n";
1426 
1427     first = false;
1428 
1429     // Handle image/sampler parameters
1430     if (isKernelFunction(*F)) {
1431       if (isSampler(*I) || isImage(*I)) {
1432         if (isImage(*I)) {
1433           std::string sname = std::string(I->getName());
1434           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1435             if (hasImageHandles)
1436               O << "\t.param .u64 .ptr .surfref ";
1437             else
1438               O << "\t.param .surfref ";
1439             CurrentFnSym->print(O, MAI);
1440             O << "_param_" << paramIndex;
1441           }
1442           else { // Default image is read_only
1443             if (hasImageHandles)
1444               O << "\t.param .u64 .ptr .texref ";
1445             else
1446               O << "\t.param .texref ";
1447             CurrentFnSym->print(O, MAI);
1448             O << "_param_" << paramIndex;
1449           }
1450         } else {
1451           if (hasImageHandles)
1452             O << "\t.param .u64 .ptr .samplerref ";
1453           else
1454             O << "\t.param .samplerref ";
1455           CurrentFnSym->print(O, MAI);
1456           O << "_param_" << paramIndex;
1457         }
1458         continue;
1459       }
1460     }
1461 
1462     if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal)) {
1463       if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1464         // Just print .param .align <a> .b8 .param[size];
1465         // <a> = PAL.getparamalignment
1466         // size = typeallocsize of element type
1467         const Align align = DL.getValueOrABITypeAlignment(
1468             PAL.getParamAlignment(paramIndex), Ty);
1469 
1470         unsigned sz = DL.getTypeAllocSize(Ty);
1471         O << "\t.param .align " << align.value() << " .b8 ";
1472         printParamName(I, paramIndex, O);
1473         O << "[" << sz << "]";
1474 
1475         continue;
1476       }
1477       // Just a scalar
1478       auto *PTy = dyn_cast<PointerType>(Ty);
1479       if (isKernelFunc) {
1480         if (PTy) {
1481           // Special handling for pointer arguments to kernel
1482           O << "\t.param .u" << thePointerTy.getSizeInBits() << " ";
1483 
1484           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1485               NVPTX::CUDA) {
1486             Type *ETy = PTy->getElementType();
1487             int addrSpace = PTy->getAddressSpace();
1488             switch (addrSpace) {
1489             default:
1490               O << ".ptr ";
1491               break;
1492             case ADDRESS_SPACE_CONST:
1493               O << ".ptr .const ";
1494               break;
1495             case ADDRESS_SPACE_SHARED:
1496               O << ".ptr .shared ";
1497               break;
1498             case ADDRESS_SPACE_GLOBAL:
1499               O << ".ptr .global ";
1500               break;
1501             }
1502             O << ".align " << (int)getOpenCLAlignment(DL, ETy) << " ";
1503           }
1504           printParamName(I, paramIndex, O);
1505           continue;
1506         }
1507 
1508         // non-pointer scalar to kernel func
1509         O << "\t.param .";
1510         // Special case: predicate operands become .u8 types
1511         if (Ty->isIntegerTy(1))
1512           O << "u8";
1513         else
1514           O << getPTXFundamentalTypeStr(Ty);
1515         O << " ";
1516         printParamName(I, paramIndex, O);
1517         continue;
1518       }
1519       // Non-kernel function, just print .param .b<size> for ABI
1520       // and .reg .b<size> for non-ABI
1521       unsigned sz = 0;
1522       if (isa<IntegerType>(Ty)) {
1523         sz = cast<IntegerType>(Ty)->getBitWidth();
1524         if (sz < 32)
1525           sz = 32;
1526       } else if (isa<PointerType>(Ty))
1527         sz = thePointerTy.getSizeInBits();
1528       else if (Ty->isHalfTy())
1529         // PTX ABI requires all scalar parameters to be at least 32
1530         // bits in size.  fp16 normally uses .b16 as its storage type
1531         // in PTX, so its size must be adjusted here, too.
1532         sz = 32;
1533       else
1534         sz = Ty->getPrimitiveSizeInBits();
1535       if (isABI)
1536         O << "\t.param .b" << sz << " ";
1537       else
1538         O << "\t.reg .b" << sz << " ";
1539       printParamName(I, paramIndex, O);
1540       continue;
1541     }
1542 
1543     // param has byVal attribute. So should be a pointer
1544     auto *PTy = dyn_cast<PointerType>(Ty);
1545     assert(PTy && "Param with byval attribute should be a pointer type");
1546     Type *ETy = PTy->getElementType();
1547 
1548     if (isABI || isKernelFunc) {
1549       // Just print .param .align <a> .b8 .param[size];
1550       // <a> = PAL.getparamalignment
1551       // size = typeallocsize of element type
1552       Align align =
1553           DL.getValueOrABITypeAlignment(PAL.getParamAlignment(paramIndex), ETy);
1554       // Work around a bug in ptxas. When PTX code takes address of
1555       // byval parameter with alignment < 4, ptxas generates code to
1556       // spill argument into memory. Alas on sm_50+ ptxas generates
1557       // SASS code that fails with misaligned access. To work around
1558       // the problem, make sure that we align byval parameters by at
1559       // least 4. Matching change must be made in LowerCall() where we
1560       // prepare parameters for the call.
1561       //
1562       // TODO: this will need to be undone when we get to support multi-TU
1563       // device-side compilation as it breaks ABI compatibility with nvcc.
1564       // Hopefully ptxas bug is fixed by then.
1565       if (!isKernelFunc && align < Align(4))
1566         align = Align(4);
1567       unsigned sz = DL.getTypeAllocSize(ETy);
1568       O << "\t.param .align " << align.value() << " .b8 ";
1569       printParamName(I, paramIndex, O);
1570       O << "[" << sz << "]";
1571       continue;
1572     } else {
1573       // Split the ETy into constituent parts and
1574       // print .param .b<size> <name> for each part.
1575       // Further, if a part is vector, print the above for
1576       // each vector element.
1577       SmallVector<EVT, 16> vtparts;
1578       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1579       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1580         unsigned elems = 1;
1581         EVT elemtype = vtparts[i];
1582         if (vtparts[i].isVector()) {
1583           elems = vtparts[i].getVectorNumElements();
1584           elemtype = vtparts[i].getVectorElementType();
1585         }
1586 
1587         for (unsigned j = 0, je = elems; j != je; ++j) {
1588           unsigned sz = elemtype.getSizeInBits();
1589           if (elemtype.isInteger() && (sz < 32))
1590             sz = 32;
1591           O << "\t.reg .b" << sz << " ";
1592           printParamName(I, paramIndex, O);
1593           if (j < je - 1)
1594             O << ",\n";
1595           ++paramIndex;
1596         }
1597         if (i < e - 1)
1598           O << ",\n";
1599       }
1600       --paramIndex;
1601       continue;
1602     }
1603   }
1604 
1605   O << "\n)\n";
1606 }
1607 
1608 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1609                                             raw_ostream &O) {
1610   const Function &F = MF.getFunction();
1611   emitFunctionParamList(&F, O);
1612 }
1613 
1614 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1615     const MachineFunction &MF) {
1616   SmallString<128> Str;
1617   raw_svector_ostream O(Str);
1618 
1619   // Map the global virtual register number to a register class specific
1620   // virtual register number starting from 1 with that class.
1621   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1622   //unsigned numRegClasses = TRI->getNumRegClasses();
1623 
1624   // Emit the Fake Stack Object
1625   const MachineFrameInfo &MFI = MF.getFrameInfo();
1626   int NumBytes = (int) MFI.getStackSize();
1627   if (NumBytes) {
1628     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1629       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1630     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1631       O << "\t.reg .b64 \t%SP;\n";
1632       O << "\t.reg .b64 \t%SPL;\n";
1633     } else {
1634       O << "\t.reg .b32 \t%SP;\n";
1635       O << "\t.reg .b32 \t%SPL;\n";
1636     }
1637   }
1638 
1639   // Go through all virtual registers to establish the mapping between the
1640   // global virtual
1641   // register number and the per class virtual register number.
1642   // We use the per class virtual register number in the ptx output.
1643   unsigned int numVRs = MRI->getNumVirtRegs();
1644   for (unsigned i = 0; i < numVRs; i++) {
1645     unsigned int vr = Register::index2VirtReg(i);
1646     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1647     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1648     int n = regmap.size();
1649     regmap.insert(std::make_pair(vr, n + 1));
1650   }
1651 
1652   // Emit register declarations
1653   // @TODO: Extract out the real register usage
1654   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1655   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1656   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1657   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1658   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1659   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1660   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1661 
1662   // Emit declaration of the virtual registers or 'physical' registers for
1663   // each register class
1664   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1665     const TargetRegisterClass *RC = TRI->getRegClass(i);
1666     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1667     std::string rcname = getNVPTXRegClassName(RC);
1668     std::string rcStr = getNVPTXRegClassStr(RC);
1669     int n = regmap.size();
1670 
1671     // Only declare those registers that may be used.
1672     if (n) {
1673        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1674          << ">;\n";
1675     }
1676   }
1677 
1678   OutStreamer->emitRawText(O.str());
1679 }
1680 
1681 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1682   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1683   bool ignored;
1684   unsigned int numHex;
1685   const char *lead;
1686 
1687   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1688     numHex = 8;
1689     lead = "0f";
1690     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1691   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1692     numHex = 16;
1693     lead = "0d";
1694     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1695   } else
1696     llvm_unreachable("unsupported fp type");
1697 
1698   APInt API = APF.bitcastToAPInt();
1699   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1700 }
1701 
1702 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1703   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1704     O << CI->getValue();
1705     return;
1706   }
1707   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1708     printFPConstant(CFP, O);
1709     return;
1710   }
1711   if (isa<ConstantPointerNull>(CPV)) {
1712     O << "0";
1713     return;
1714   }
1715   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1716     bool IsNonGenericPointer = false;
1717     if (GVar->getType()->getAddressSpace() != 0) {
1718       IsNonGenericPointer = true;
1719     }
1720     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1721       O << "generic(";
1722       getSymbol(GVar)->print(O, MAI);
1723       O << ")";
1724     } else {
1725       getSymbol(GVar)->print(O, MAI);
1726     }
1727     return;
1728   }
1729   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1730     const Value *v = Cexpr->stripPointerCasts();
1731     PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType());
1732     bool IsNonGenericPointer = false;
1733     if (PTy && PTy->getAddressSpace() != 0) {
1734       IsNonGenericPointer = true;
1735     }
1736     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1737       if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
1738         O << "generic(";
1739         getSymbol(GVar)->print(O, MAI);
1740         O << ")";
1741       } else {
1742         getSymbol(GVar)->print(O, MAI);
1743       }
1744       return;
1745     } else {
1746       lowerConstant(CPV)->print(O, MAI);
1747       return;
1748     }
1749   }
1750   llvm_unreachable("Not scalar type found in printScalarConstant()");
1751 }
1752 
1753 // These utility functions assure we get the right sequence of bytes for a given
1754 // type even for big-endian machines
1755 template <typename T> static void ConvertIntToBytes(unsigned char *p, T val) {
1756   int64_t vp = (int64_t)val;
1757   for (unsigned i = 0; i < sizeof(T); ++i) {
1758     p[i] = (unsigned char)vp;
1759     vp >>= 8;
1760   }
1761 }
1762 static void ConvertFloatToBytes(unsigned char *p, float val) {
1763   int32_t *vp = (int32_t *)&val;
1764   for (unsigned i = 0; i < sizeof(int32_t); ++i) {
1765     p[i] = (unsigned char)*vp;
1766     *vp >>= 8;
1767   }
1768 }
1769 static void ConvertDoubleToBytes(unsigned char *p, double val) {
1770   int64_t *vp = (int64_t *)&val;
1771   for (unsigned i = 0; i < sizeof(int64_t); ++i) {
1772     p[i] = (unsigned char)*vp;
1773     *vp >>= 8;
1774   }
1775 }
1776 
1777 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1778                                    AggBuffer *aggBuffer) {
1779   const DataLayout &DL = getDataLayout();
1780 
1781   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1782     int s = DL.getTypeAllocSize(CPV->getType());
1783     if (s < Bytes)
1784       s = Bytes;
1785     aggBuffer->addZeros(s);
1786     return;
1787   }
1788 
1789   unsigned char ptr[8];
1790   switch (CPV->getType()->getTypeID()) {
1791 
1792   case Type::IntegerTyID: {
1793     Type *ETy = CPV->getType();
1794     if (ETy == Type::getInt8Ty(CPV->getContext())) {
1795       unsigned char c = (unsigned char)cast<ConstantInt>(CPV)->getZExtValue();
1796       ConvertIntToBytes<>(ptr, c);
1797       aggBuffer->addBytes(ptr, 1, Bytes);
1798     } else if (ETy == Type::getInt16Ty(CPV->getContext())) {
1799       short int16 = (short)cast<ConstantInt>(CPV)->getZExtValue();
1800       ConvertIntToBytes<>(ptr, int16);
1801       aggBuffer->addBytes(ptr, 2, Bytes);
1802     } else if (ETy == Type::getInt32Ty(CPV->getContext())) {
1803       if (const ConstantInt *constInt = dyn_cast<ConstantInt>(CPV)) {
1804         int int32 = (int)(constInt->getZExtValue());
1805         ConvertIntToBytes<>(ptr, int32);
1806         aggBuffer->addBytes(ptr, 4, Bytes);
1807         break;
1808       } else if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1809         if (const auto *constInt = dyn_cast<ConstantInt>(
1810                 ConstantFoldConstant(Cexpr, DL))) {
1811           int int32 = (int)(constInt->getZExtValue());
1812           ConvertIntToBytes<>(ptr, int32);
1813           aggBuffer->addBytes(ptr, 4, Bytes);
1814           break;
1815         }
1816         if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1817           Value *v = Cexpr->getOperand(0)->stripPointerCasts();
1818           aggBuffer->addSymbol(v, Cexpr->getOperand(0));
1819           aggBuffer->addZeros(4);
1820           break;
1821         }
1822       }
1823       llvm_unreachable("unsupported integer const type");
1824     } else if (ETy == Type::getInt64Ty(CPV->getContext())) {
1825       if (const ConstantInt *constInt = dyn_cast<ConstantInt>(CPV)) {
1826         long long int64 = (long long)(constInt->getZExtValue());
1827         ConvertIntToBytes<>(ptr, int64);
1828         aggBuffer->addBytes(ptr, 8, Bytes);
1829         break;
1830       } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1831         if (const auto *constInt = dyn_cast<ConstantInt>(
1832                 ConstantFoldConstant(Cexpr, DL))) {
1833           long long int64 = (long long)(constInt->getZExtValue());
1834           ConvertIntToBytes<>(ptr, int64);
1835           aggBuffer->addBytes(ptr, 8, Bytes);
1836           break;
1837         }
1838         if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1839           Value *v = Cexpr->getOperand(0)->stripPointerCasts();
1840           aggBuffer->addSymbol(v, Cexpr->getOperand(0));
1841           aggBuffer->addZeros(8);
1842           break;
1843         }
1844       }
1845       llvm_unreachable("unsupported integer const type");
1846     } else
1847       llvm_unreachable("unsupported integer const type");
1848     break;
1849   }
1850   case Type::HalfTyID:
1851   case Type::FloatTyID:
1852   case Type::DoubleTyID: {
1853     const auto *CFP = cast<ConstantFP>(CPV);
1854     Type *Ty = CFP->getType();
1855     if (Ty == Type::getHalfTy(CPV->getContext())) {
1856       APInt API = CFP->getValueAPF().bitcastToAPInt();
1857       uint16_t float16 = API.getLoBits(16).getZExtValue();
1858       ConvertIntToBytes<>(ptr, float16);
1859       aggBuffer->addBytes(ptr, 2, Bytes);
1860     } else if (Ty == Type::getFloatTy(CPV->getContext())) {
1861       float float32 = (float) CFP->getValueAPF().convertToFloat();
1862       ConvertFloatToBytes(ptr, float32);
1863       aggBuffer->addBytes(ptr, 4, Bytes);
1864     } else if (Ty == Type::getDoubleTy(CPV->getContext())) {
1865       double float64 = CFP->getValueAPF().convertToDouble();
1866       ConvertDoubleToBytes(ptr, float64);
1867       aggBuffer->addBytes(ptr, 8, Bytes);
1868     } else {
1869       llvm_unreachable("unsupported fp const type");
1870     }
1871     break;
1872   }
1873   case Type::PointerTyID: {
1874     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1875       aggBuffer->addSymbol(GVar, GVar);
1876     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1877       const Value *v = Cexpr->stripPointerCasts();
1878       aggBuffer->addSymbol(v, Cexpr);
1879     }
1880     unsigned int s = DL.getTypeAllocSize(CPV->getType());
1881     aggBuffer->addZeros(s);
1882     break;
1883   }
1884 
1885   case Type::ArrayTyID:
1886   case Type::FixedVectorTyID:
1887   case Type::StructTyID: {
1888     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1889       int ElementSize = DL.getTypeAllocSize(CPV->getType());
1890       bufferAggregateConstant(CPV, aggBuffer);
1891       if (Bytes > ElementSize)
1892         aggBuffer->addZeros(Bytes - ElementSize);
1893     } else if (isa<ConstantAggregateZero>(CPV))
1894       aggBuffer->addZeros(Bytes);
1895     else
1896       llvm_unreachable("Unexpected Constant type");
1897     break;
1898   }
1899 
1900   default:
1901     llvm_unreachable("unsupported type");
1902   }
1903 }
1904 
1905 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1906                                               AggBuffer *aggBuffer) {
1907   const DataLayout &DL = getDataLayout();
1908   int Bytes;
1909 
1910   // Integers of arbitrary width
1911   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1912     APInt Val = CI->getValue();
1913     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1914       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1915       aggBuffer->addBytes(&Byte, 1, 1);
1916       Val.lshrInPlace(8);
1917     }
1918     return;
1919   }
1920 
1921   // Old constants
1922   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1923     if (CPV->getNumOperands())
1924       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1925         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1926     return;
1927   }
1928 
1929   if (const ConstantDataSequential *CDS =
1930           dyn_cast<ConstantDataSequential>(CPV)) {
1931     if (CDS->getNumElements())
1932       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1933         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1934                      aggBuffer);
1935     return;
1936   }
1937 
1938   if (isa<ConstantStruct>(CPV)) {
1939     if (CPV->getNumOperands()) {
1940       StructType *ST = cast<StructType>(CPV->getType());
1941       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1942         if (i == (e - 1))
1943           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1944                   DL.getTypeAllocSize(ST) -
1945                   DL.getStructLayout(ST)->getElementOffset(i);
1946         else
1947           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1948                   DL.getStructLayout(ST)->getElementOffset(i);
1949         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1950       }
1951     }
1952     return;
1953   }
1954   llvm_unreachable("unsupported constant type in printAggregateConstant()");
1955 }
1956 
1957 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1958 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1959 /// expressions that are representable in PTX and create
1960 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1961 const MCExpr *
1962 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1963   MCContext &Ctx = OutContext;
1964 
1965   if (CV->isNullValue() || isa<UndefValue>(CV))
1966     return MCConstantExpr::create(0, Ctx);
1967 
1968   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1969     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1970 
1971   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1972     const MCSymbolRefExpr *Expr =
1973       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1974     if (ProcessingGeneric) {
1975       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1976     } else {
1977       return Expr;
1978     }
1979   }
1980 
1981   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1982   if (!CE) {
1983     llvm_unreachable("Unknown constant value to lower!");
1984   }
1985 
1986   switch (CE->getOpcode()) {
1987   default: {
1988     // If the code isn't optimized, there may be outstanding folding
1989     // opportunities. Attempt to fold the expression using DataLayout as a
1990     // last resort before giving up.
1991     Constant *C = ConstantFoldConstant(CE, getDataLayout());
1992     if (C != CE)
1993       return lowerConstantForGV(C, ProcessingGeneric);
1994 
1995     // Otherwise report the problem to the user.
1996     std::string S;
1997     raw_string_ostream OS(S);
1998     OS << "Unsupported expression in static initializer: ";
1999     CE->printAsOperand(OS, /*PrintType=*/false,
2000                    !MF ? nullptr : MF->getFunction().getParent());
2001     report_fatal_error(OS.str());
2002   }
2003 
2004   case Instruction::AddrSpaceCast: {
2005     // Strip the addrspacecast and pass along the operand
2006     PointerType *DstTy = cast<PointerType>(CE->getType());
2007     if (DstTy->getAddressSpace() == 0) {
2008       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2009     }
2010     std::string S;
2011     raw_string_ostream OS(S);
2012     OS << "Unsupported expression in static initializer: ";
2013     CE->printAsOperand(OS, /*PrintType=*/ false,
2014                        !MF ? nullptr : MF->getFunction().getParent());
2015     report_fatal_error(OS.str());
2016   }
2017 
2018   case Instruction::GetElementPtr: {
2019     const DataLayout &DL = getDataLayout();
2020 
2021     // Generate a symbolic expression for the byte address
2022     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2023     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2024 
2025     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2026                                             ProcessingGeneric);
2027     if (!OffsetAI)
2028       return Base;
2029 
2030     int64_t Offset = OffsetAI.getSExtValue();
2031     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2032                                    Ctx);
2033   }
2034 
2035   case Instruction::Trunc:
2036     // We emit the value and depend on the assembler to truncate the generated
2037     // expression properly.  This is important for differences between
2038     // blockaddress labels.  Since the two labels are in the same function, it
2039     // is reasonable to treat their delta as a 32-bit value.
2040     LLVM_FALLTHROUGH;
2041   case Instruction::BitCast:
2042     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2043 
2044   case Instruction::IntToPtr: {
2045     const DataLayout &DL = getDataLayout();
2046 
2047     // Handle casts to pointers by changing them into casts to the appropriate
2048     // integer type.  This promotes constant folding and simplifies this code.
2049     Constant *Op = CE->getOperand(0);
2050     Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2051                                       false/*ZExt*/);
2052     return lowerConstantForGV(Op, ProcessingGeneric);
2053   }
2054 
2055   case Instruction::PtrToInt: {
2056     const DataLayout &DL = getDataLayout();
2057 
2058     // Support only foldable casts to/from pointers that can be eliminated by
2059     // changing the pointer to the appropriately sized integer type.
2060     Constant *Op = CE->getOperand(0);
2061     Type *Ty = CE->getType();
2062 
2063     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2064 
2065     // We can emit the pointer value into this slot if the slot is an
2066     // integer slot equal to the size of the pointer.
2067     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2068       return OpExpr;
2069 
2070     // Otherwise the pointer is smaller than the resultant integer, mask off
2071     // the high bits so we are sure to get a proper truncation if the input is
2072     // a constant expr.
2073     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2074     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2075     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2076   }
2077 
2078   // The MC library also has a right-shift operator, but it isn't consistently
2079   // signed or unsigned between different targets.
2080   case Instruction::Add: {
2081     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2082     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2083     switch (CE->getOpcode()) {
2084     default: llvm_unreachable("Unknown binary operator constant cast expr");
2085     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2086     }
2087   }
2088   }
2089 }
2090 
2091 // Copy of MCExpr::print customized for NVPTX
2092 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2093   switch (Expr.getKind()) {
2094   case MCExpr::Target:
2095     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2096   case MCExpr::Constant:
2097     OS << cast<MCConstantExpr>(Expr).getValue();
2098     return;
2099 
2100   case MCExpr::SymbolRef: {
2101     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2102     const MCSymbol &Sym = SRE.getSymbol();
2103     Sym.print(OS, MAI);
2104     return;
2105   }
2106 
2107   case MCExpr::Unary: {
2108     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2109     switch (UE.getOpcode()) {
2110     case MCUnaryExpr::LNot:  OS << '!'; break;
2111     case MCUnaryExpr::Minus: OS << '-'; break;
2112     case MCUnaryExpr::Not:   OS << '~'; break;
2113     case MCUnaryExpr::Plus:  OS << '+'; break;
2114     }
2115     printMCExpr(*UE.getSubExpr(), OS);
2116     return;
2117   }
2118 
2119   case MCExpr::Binary: {
2120     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2121 
2122     // Only print parens around the LHS if it is non-trivial.
2123     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2124         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2125       printMCExpr(*BE.getLHS(), OS);
2126     } else {
2127       OS << '(';
2128       printMCExpr(*BE.getLHS(), OS);
2129       OS<< ')';
2130     }
2131 
2132     switch (BE.getOpcode()) {
2133     case MCBinaryExpr::Add:
2134       // Print "X-42" instead of "X+-42".
2135       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2136         if (RHSC->getValue() < 0) {
2137           OS << RHSC->getValue();
2138           return;
2139         }
2140       }
2141 
2142       OS <<  '+';
2143       break;
2144     default: llvm_unreachable("Unhandled binary operator");
2145     }
2146 
2147     // Only print parens around the LHS if it is non-trivial.
2148     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2149       printMCExpr(*BE.getRHS(), OS);
2150     } else {
2151       OS << '(';
2152       printMCExpr(*BE.getRHS(), OS);
2153       OS << ')';
2154     }
2155     return;
2156   }
2157   }
2158 
2159   llvm_unreachable("Invalid expression kind!");
2160 }
2161 
2162 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2163 ///
2164 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2165                                       const char *ExtraCode, raw_ostream &O) {
2166   if (ExtraCode && ExtraCode[0]) {
2167     if (ExtraCode[1] != 0)
2168       return true; // Unknown modifier.
2169 
2170     switch (ExtraCode[0]) {
2171     default:
2172       // See if this is a generic print operand
2173       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2174     case 'r':
2175       break;
2176     }
2177   }
2178 
2179   printOperand(MI, OpNo, O);
2180 
2181   return false;
2182 }
2183 
2184 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2185                                             unsigned OpNo,
2186                                             const char *ExtraCode,
2187                                             raw_ostream &O) {
2188   if (ExtraCode && ExtraCode[0])
2189     return true; // Unknown modifier
2190 
2191   O << '[';
2192   printMemOperand(MI, OpNo, O);
2193   O << ']';
2194 
2195   return false;
2196 }
2197 
2198 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2199                                    raw_ostream &O) {
2200   const MachineOperand &MO = MI->getOperand(opNum);
2201   switch (MO.getType()) {
2202   case MachineOperand::MO_Register:
2203     if (Register::isPhysicalRegister(MO.getReg())) {
2204       if (MO.getReg() == NVPTX::VRDepot)
2205         O << DEPOTNAME << getFunctionNumber();
2206       else
2207         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2208     } else {
2209       emitVirtualRegister(MO.getReg(), O);
2210     }
2211     break;
2212 
2213   case MachineOperand::MO_Immediate:
2214     O << MO.getImm();
2215     break;
2216 
2217   case MachineOperand::MO_FPImmediate:
2218     printFPConstant(MO.getFPImm(), O);
2219     break;
2220 
2221   case MachineOperand::MO_GlobalAddress:
2222     PrintSymbolOperand(MO, O);
2223     break;
2224 
2225   case MachineOperand::MO_MachineBasicBlock:
2226     MO.getMBB()->getSymbol()->print(O, MAI);
2227     break;
2228 
2229   default:
2230     llvm_unreachable("Operand type not supported.");
2231   }
2232 }
2233 
2234 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2235                                       raw_ostream &O, const char *Modifier) {
2236   printOperand(MI, opNum, O);
2237 
2238   if (Modifier && strcmp(Modifier, "add") == 0) {
2239     O << ", ";
2240     printOperand(MI, opNum + 1, O);
2241   } else {
2242     if (MI->getOperand(opNum + 1).isImm() &&
2243         MI->getOperand(opNum + 1).getImm() == 0)
2244       return; // don't print ',0' or '+0'
2245     O << "+";
2246     printOperand(MI, opNum + 1, O);
2247   }
2248 }
2249 
2250 // Force static initialization.
2251 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2252   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2253   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2254 }
2255