1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/Path.h"
82 #include "llvm/Support/raw_ostream.h"
83 #include "llvm/Target/TargetLoweringObjectFile.h"
84 #include "llvm/Target/TargetMachine.h"
85 #include "llvm/TargetParser/Triple.h"
86 #include "llvm/Transforms/Utils/UnrollLoop.h"
87 #include <cassert>
88 #include <cstdint>
89 #include <cstring>
90 #include <new>
91 #include <string>
92 #include <utility>
93 #include <vector>
94
95 using namespace llvm;
96
97 static cl::opt<bool>
98 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
99 cl::desc("Lower GPU ctor / dtors to globals on the device."),
100 cl::init(false), cl::Hidden);
101
102 #define DEPOTNAME "__local_depot"
103
104 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
105 /// depends.
106 static void
DiscoverDependentGlobals(const Value * V,DenseSet<const GlobalVariable * > & Globals)107 DiscoverDependentGlobals(const Value *V,
108 DenseSet<const GlobalVariable *> &Globals) {
109 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
110 Globals.insert(GV);
111 else {
112 if (const User *U = dyn_cast<User>(V)) {
113 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
114 DiscoverDependentGlobals(U->getOperand(i), Globals);
115 }
116 }
117 }
118 }
119
120 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
121 /// instances to be emitted, but only after any dependents have been added
122 /// first.s
123 static void
VisitGlobalVariableForEmission(const GlobalVariable * GV,SmallVectorImpl<const GlobalVariable * > & Order,DenseSet<const GlobalVariable * > & Visited,DenseSet<const GlobalVariable * > & Visiting)124 VisitGlobalVariableForEmission(const GlobalVariable *GV,
125 SmallVectorImpl<const GlobalVariable *> &Order,
126 DenseSet<const GlobalVariable *> &Visited,
127 DenseSet<const GlobalVariable *> &Visiting) {
128 // Have we already visited this one?
129 if (Visited.count(GV))
130 return;
131
132 // Do we have a circular dependency?
133 if (!Visiting.insert(GV).second)
134 report_fatal_error("Circular dependency found in global variable set");
135
136 // Make sure we visit all dependents first
137 DenseSet<const GlobalVariable *> Others;
138 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
139 DiscoverDependentGlobals(GV->getOperand(i), Others);
140
141 for (const GlobalVariable *GV : Others)
142 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
143
144 // Now we can visit ourself
145 Order.push_back(GV);
146 Visited.insert(GV);
147 Visiting.erase(GV);
148 }
149
emitInstruction(const MachineInstr * MI)150 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
151 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
152 getSubtargetInfo().getFeatureBits());
153
154 MCInst Inst;
155 lowerToMCInst(MI, Inst);
156 EmitToStreamer(*OutStreamer, Inst);
157 }
158
159 // Handle symbol backtracking for targets that do not support image handles
lowerImageHandleOperand(const MachineInstr * MI,unsigned OpNo,MCOperand & MCOp)160 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
161 unsigned OpNo, MCOperand &MCOp) {
162 const MachineOperand &MO = MI->getOperand(OpNo);
163 const MCInstrDesc &MCID = MI->getDesc();
164
165 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
166 // This is a texture fetch, so operand 4 is a texref and operand 5 is
167 // a samplerref
168 if (OpNo == 4 && MO.isImm()) {
169 lowerImageHandleSymbol(MO.getImm(), MCOp);
170 return true;
171 }
172 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
173 lowerImageHandleSymbol(MO.getImm(), MCOp);
174 return true;
175 }
176
177 return false;
178 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
179 unsigned VecSize =
180 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
181
182 // For a surface load of vector size N, the Nth operand will be the surfref
183 if (OpNo == VecSize && MO.isImm()) {
184 lowerImageHandleSymbol(MO.getImm(), MCOp);
185 return true;
186 }
187
188 return false;
189 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
190 // This is a surface store, so operand 0 is a surfref
191 if (OpNo == 0 && MO.isImm()) {
192 lowerImageHandleSymbol(MO.getImm(), MCOp);
193 return true;
194 }
195
196 return false;
197 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
198 // This is a query, so operand 1 is a surfref/texref
199 if (OpNo == 1 && MO.isImm()) {
200 lowerImageHandleSymbol(MO.getImm(), MCOp);
201 return true;
202 }
203
204 return false;
205 }
206
207 return false;
208 }
209
lowerImageHandleSymbol(unsigned Index,MCOperand & MCOp)210 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
211 // Ewwww
212 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
213 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
214 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
215 const char *Sym = MFI->getImageHandleSymbol(Index);
216 StringRef SymName = nvTM.getStrPool().save(Sym);
217 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
218 }
219
lowerToMCInst(const MachineInstr * MI,MCInst & OutMI)220 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
221 OutMI.setOpcode(MI->getOpcode());
222 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
223 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
224 const MachineOperand &MO = MI->getOperand(0);
225 OutMI.addOperand(GetSymbolRef(
226 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
227 return;
228 }
229
230 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
231 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
232 const MachineOperand &MO = MI->getOperand(i);
233
234 MCOperand MCOp;
235 if (!STI.hasImageHandles()) {
236 if (lowerImageHandleOperand(MI, i, MCOp)) {
237 OutMI.addOperand(MCOp);
238 continue;
239 }
240 }
241
242 if (lowerOperand(MO, MCOp))
243 OutMI.addOperand(MCOp);
244 }
245 }
246
lowerOperand(const MachineOperand & MO,MCOperand & MCOp)247 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
248 MCOperand &MCOp) {
249 switch (MO.getType()) {
250 default: llvm_unreachable("unknown operand type");
251 case MachineOperand::MO_Register:
252 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
253 break;
254 case MachineOperand::MO_Immediate:
255 MCOp = MCOperand::createImm(MO.getImm());
256 break;
257 case MachineOperand::MO_MachineBasicBlock:
258 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
259 MO.getMBB()->getSymbol(), OutContext));
260 break;
261 case MachineOperand::MO_ExternalSymbol:
262 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
263 break;
264 case MachineOperand::MO_GlobalAddress:
265 MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
266 break;
267 case MachineOperand::MO_FPImmediate: {
268 const ConstantFP *Cnt = MO.getFPImm();
269 const APFloat &Val = Cnt->getValueAPF();
270
271 switch (Cnt->getType()->getTypeID()) {
272 default: report_fatal_error("Unsupported FP type"); break;
273 case Type::HalfTyID:
274 MCOp = MCOperand::createExpr(
275 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
276 break;
277 case Type::BFloatTyID:
278 MCOp = MCOperand::createExpr(
279 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
280 break;
281 case Type::FloatTyID:
282 MCOp = MCOperand::createExpr(
283 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
284 break;
285 case Type::DoubleTyID:
286 MCOp = MCOperand::createExpr(
287 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
288 break;
289 }
290 break;
291 }
292 }
293 return true;
294 }
295
encodeVirtualRegister(unsigned Reg)296 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
297 if (Register::isVirtualRegister(Reg)) {
298 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
299
300 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
301 unsigned RegNum = RegMap[Reg];
302
303 // Encode the register class in the upper 4 bits
304 // Must be kept in sync with NVPTXInstPrinter::printRegName
305 unsigned Ret = 0;
306 if (RC == &NVPTX::Int1RegsRegClass) {
307 Ret = (1 << 28);
308 } else if (RC == &NVPTX::Int16RegsRegClass) {
309 Ret = (2 << 28);
310 } else if (RC == &NVPTX::Int32RegsRegClass) {
311 Ret = (3 << 28);
312 } else if (RC == &NVPTX::Int64RegsRegClass) {
313 Ret = (4 << 28);
314 } else if (RC == &NVPTX::Float32RegsRegClass) {
315 Ret = (5 << 28);
316 } else if (RC == &NVPTX::Float64RegsRegClass) {
317 Ret = (6 << 28);
318 } else if (RC == &NVPTX::Int128RegsRegClass) {
319 Ret = (7 << 28);
320 } else {
321 report_fatal_error("Bad register class");
322 }
323
324 // Insert the vreg number
325 Ret |= (RegNum & 0x0FFFFFFF);
326 return Ret;
327 } else {
328 // Some special-use registers are actually physical registers.
329 // Encode this as the register class ID of 0 and the real register ID.
330 return Reg & 0x0FFFFFFF;
331 }
332 }
333
GetSymbolRef(const MCSymbol * Symbol)334 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
335 const MCExpr *Expr;
336 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
337 OutContext);
338 return MCOperand::createExpr(Expr);
339 }
340
ShouldPassAsArray(Type * Ty)341 static bool ShouldPassAsArray(Type *Ty) {
342 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
343 Ty->isHalfTy() || Ty->isBFloatTy();
344 }
345
printReturnValStr(const Function * F,raw_ostream & O)346 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
347 const DataLayout &DL = getDataLayout();
348 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
349 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
350
351 Type *Ty = F->getReturnType();
352
353 bool isABI = (STI.getSmVersion() >= 20);
354
355 if (Ty->getTypeID() == Type::VoidTyID)
356 return;
357 O << " (";
358
359 if (isABI) {
360 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
361 !ShouldPassAsArray(Ty)) {
362 unsigned size = 0;
363 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
364 size = ITy->getBitWidth();
365 } else {
366 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
367 size = Ty->getPrimitiveSizeInBits();
368 }
369 size = promoteScalarArgumentSize(size);
370 O << ".param .b" << size << " func_retval0";
371 } else if (isa<PointerType>(Ty)) {
372 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
373 << " func_retval0";
374 } else if (ShouldPassAsArray(Ty)) {
375 unsigned totalsz = DL.getTypeAllocSize(Ty);
376 Align RetAlignment = TLI->getFunctionArgumentAlignment(
377 F, Ty, AttributeList::ReturnIndex, DL);
378 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
379 << totalsz << "]";
380 } else
381 llvm_unreachable("Unknown return type");
382 } else {
383 SmallVector<EVT, 16> vtparts;
384 ComputeValueVTs(*TLI, DL, Ty, vtparts);
385 unsigned idx = 0;
386 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
387 unsigned elems = 1;
388 EVT elemtype = vtparts[i];
389 if (vtparts[i].isVector()) {
390 elems = vtparts[i].getVectorNumElements();
391 elemtype = vtparts[i].getVectorElementType();
392 }
393
394 for (unsigned j = 0, je = elems; j != je; ++j) {
395 unsigned sz = elemtype.getSizeInBits();
396 if (elemtype.isInteger())
397 sz = promoteScalarArgumentSize(sz);
398 O << ".reg .b" << sz << " func_retval" << idx;
399 if (j < je - 1)
400 O << ", ";
401 ++idx;
402 }
403 if (i < e - 1)
404 O << ", ";
405 }
406 }
407 O << ") ";
408 }
409
printReturnValStr(const MachineFunction & MF,raw_ostream & O)410 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
411 raw_ostream &O) {
412 const Function &F = MF.getFunction();
413 printReturnValStr(&F, O);
414 }
415
416 // Return true if MBB is the header of a loop marked with
417 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
isLoopHeaderOfNoUnroll(const MachineBasicBlock & MBB) const418 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
419 const MachineBasicBlock &MBB) const {
420 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
421 // We insert .pragma "nounroll" only to the loop header.
422 if (!LI.isLoopHeader(&MBB))
423 return false;
424
425 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
426 // we iterate through each back edge of the loop with header MBB, and check
427 // whether its metadata contains llvm.loop.unroll.disable.
428 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
429 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
430 // Edges from other loops to MBB are not back edges.
431 continue;
432 }
433 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
434 if (MDNode *LoopID =
435 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
436 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
437 return true;
438 if (MDNode *UnrollCountMD =
439 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
440 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
441 ->isOne())
442 return true;
443 }
444 }
445 }
446 }
447 return false;
448 }
449
emitBasicBlockStart(const MachineBasicBlock & MBB)450 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
451 AsmPrinter::emitBasicBlockStart(MBB);
452 if (isLoopHeaderOfNoUnroll(MBB))
453 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
454 }
455
emitFunctionEntryLabel()456 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
457 SmallString<128> Str;
458 raw_svector_ostream O(Str);
459
460 if (!GlobalsEmitted) {
461 emitGlobals(*MF->getFunction().getParent());
462 GlobalsEmitted = true;
463 }
464
465 // Set up
466 MRI = &MF->getRegInfo();
467 F = &MF->getFunction();
468 emitLinkageDirective(F, O);
469 if (isKernelFunction(*F))
470 O << ".entry ";
471 else {
472 O << ".func ";
473 printReturnValStr(*MF, O);
474 }
475
476 CurrentFnSym->print(O, MAI);
477
478 emitFunctionParamList(F, O);
479 O << "\n";
480
481 if (isKernelFunction(*F))
482 emitKernelFunctionDirectives(*F, O);
483
484 if (shouldEmitPTXNoReturn(F, TM))
485 O << ".noreturn";
486
487 OutStreamer->emitRawText(O.str());
488
489 VRegMapping.clear();
490 // Emit open brace for function body.
491 OutStreamer->emitRawText(StringRef("{\n"));
492 setAndEmitFunctionVirtualRegisters(*MF);
493 // Emit initial .loc debug directive for correct relocation symbol data.
494 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
495 assert(SP->getUnit());
496 if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo())
497 emitInitialRawDwarfLocDirective(*MF);
498 }
499 }
500
runOnMachineFunction(MachineFunction & F)501 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
502 bool Result = AsmPrinter::runOnMachineFunction(F);
503 // Emit closing brace for the body of function F.
504 // The closing brace must be emitted here because we need to emit additional
505 // debug labels/data after the last basic block.
506 // We need to emit the closing brace here because we don't have function that
507 // finished emission of the function body.
508 OutStreamer->emitRawText(StringRef("}\n"));
509 return Result;
510 }
511
emitFunctionBodyStart()512 void NVPTXAsmPrinter::emitFunctionBodyStart() {
513 SmallString<128> Str;
514 raw_svector_ostream O(Str);
515 emitDemotedVars(&MF->getFunction(), O);
516 OutStreamer->emitRawText(O.str());
517 }
518
emitFunctionBodyEnd()519 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
520 VRegMapping.clear();
521 }
522
getFunctionFrameSymbol() const523 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
524 SmallString<128> Str;
525 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
526 return OutContext.getOrCreateSymbol(Str);
527 }
528
emitImplicitDef(const MachineInstr * MI) const529 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
530 Register RegNo = MI->getOperand(0).getReg();
531 if (RegNo.isVirtual()) {
532 OutStreamer->AddComment(Twine("implicit-def: ") +
533 getVirtualRegisterName(RegNo));
534 } else {
535 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
536 OutStreamer->AddComment(Twine("implicit-def: ") +
537 STI.getRegisterInfo()->getName(RegNo));
538 }
539 OutStreamer->addBlankLine();
540 }
541
emitKernelFunctionDirectives(const Function & F,raw_ostream & O) const542 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
543 raw_ostream &O) const {
544 // If the NVVM IR has some of reqntid* specified, then output
545 // the reqntid directive, and set the unspecified ones to 1.
546 // If none of Reqntid* is specified, don't output reqntid directive.
547 std::optional<unsigned> Reqntidx = getReqNTIDx(F);
548 std::optional<unsigned> Reqntidy = getReqNTIDy(F);
549 std::optional<unsigned> Reqntidz = getReqNTIDz(F);
550
551 if (Reqntidx || Reqntidy || Reqntidz)
552 O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
553 << ", " << Reqntidz.value_or(1) << "\n";
554
555 // If the NVVM IR has some of maxntid* specified, then output
556 // the maxntid directive, and set the unspecified ones to 1.
557 // If none of maxntid* is specified, don't output maxntid directive.
558 std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
559 std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
560 std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
561
562 if (Maxntidx || Maxntidy || Maxntidz)
563 O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
564 << ", " << Maxntidz.value_or(1) << "\n";
565
566 unsigned Mincta = 0;
567 if (getMinCTASm(F, Mincta))
568 O << ".minnctapersm " << Mincta << "\n";
569
570 unsigned Maxnreg = 0;
571 if (getMaxNReg(F, Maxnreg))
572 O << ".maxnreg " << Maxnreg << "\n";
573
574 // .maxclusterrank directive requires SM_90 or higher, make sure that we
575 // filter it out for lower SM versions, as it causes a hard ptxas crash.
576 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578 unsigned Maxclusterrank = 0;
579 if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580 O << ".maxclusterrank " << Maxclusterrank << "\n";
581 }
582
getVirtualRegisterName(unsigned Reg) const583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
584 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
585
586 std::string Name;
587 raw_string_ostream NameStr(Name);
588
589 VRegRCMap::const_iterator I = VRegMapping.find(RC);
590 assert(I != VRegMapping.end() && "Bad register class");
591 const DenseMap<unsigned, unsigned> &RegMap = I->second;
592
593 VRegMap::const_iterator VI = RegMap.find(Reg);
594 assert(VI != RegMap.end() && "Bad virtual register");
595 unsigned MappedVR = VI->second;
596
597 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
598
599 NameStr.flush();
600 return Name;
601 }
602
emitVirtualRegister(unsigned int vr,raw_ostream & O)603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
604 raw_ostream &O) {
605 O << getVirtualRegisterName(vr);
606 }
607
emitAliasDeclaration(const GlobalAlias * GA,raw_ostream & O)608 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
609 raw_ostream &O) {
610 const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());
611 if (!F || isKernelFunction(*F) || F->isDeclaration())
612 report_fatal_error(
613 "NVPTX aliasee must be a non-kernel function definition");
614
615 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
616 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
617 report_fatal_error("NVPTX aliasee must not be '.weak'");
618
619 emitDeclarationWithName(F, getSymbol(GA), O);
620 }
621
emitDeclaration(const Function * F,raw_ostream & O)622 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
623 emitDeclarationWithName(F, getSymbol(F), O);
624 }
625
emitDeclarationWithName(const Function * F,MCSymbol * S,raw_ostream & O)626 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
627 raw_ostream &O) {
628 emitLinkageDirective(F, O);
629 if (isKernelFunction(*F))
630 O << ".entry ";
631 else
632 O << ".func ";
633 printReturnValStr(F, O);
634 S->print(O, MAI);
635 O << "\n";
636 emitFunctionParamList(F, O);
637 O << "\n";
638 if (shouldEmitPTXNoReturn(F, TM))
639 O << ".noreturn";
640 O << ";\n";
641 }
642
usedInGlobalVarDef(const Constant * C)643 static bool usedInGlobalVarDef(const Constant *C) {
644 if (!C)
645 return false;
646
647 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
648 return GV->getName() != "llvm.used";
649 }
650
651 for (const User *U : C->users())
652 if (const Constant *C = dyn_cast<Constant>(U))
653 if (usedInGlobalVarDef(C))
654 return true;
655
656 return false;
657 }
658
usedInOneFunc(const User * U,Function const * & oneFunc)659 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
660 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
661 if (othergv->getName() == "llvm.used")
662 return true;
663 }
664
665 if (const Instruction *instr = dyn_cast<Instruction>(U)) {
666 if (instr->getParent() && instr->getParent()->getParent()) {
667 const Function *curFunc = instr->getParent()->getParent();
668 if (oneFunc && (curFunc != oneFunc))
669 return false;
670 oneFunc = curFunc;
671 return true;
672 } else
673 return false;
674 }
675
676 for (const User *UU : U->users())
677 if (!usedInOneFunc(UU, oneFunc))
678 return false;
679
680 return true;
681 }
682
683 /* Find out if a global variable can be demoted to local scope.
684 * Currently, this is valid for CUDA shared variables, which have local
685 * scope and global lifetime. So the conditions to check are :
686 * 1. Is the global variable in shared address space?
687 * 2. Does it have local linkage?
688 * 3. Is the global variable referenced only in one function?
689 */
canDemoteGlobalVar(const GlobalVariable * gv,Function const * & f)690 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
691 if (!gv->hasLocalLinkage())
692 return false;
693 PointerType *Pty = gv->getType();
694 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
695 return false;
696
697 const Function *oneFunc = nullptr;
698
699 bool flag = usedInOneFunc(gv, oneFunc);
700 if (!flag)
701 return false;
702 if (!oneFunc)
703 return false;
704 f = oneFunc;
705 return true;
706 }
707
useFuncSeen(const Constant * C,DenseMap<const Function *,bool> & seenMap)708 static bool useFuncSeen(const Constant *C,
709 DenseMap<const Function *, bool> &seenMap) {
710 for (const User *U : C->users()) {
711 if (const Constant *cu = dyn_cast<Constant>(U)) {
712 if (useFuncSeen(cu, seenMap))
713 return true;
714 } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
715 const BasicBlock *bb = I->getParent();
716 if (!bb)
717 continue;
718 const Function *caller = bb->getParent();
719 if (!caller)
720 continue;
721 if (seenMap.contains(caller))
722 return true;
723 }
724 }
725 return false;
726 }
727
emitDeclarations(const Module & M,raw_ostream & O)728 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
729 DenseMap<const Function *, bool> seenMap;
730 for (const Function &F : M) {
731 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
732 emitDeclaration(&F, O);
733 continue;
734 }
735
736 if (F.isDeclaration()) {
737 if (F.use_empty())
738 continue;
739 if (F.getIntrinsicID())
740 continue;
741 emitDeclaration(&F, O);
742 continue;
743 }
744 for (const User *U : F.users()) {
745 if (const Constant *C = dyn_cast<Constant>(U)) {
746 if (usedInGlobalVarDef(C)) {
747 // The use is in the initialization of a global variable
748 // that is a function pointer, so print a declaration
749 // for the original function
750 emitDeclaration(&F, O);
751 break;
752 }
753 // Emit a declaration of this function if the function that
754 // uses this constant expr has already been seen.
755 if (useFuncSeen(C, seenMap)) {
756 emitDeclaration(&F, O);
757 break;
758 }
759 }
760
761 if (!isa<Instruction>(U))
762 continue;
763 const Instruction *instr = cast<Instruction>(U);
764 const BasicBlock *bb = instr->getParent();
765 if (!bb)
766 continue;
767 const Function *caller = bb->getParent();
768 if (!caller)
769 continue;
770
771 // If a caller has already been seen, then the caller is
772 // appearing in the module before the callee. so print out
773 // a declaration for the callee.
774 if (seenMap.contains(caller)) {
775 emitDeclaration(&F, O);
776 break;
777 }
778 }
779 seenMap[&F] = true;
780 }
781 for (const GlobalAlias &GA : M.aliases())
782 emitAliasDeclaration(&GA, O);
783 }
784
isEmptyXXStructor(GlobalVariable * GV)785 static bool isEmptyXXStructor(GlobalVariable *GV) {
786 if (!GV) return true;
787 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
788 if (!InitList) return true; // Not an array; we don't know how to parse.
789 return InitList->getNumOperands() == 0;
790 }
791
emitStartOfAsmFile(Module & M)792 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
793 // Construct a default subtarget off of the TargetMachine defaults. The
794 // rest of NVPTX isn't friendly to change subtargets per function and
795 // so the default TargetMachine will have all of the options.
796 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
797 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
798 SmallString<128> Str1;
799 raw_svector_ostream OS1(Str1);
800
801 // Emit header before any dwarf directives are emitted below.
802 emitHeader(M, OS1, *STI);
803 OutStreamer->emitRawText(OS1.str());
804 }
805
doInitialization(Module & M)806 bool NVPTXAsmPrinter::doInitialization(Module &M) {
807 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
808 const NVPTXSubtarget &STI =
809 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
810 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
811 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
812
813 // OpenMP supports NVPTX global constructors and destructors.
814 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
815
816 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
817 !LowerCtorDtor && !IsOpenMP) {
818 report_fatal_error(
819 "Module has a nontrivial global ctor, which NVPTX does not support.");
820 return true; // error
821 }
822 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
823 !LowerCtorDtor && !IsOpenMP) {
824 report_fatal_error(
825 "Module has a nontrivial global dtor, which NVPTX does not support.");
826 return true; // error
827 }
828
829 // We need to call the parent's one explicitly.
830 bool Result = AsmPrinter::doInitialization(M);
831
832 GlobalsEmitted = false;
833
834 return Result;
835 }
836
emitGlobals(const Module & M)837 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
838 SmallString<128> Str2;
839 raw_svector_ostream OS2(Str2);
840
841 emitDeclarations(M, OS2);
842
843 // As ptxas does not support forward references of globals, we need to first
844 // sort the list of module-level globals in def-use order. We visit each
845 // global variable in order, and ensure that we emit it *after* its dependent
846 // globals. We use a little extra memory maintaining both a set and a list to
847 // have fast searches while maintaining a strict ordering.
848 SmallVector<const GlobalVariable *, 8> Globals;
849 DenseSet<const GlobalVariable *> GVVisited;
850 DenseSet<const GlobalVariable *> GVVisiting;
851
852 // Visit each global variable, in order
853 for (const GlobalVariable &I : M.globals())
854 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
855
856 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
857 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
858
859 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860 const NVPTXSubtarget &STI =
861 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
862
863 // Print out module-level global variables in proper order
864 for (const GlobalVariable *GV : Globals)
865 printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
866
867 OS2 << '\n';
868
869 OutStreamer->emitRawText(OS2.str());
870 }
871
emitGlobalAlias(const Module & M,const GlobalAlias & GA)872 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
873 SmallString<128> Str;
874 raw_svector_ostream OS(Str);
875
876 MCSymbol *Name = getSymbol(&GA);
877
878 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
879 << ";\n";
880
881 OutStreamer->emitRawText(OS.str());
882 }
883
emitHeader(Module & M,raw_ostream & O,const NVPTXSubtarget & STI)884 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
885 const NVPTXSubtarget &STI) {
886 O << "//\n";
887 O << "// Generated by LLVM NVPTX Back-End\n";
888 O << "//\n";
889 O << "\n";
890
891 unsigned PTXVersion = STI.getPTXVersion();
892 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
893
894 O << ".target ";
895 O << STI.getTargetName();
896
897 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
898 if (NTM.getDrvInterface() == NVPTX::NVCL)
899 O << ", texmode_independent";
900
901 bool HasFullDebugInfo = false;
902 for (DICompileUnit *CU : M.debug_compile_units()) {
903 switch(CU->getEmissionKind()) {
904 case DICompileUnit::NoDebug:
905 case DICompileUnit::DebugDirectivesOnly:
906 break;
907 case DICompileUnit::LineTablesOnly:
908 case DICompileUnit::FullDebug:
909 HasFullDebugInfo = true;
910 break;
911 }
912 if (HasFullDebugInfo)
913 break;
914 }
915 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
916 O << ", debug";
917
918 O << "\n";
919
920 O << ".address_size ";
921 if (NTM.is64Bit())
922 O << "64";
923 else
924 O << "32";
925 O << "\n";
926
927 O << "\n";
928 }
929
doFinalization(Module & M)930 bool NVPTXAsmPrinter::doFinalization(Module &M) {
931 bool HasDebugInfo = MMI && MMI->hasDebugInfo();
932
933 // If we did not emit any functions, then the global declarations have not
934 // yet been emitted.
935 if (!GlobalsEmitted) {
936 emitGlobals(M);
937 GlobalsEmitted = true;
938 }
939
940 // call doFinalization
941 bool ret = AsmPrinter::doFinalization(M);
942
943 clearAnnotationCache(&M);
944
945 auto *TS =
946 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
947 // Close the last emitted section
948 if (HasDebugInfo) {
949 TS->closeLastSection();
950 // Emit empty .debug_loc section for better support of the empty files.
951 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
952 }
953
954 // Output last DWARF .file directives, if any.
955 TS->outputDwarfFileDirectives();
956
957 return ret;
958 }
959
960 // This function emits appropriate linkage directives for
961 // functions and global variables.
962 //
963 // extern function declaration -> .extern
964 // extern function definition -> .visible
965 // external global variable with init -> .visible
966 // external without init -> .extern
967 // appending -> not allowed, assert.
968 // for any linkage other than
969 // internal, private, linker_private,
970 // linker_private_weak, linker_private_weak_def_auto,
971 // we emit -> .weak.
972
emitLinkageDirective(const GlobalValue * V,raw_ostream & O)973 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
974 raw_ostream &O) {
975 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
976 if (V->hasExternalLinkage()) {
977 if (isa<GlobalVariable>(V)) {
978 const GlobalVariable *GVar = cast<GlobalVariable>(V);
979 if (GVar) {
980 if (GVar->hasInitializer())
981 O << ".visible ";
982 else
983 O << ".extern ";
984 }
985 } else if (V->isDeclaration())
986 O << ".extern ";
987 else
988 O << ".visible ";
989 } else if (V->hasAppendingLinkage()) {
990 std::string msg;
991 msg.append("Error: ");
992 msg.append("Symbol ");
993 if (V->hasName())
994 msg.append(std::string(V->getName()));
995 msg.append("has unsupported appending linkage type");
996 llvm_unreachable(msg.c_str());
997 } else if (!V->hasInternalLinkage() &&
998 !V->hasPrivateLinkage()) {
999 O << ".weak ";
1000 }
1001 }
1002 }
1003
printModuleLevelGV(const GlobalVariable * GVar,raw_ostream & O,bool processDemoted,const NVPTXSubtarget & STI)1004 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1005 raw_ostream &O, bool processDemoted,
1006 const NVPTXSubtarget &STI) {
1007 // Skip meta data
1008 if (GVar->hasSection()) {
1009 if (GVar->getSection() == "llvm.metadata")
1010 return;
1011 }
1012
1013 // Skip LLVM intrinsic global variables
1014 if (GVar->getName().starts_with("llvm.") ||
1015 GVar->getName().starts_with("nvvm."))
1016 return;
1017
1018 const DataLayout &DL = getDataLayout();
1019
1020 // GlobalVariables are always constant pointers themselves.
1021 Type *ETy = GVar->getValueType();
1022
1023 if (GVar->hasExternalLinkage()) {
1024 if (GVar->hasInitializer())
1025 O << ".visible ";
1026 else
1027 O << ".extern ";
1028 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1029 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1030 O << ".common ";
1031 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1032 GVar->hasAvailableExternallyLinkage() ||
1033 GVar->hasCommonLinkage()) {
1034 O << ".weak ";
1035 }
1036
1037 if (isTexture(*GVar)) {
1038 O << ".global .texref " << getTextureName(*GVar) << ";\n";
1039 return;
1040 }
1041
1042 if (isSurface(*GVar)) {
1043 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1044 return;
1045 }
1046
1047 if (GVar->isDeclaration()) {
1048 // (extern) declarations, no definition or initializer
1049 // Currently the only known declaration is for an automatic __local
1050 // (.shared) promoted to global.
1051 emitPTXGlobalVariable(GVar, O, STI);
1052 O << ";\n";
1053 return;
1054 }
1055
1056 if (isSampler(*GVar)) {
1057 O << ".global .samplerref " << getSamplerName(*GVar);
1058
1059 const Constant *Initializer = nullptr;
1060 if (GVar->hasInitializer())
1061 Initializer = GVar->getInitializer();
1062 const ConstantInt *CI = nullptr;
1063 if (Initializer)
1064 CI = dyn_cast<ConstantInt>(Initializer);
1065 if (CI) {
1066 unsigned sample = CI->getZExtValue();
1067
1068 O << " = { ";
1069
1070 for (int i = 0,
1071 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1072 i < 3; i++) {
1073 O << "addr_mode_" << i << " = ";
1074 switch (addr) {
1075 case 0:
1076 O << "wrap";
1077 break;
1078 case 1:
1079 O << "clamp_to_border";
1080 break;
1081 case 2:
1082 O << "clamp_to_edge";
1083 break;
1084 case 3:
1085 O << "wrap";
1086 break;
1087 case 4:
1088 O << "mirror";
1089 break;
1090 }
1091 O << ", ";
1092 }
1093 O << "filter_mode = ";
1094 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1095 case 0:
1096 O << "nearest";
1097 break;
1098 case 1:
1099 O << "linear";
1100 break;
1101 case 2:
1102 llvm_unreachable("Anisotropic filtering is not supported");
1103 default:
1104 O << "nearest";
1105 break;
1106 }
1107 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1108 O << ", force_unnormalized_coords = 1";
1109 }
1110 O << " }";
1111 }
1112
1113 O << ";\n";
1114 return;
1115 }
1116
1117 if (GVar->hasPrivateLinkage()) {
1118 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1119 return;
1120
1121 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1122 if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1123 return;
1124 if (GVar->use_empty())
1125 return;
1126 }
1127
1128 const Function *demotedFunc = nullptr;
1129 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1130 O << "// " << GVar->getName() << " has been demoted\n";
1131 if (localDecls.find(demotedFunc) != localDecls.end())
1132 localDecls[demotedFunc].push_back(GVar);
1133 else {
1134 std::vector<const GlobalVariable *> temp;
1135 temp.push_back(GVar);
1136 localDecls[demotedFunc] = temp;
1137 }
1138 return;
1139 }
1140
1141 O << ".";
1142 emitPTXAddressSpace(GVar->getAddressSpace(), O);
1143
1144 if (isManaged(*GVar)) {
1145 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1146 report_fatal_error(
1147 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1148 }
1149 O << " .attribute(.managed)";
1150 }
1151
1152 if (MaybeAlign A = GVar->getAlign())
1153 O << " .align " << A->value();
1154 else
1155 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1156
1157 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1158 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1159 O << " .";
1160 // Special case: ABI requires that we use .u8 for predicates
1161 if (ETy->isIntegerTy(1))
1162 O << "u8";
1163 else
1164 O << getPTXFundamentalTypeStr(ETy, false);
1165 O << " ";
1166 getSymbol(GVar)->print(O, MAI);
1167
1168 // Ptx allows variable initilization only for constant and global state
1169 // spaces.
1170 if (GVar->hasInitializer()) {
1171 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1172 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1173 const Constant *Initializer = GVar->getInitializer();
1174 // 'undef' is treated as there is no value specified.
1175 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1176 O << " = ";
1177 printScalarConstant(Initializer, O);
1178 }
1179 } else {
1180 // The frontend adds zero-initializer to device and constant variables
1181 // that don't have an initial value, and UndefValue to shared
1182 // variables, so skip warning for this case.
1183 if (!GVar->getInitializer()->isNullValue() &&
1184 !isa<UndefValue>(GVar->getInitializer())) {
1185 report_fatal_error("initial value of '" + GVar->getName() +
1186 "' is not allowed in addrspace(" +
1187 Twine(GVar->getAddressSpace()) + ")");
1188 }
1189 }
1190 }
1191 } else {
1192 uint64_t ElementSize = 0;
1193
1194 // Although PTX has direct support for struct type and array type and
1195 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1196 // targets that support these high level field accesses. Structs, arrays
1197 // and vectors are lowered into arrays of bytes.
1198 switch (ETy->getTypeID()) {
1199 case Type::IntegerTyID: // Integers larger than 64 bits
1200 case Type::StructTyID:
1201 case Type::ArrayTyID:
1202 case Type::FixedVectorTyID:
1203 ElementSize = DL.getTypeStoreSize(ETy);
1204 // Ptx allows variable initilization only for constant and
1205 // global state spaces.
1206 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1207 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1208 GVar->hasInitializer()) {
1209 const Constant *Initializer = GVar->getInitializer();
1210 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1211 AggBuffer aggBuffer(ElementSize, *this);
1212 bufferAggregateConstant(Initializer, &aggBuffer);
1213 if (aggBuffer.numSymbols()) {
1214 unsigned int ptrSize = MAI->getCodePointerSize();
1215 if (ElementSize % ptrSize ||
1216 !aggBuffer.allSymbolsAligned(ptrSize)) {
1217 // Print in bytes and use the mask() operator for pointers.
1218 if (!STI.hasMaskOperator())
1219 report_fatal_error(
1220 "initialized packed aggregate with pointers '" +
1221 GVar->getName() +
1222 "' requires at least PTX ISA version 7.1");
1223 O << " .u8 ";
1224 getSymbol(GVar)->print(O, MAI);
1225 O << "[" << ElementSize << "] = {";
1226 aggBuffer.printBytes(O);
1227 O << "}";
1228 } else {
1229 O << " .u" << ptrSize * 8 << " ";
1230 getSymbol(GVar)->print(O, MAI);
1231 O << "[" << ElementSize / ptrSize << "] = {";
1232 aggBuffer.printWords(O);
1233 O << "}";
1234 }
1235 } else {
1236 O << " .b8 ";
1237 getSymbol(GVar)->print(O, MAI);
1238 O << "[" << ElementSize << "] = {";
1239 aggBuffer.printBytes(O);
1240 O << "}";
1241 }
1242 } else {
1243 O << " .b8 ";
1244 getSymbol(GVar)->print(O, MAI);
1245 if (ElementSize) {
1246 O << "[";
1247 O << ElementSize;
1248 O << "]";
1249 }
1250 }
1251 } else {
1252 O << " .b8 ";
1253 getSymbol(GVar)->print(O, MAI);
1254 if (ElementSize) {
1255 O << "[";
1256 O << ElementSize;
1257 O << "]";
1258 }
1259 }
1260 break;
1261 default:
1262 llvm_unreachable("type not supported yet");
1263 }
1264 }
1265 O << ";\n";
1266 }
1267
printSymbol(unsigned nSym,raw_ostream & os)1268 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1269 const Value *v = Symbols[nSym];
1270 const Value *v0 = SymbolsBeforeStripping[nSym];
1271 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1272 MCSymbol *Name = AP.getSymbol(GVar);
1273 PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1274 // Is v0 a generic pointer?
1275 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1276 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1277 os << "generic(";
1278 Name->print(os, AP.MAI);
1279 os << ")";
1280 } else {
1281 Name->print(os, AP.MAI);
1282 }
1283 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1284 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1285 AP.printMCExpr(*Expr, os);
1286 } else
1287 llvm_unreachable("symbol type unknown");
1288 }
1289
printBytes(raw_ostream & os)1290 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1291 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1292 // Do not emit trailing zero initializers. They will be zero-initialized by
1293 // ptxas. This saves on both space requirements for the generated PTX and on
1294 // memory use by ptxas. (See:
1295 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1296 unsigned int InitializerCount = size;
1297 // TODO: symbols make this harder, but it would still be good to trim trailing
1298 // 0s for aggs with symbols as well.
1299 if (numSymbols() == 0)
1300 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1301 InitializerCount--;
1302
1303 symbolPosInBuffer.push_back(InitializerCount);
1304 unsigned int nSym = 0;
1305 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1306 for (unsigned int pos = 0; pos < InitializerCount;) {
1307 if (pos)
1308 os << ", ";
1309 if (pos != nextSymbolPos) {
1310 os << (unsigned int)buffer[pos];
1311 ++pos;
1312 continue;
1313 }
1314 // Generate a per-byte mask() operator for the symbol, which looks like:
1315 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1316 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1317 std::string symText;
1318 llvm::raw_string_ostream oss(symText);
1319 printSymbol(nSym, oss);
1320 for (unsigned i = 0; i < ptrSize; ++i) {
1321 if (i)
1322 os << ", ";
1323 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1324 os << "(" << symText << ")";
1325 }
1326 pos += ptrSize;
1327 nextSymbolPos = symbolPosInBuffer[++nSym];
1328 assert(nextSymbolPos >= pos);
1329 }
1330 }
1331
printWords(raw_ostream & os)1332 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1333 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1334 symbolPosInBuffer.push_back(size);
1335 unsigned int nSym = 0;
1336 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1337 assert(nextSymbolPos % ptrSize == 0);
1338 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1339 if (pos)
1340 os << ", ";
1341 if (pos == nextSymbolPos) {
1342 printSymbol(nSym, os);
1343 nextSymbolPos = symbolPosInBuffer[++nSym];
1344 assert(nextSymbolPos % ptrSize == 0);
1345 assert(nextSymbolPos >= pos + ptrSize);
1346 } else if (ptrSize == 4)
1347 os << support::endian::read32le(&buffer[pos]);
1348 else
1349 os << support::endian::read64le(&buffer[pos]);
1350 }
1351 }
1352
emitDemotedVars(const Function * f,raw_ostream & O)1353 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1354 if (localDecls.find(f) == localDecls.end())
1355 return;
1356
1357 std::vector<const GlobalVariable *> &gvars = localDecls[f];
1358
1359 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1360 const NVPTXSubtarget &STI =
1361 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1362
1363 for (const GlobalVariable *GV : gvars) {
1364 O << "\t// demoted variable\n\t";
1365 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1366 }
1367 }
1368
emitPTXAddressSpace(unsigned int AddressSpace,raw_ostream & O) const1369 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1370 raw_ostream &O) const {
1371 switch (AddressSpace) {
1372 case ADDRESS_SPACE_LOCAL:
1373 O << "local";
1374 break;
1375 case ADDRESS_SPACE_GLOBAL:
1376 O << "global";
1377 break;
1378 case ADDRESS_SPACE_CONST:
1379 O << "const";
1380 break;
1381 case ADDRESS_SPACE_SHARED:
1382 O << "shared";
1383 break;
1384 default:
1385 report_fatal_error("Bad address space found while emitting PTX: " +
1386 llvm::Twine(AddressSpace));
1387 break;
1388 }
1389 }
1390
1391 std::string
getPTXFundamentalTypeStr(Type * Ty,bool useB4PTR) const1392 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1393 switch (Ty->getTypeID()) {
1394 case Type::IntegerTyID: {
1395 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1396 if (NumBits == 1)
1397 return "pred";
1398 else if (NumBits <= 64) {
1399 std::string name = "u";
1400 return name + utostr(NumBits);
1401 } else {
1402 llvm_unreachable("Integer too large");
1403 break;
1404 }
1405 break;
1406 }
1407 case Type::BFloatTyID:
1408 case Type::HalfTyID:
1409 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1410 // PTX assembly.
1411 return "b16";
1412 case Type::FloatTyID:
1413 return "f32";
1414 case Type::DoubleTyID:
1415 return "f64";
1416 case Type::PointerTyID: {
1417 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1418 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1419
1420 if (PtrSize == 64)
1421 if (useB4PTR)
1422 return "b64";
1423 else
1424 return "u64";
1425 else if (useB4PTR)
1426 return "b32";
1427 else
1428 return "u32";
1429 }
1430 default:
1431 break;
1432 }
1433 llvm_unreachable("unexpected type");
1434 }
1435
emitPTXGlobalVariable(const GlobalVariable * GVar,raw_ostream & O,const NVPTXSubtarget & STI)1436 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1437 raw_ostream &O,
1438 const NVPTXSubtarget &STI) {
1439 const DataLayout &DL = getDataLayout();
1440
1441 // GlobalVariables are always constant pointers themselves.
1442 Type *ETy = GVar->getValueType();
1443
1444 O << ".";
1445 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1446 if (isManaged(*GVar)) {
1447 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1448 report_fatal_error(
1449 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1450 }
1451 O << " .attribute(.managed)";
1452 }
1453 if (MaybeAlign A = GVar->getAlign())
1454 O << " .align " << A->value();
1455 else
1456 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1457
1458 // Special case for i128
1459 if (ETy->isIntegerTy(128)) {
1460 O << " .b8 ";
1461 getSymbol(GVar)->print(O, MAI);
1462 O << "[16]";
1463 return;
1464 }
1465
1466 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1467 O << " .";
1468 O << getPTXFundamentalTypeStr(ETy);
1469 O << " ";
1470 getSymbol(GVar)->print(O, MAI);
1471 return;
1472 }
1473
1474 int64_t ElementSize = 0;
1475
1476 // Although PTX has direct support for struct type and array type and LLVM IR
1477 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1478 // support these high level field accesses. Structs and arrays are lowered
1479 // into arrays of bytes.
1480 switch (ETy->getTypeID()) {
1481 case Type::StructTyID:
1482 case Type::ArrayTyID:
1483 case Type::FixedVectorTyID:
1484 ElementSize = DL.getTypeStoreSize(ETy);
1485 O << " .b8 ";
1486 getSymbol(GVar)->print(O, MAI);
1487 O << "[";
1488 if (ElementSize) {
1489 O << ElementSize;
1490 }
1491 O << "]";
1492 break;
1493 default:
1494 llvm_unreachable("type not supported yet");
1495 }
1496 }
1497
emitFunctionParamList(const Function * F,raw_ostream & O)1498 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1499 const DataLayout &DL = getDataLayout();
1500 const AttributeList &PAL = F->getAttributes();
1501 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1502 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1503
1504 Function::const_arg_iterator I, E;
1505 unsigned paramIndex = 0;
1506 bool first = true;
1507 bool isKernelFunc = isKernelFunction(*F);
1508 bool isABI = (STI.getSmVersion() >= 20);
1509 bool hasImageHandles = STI.hasImageHandles();
1510
1511 if (F->arg_empty() && !F->isVarArg()) {
1512 O << "()";
1513 return;
1514 }
1515
1516 O << "(\n";
1517
1518 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1519 Type *Ty = I->getType();
1520
1521 if (!first)
1522 O << ",\n";
1523
1524 first = false;
1525
1526 // Handle image/sampler parameters
1527 if (isKernelFunction(*F)) {
1528 if (isSampler(*I) || isImage(*I)) {
1529 if (isImage(*I)) {
1530 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1531 if (hasImageHandles)
1532 O << "\t.param .u64 .ptr .surfref ";
1533 else
1534 O << "\t.param .surfref ";
1535 O << TLI->getParamName(F, paramIndex);
1536 }
1537 else { // Default image is read_only
1538 if (hasImageHandles)
1539 O << "\t.param .u64 .ptr .texref ";
1540 else
1541 O << "\t.param .texref ";
1542 O << TLI->getParamName(F, paramIndex);
1543 }
1544 } else {
1545 if (hasImageHandles)
1546 O << "\t.param .u64 .ptr .samplerref ";
1547 else
1548 O << "\t.param .samplerref ";
1549 O << TLI->getParamName(F, paramIndex);
1550 }
1551 continue;
1552 }
1553 }
1554
1555 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1556 paramIndex](Type *Ty) -> Align {
1557 if (MaybeAlign StackAlign =
1558 getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1559 return StackAlign.value();
1560
1561 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1562 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1563 return std::max(TypeAlign, ParamAlign.valueOrOne());
1564 };
1565
1566 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1567 if (ShouldPassAsArray(Ty)) {
1568 // Just print .param .align <a> .b8 .param[size];
1569 // <a> = optimal alignment for the element type; always multiple of
1570 // PAL.getParamAlignment
1571 // size = typeallocsize of element type
1572 Align OptimalAlign = getOptimalAlignForParam(Ty);
1573
1574 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1575 O << TLI->getParamName(F, paramIndex);
1576 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1577
1578 continue;
1579 }
1580 // Just a scalar
1581 auto *PTy = dyn_cast<PointerType>(Ty);
1582 unsigned PTySizeInBits = 0;
1583 if (PTy) {
1584 PTySizeInBits =
1585 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1586 assert(PTySizeInBits && "Invalid pointer size");
1587 }
1588
1589 if (isKernelFunc) {
1590 if (PTy) {
1591 // Special handling for pointer arguments to kernel
1592 O << "\t.param .u" << PTySizeInBits << " ";
1593
1594 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1595 NVPTX::CUDA) {
1596 int addrSpace = PTy->getAddressSpace();
1597 switch (addrSpace) {
1598 default:
1599 O << ".ptr ";
1600 break;
1601 case ADDRESS_SPACE_CONST:
1602 O << ".ptr .const ";
1603 break;
1604 case ADDRESS_SPACE_SHARED:
1605 O << ".ptr .shared ";
1606 break;
1607 case ADDRESS_SPACE_GLOBAL:
1608 O << ".ptr .global ";
1609 break;
1610 }
1611 Align ParamAlign = I->getParamAlign().valueOrOne();
1612 O << ".align " << ParamAlign.value() << " ";
1613 }
1614 O << TLI->getParamName(F, paramIndex);
1615 continue;
1616 }
1617
1618 // non-pointer scalar to kernel func
1619 O << "\t.param .";
1620 // Special case: predicate operands become .u8 types
1621 if (Ty->isIntegerTy(1))
1622 O << "u8";
1623 else
1624 O << getPTXFundamentalTypeStr(Ty);
1625 O << " ";
1626 O << TLI->getParamName(F, paramIndex);
1627 continue;
1628 }
1629 // Non-kernel function, just print .param .b<size> for ABI
1630 // and .reg .b<size> for non-ABI
1631 unsigned sz = 0;
1632 if (isa<IntegerType>(Ty)) {
1633 sz = cast<IntegerType>(Ty)->getBitWidth();
1634 sz = promoteScalarArgumentSize(sz);
1635 } else if (PTy) {
1636 assert(PTySizeInBits && "Invalid pointer size");
1637 sz = PTySizeInBits;
1638 } else
1639 sz = Ty->getPrimitiveSizeInBits();
1640 if (isABI)
1641 O << "\t.param .b" << sz << " ";
1642 else
1643 O << "\t.reg .b" << sz << " ";
1644 O << TLI->getParamName(F, paramIndex);
1645 continue;
1646 }
1647
1648 // param has byVal attribute.
1649 Type *ETy = PAL.getParamByValType(paramIndex);
1650 assert(ETy && "Param should have byval type");
1651
1652 if (isABI || isKernelFunc) {
1653 // Just print .param .align <a> .b8 .param[size];
1654 // <a> = optimal alignment for the element type; always multiple of
1655 // PAL.getParamAlignment
1656 // size = typeallocsize of element type
1657 Align OptimalAlign =
1658 isKernelFunc
1659 ? getOptimalAlignForParam(ETy)
1660 : TLI->getFunctionByValParamAlign(
1661 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1662
1663 unsigned sz = DL.getTypeAllocSize(ETy);
1664 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1665 O << TLI->getParamName(F, paramIndex);
1666 O << "[" << sz << "]";
1667 continue;
1668 } else {
1669 // Split the ETy into constituent parts and
1670 // print .param .b<size> <name> for each part.
1671 // Further, if a part is vector, print the above for
1672 // each vector element.
1673 SmallVector<EVT, 16> vtparts;
1674 ComputeValueVTs(*TLI, DL, ETy, vtparts);
1675 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1676 unsigned elems = 1;
1677 EVT elemtype = vtparts[i];
1678 if (vtparts[i].isVector()) {
1679 elems = vtparts[i].getVectorNumElements();
1680 elemtype = vtparts[i].getVectorElementType();
1681 }
1682
1683 for (unsigned j = 0, je = elems; j != je; ++j) {
1684 unsigned sz = elemtype.getSizeInBits();
1685 if (elemtype.isInteger())
1686 sz = promoteScalarArgumentSize(sz);
1687 O << "\t.reg .b" << sz << " ";
1688 O << TLI->getParamName(F, paramIndex);
1689 if (j < je - 1)
1690 O << ",\n";
1691 ++paramIndex;
1692 }
1693 if (i < e - 1)
1694 O << ",\n";
1695 }
1696 --paramIndex;
1697 continue;
1698 }
1699 }
1700
1701 if (F->isVarArg()) {
1702 if (!first)
1703 O << ",\n";
1704 O << "\t.param .align " << STI.getMaxRequiredAlignment();
1705 O << " .b8 ";
1706 O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1707 }
1708
1709 O << "\n)";
1710 }
1711
setAndEmitFunctionVirtualRegisters(const MachineFunction & MF)1712 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1713 const MachineFunction &MF) {
1714 SmallString<128> Str;
1715 raw_svector_ostream O(Str);
1716
1717 // Map the global virtual register number to a register class specific
1718 // virtual register number starting from 1 with that class.
1719 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1720 //unsigned numRegClasses = TRI->getNumRegClasses();
1721
1722 // Emit the Fake Stack Object
1723 const MachineFrameInfo &MFI = MF.getFrameInfo();
1724 int64_t NumBytes = MFI.getStackSize();
1725 if (NumBytes) {
1726 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1727 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1728 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1729 O << "\t.reg .b64 \t%SP;\n";
1730 O << "\t.reg .b64 \t%SPL;\n";
1731 } else {
1732 O << "\t.reg .b32 \t%SP;\n";
1733 O << "\t.reg .b32 \t%SPL;\n";
1734 }
1735 }
1736
1737 // Go through all virtual registers to establish the mapping between the
1738 // global virtual
1739 // register number and the per class virtual register number.
1740 // We use the per class virtual register number in the ptx output.
1741 unsigned int numVRs = MRI->getNumVirtRegs();
1742 for (unsigned i = 0; i < numVRs; i++) {
1743 Register vr = Register::index2VirtReg(i);
1744 const TargetRegisterClass *RC = MRI->getRegClass(vr);
1745 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1746 int n = regmap.size();
1747 regmap.insert(std::make_pair(vr, n + 1));
1748 }
1749
1750 // Emit register declarations
1751 // @TODO: Extract out the real register usage
1752 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1753 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1754 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1755 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1756 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1757 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1758 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1759
1760 // Emit declaration of the virtual registers or 'physical' registers for
1761 // each register class
1762 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1763 const TargetRegisterClass *RC = TRI->getRegClass(i);
1764 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1765 std::string rcname = getNVPTXRegClassName(RC);
1766 std::string rcStr = getNVPTXRegClassStr(RC);
1767 int n = regmap.size();
1768
1769 // Only declare those registers that may be used.
1770 if (n) {
1771 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1772 << ">;\n";
1773 }
1774 }
1775
1776 OutStreamer->emitRawText(O.str());
1777 }
1778
printFPConstant(const ConstantFP * Fp,raw_ostream & O)1779 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1780 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1781 bool ignored;
1782 unsigned int numHex;
1783 const char *lead;
1784
1785 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1786 numHex = 8;
1787 lead = "0f";
1788 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1789 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1790 numHex = 16;
1791 lead = "0d";
1792 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1793 } else
1794 llvm_unreachable("unsupported fp type");
1795
1796 APInt API = APF.bitcastToAPInt();
1797 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1798 }
1799
printScalarConstant(const Constant * CPV,raw_ostream & O)1800 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1801 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1802 O << CI->getValue();
1803 return;
1804 }
1805 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1806 printFPConstant(CFP, O);
1807 return;
1808 }
1809 if (isa<ConstantPointerNull>(CPV)) {
1810 O << "0";
1811 return;
1812 }
1813 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1814 bool IsNonGenericPointer = false;
1815 if (GVar->getType()->getAddressSpace() != 0) {
1816 IsNonGenericPointer = true;
1817 }
1818 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1819 O << "generic(";
1820 getSymbol(GVar)->print(O, MAI);
1821 O << ")";
1822 } else {
1823 getSymbol(GVar)->print(O, MAI);
1824 }
1825 return;
1826 }
1827 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1829 printMCExpr(*E, O);
1830 return;
1831 }
1832 llvm_unreachable("Not scalar type found in printScalarConstant()");
1833 }
1834
bufferLEByte(const Constant * CPV,int Bytes,AggBuffer * AggBuffer)1835 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1836 AggBuffer *AggBuffer) {
1837 const DataLayout &DL = getDataLayout();
1838 int AllocSize = DL.getTypeAllocSize(CPV->getType());
1839 if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1840 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1841 // only the space allocated by CPV.
1842 AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1843 return;
1844 }
1845
1846 // Helper for filling AggBuffer with APInts.
1847 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1848 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1849 SmallVector<unsigned char, 16> Buf(NumBytes);
1850 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1851 // input's bit width, and i1 arrays may not have a length that is a multuple
1852 // of 8. We handle the last byte separately, so we never request out of
1853 // bounds bits.
1854 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1855 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1856 }
1857 size_t LastBytePosition = (NumBytes - 1) * 8;
1858 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1859 Buf[NumBytes - 1] =
1860 Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);
1861 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1862 };
1863
1864 switch (CPV->getType()->getTypeID()) {
1865 case Type::IntegerTyID:
1866 if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1867 AddIntToBuffer(CI->getValue());
1868 break;
1869 }
1870 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1871 if (const auto *CI =
1872 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1873 AddIntToBuffer(CI->getValue());
1874 break;
1875 }
1876 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1877 Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1878 AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1879 AggBuffer->addZeros(AllocSize);
1880 break;
1881 }
1882 }
1883 llvm_unreachable("unsupported integer const type");
1884 break;
1885
1886 case Type::HalfTyID:
1887 case Type::BFloatTyID:
1888 case Type::FloatTyID:
1889 case Type::DoubleTyID:
1890 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1891 break;
1892
1893 case Type::PointerTyID: {
1894 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1895 AggBuffer->addSymbol(GVar, GVar);
1896 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1897 const Value *v = Cexpr->stripPointerCasts();
1898 AggBuffer->addSymbol(v, Cexpr);
1899 }
1900 AggBuffer->addZeros(AllocSize);
1901 break;
1902 }
1903
1904 case Type::ArrayTyID:
1905 case Type::FixedVectorTyID:
1906 case Type::StructTyID: {
1907 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1908 bufferAggregateConstant(CPV, AggBuffer);
1909 if (Bytes > AllocSize)
1910 AggBuffer->addZeros(Bytes - AllocSize);
1911 } else if (isa<ConstantAggregateZero>(CPV))
1912 AggBuffer->addZeros(Bytes);
1913 else
1914 llvm_unreachable("Unexpected Constant type");
1915 break;
1916 }
1917
1918 default:
1919 llvm_unreachable("unsupported type");
1920 }
1921 }
1922
bufferAggregateConstant(const Constant * CPV,AggBuffer * aggBuffer)1923 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1924 AggBuffer *aggBuffer) {
1925 const DataLayout &DL = getDataLayout();
1926 int Bytes;
1927
1928 // Integers of arbitrary width
1929 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1930 APInt Val = CI->getValue();
1931 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1932 uint8_t Byte = Val.getLoBits(8).getZExtValue();
1933 aggBuffer->addBytes(&Byte, 1, 1);
1934 Val.lshrInPlace(8);
1935 }
1936 return;
1937 }
1938
1939 // Old constants
1940 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1941 if (CPV->getNumOperands())
1942 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1943 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1944 return;
1945 }
1946
1947 if (const ConstantDataSequential *CDS =
1948 dyn_cast<ConstantDataSequential>(CPV)) {
1949 if (CDS->getNumElements())
1950 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1951 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1952 aggBuffer);
1953 return;
1954 }
1955
1956 if (isa<ConstantStruct>(CPV)) {
1957 if (CPV->getNumOperands()) {
1958 StructType *ST = cast<StructType>(CPV->getType());
1959 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1960 if (i == (e - 1))
1961 Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1962 DL.getTypeAllocSize(ST) -
1963 DL.getStructLayout(ST)->getElementOffset(i);
1964 else
1965 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1966 DL.getStructLayout(ST)->getElementOffset(i);
1967 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1968 }
1969 }
1970 return;
1971 }
1972 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1973 }
1974
1975 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1976 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1977 /// expressions that are representable in PTX and create
1978 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1979 const MCExpr *
lowerConstantForGV(const Constant * CV,bool ProcessingGeneric)1980 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1981 MCContext &Ctx = OutContext;
1982
1983 if (CV->isNullValue() || isa<UndefValue>(CV))
1984 return MCConstantExpr::create(0, Ctx);
1985
1986 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1987 return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1988
1989 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1990 const MCSymbolRefExpr *Expr =
1991 MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1992 if (ProcessingGeneric) {
1993 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1994 } else {
1995 return Expr;
1996 }
1997 }
1998
1999 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2000 if (!CE) {
2001 llvm_unreachable("Unknown constant value to lower!");
2002 }
2003
2004 switch (CE->getOpcode()) {
2005 default:
2006 break; // Error
2007
2008 case Instruction::AddrSpaceCast: {
2009 // Strip the addrspacecast and pass along the operand
2010 PointerType *DstTy = cast<PointerType>(CE->getType());
2011 if (DstTy->getAddressSpace() == 0)
2012 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2013
2014 break; // Error
2015 }
2016
2017 case Instruction::GetElementPtr: {
2018 const DataLayout &DL = getDataLayout();
2019
2020 // Generate a symbolic expression for the byte address
2021 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2022 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2023
2024 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2025 ProcessingGeneric);
2026 if (!OffsetAI)
2027 return Base;
2028
2029 int64_t Offset = OffsetAI.getSExtValue();
2030 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2031 Ctx);
2032 }
2033
2034 case Instruction::Trunc:
2035 // We emit the value and depend on the assembler to truncate the generated
2036 // expression properly. This is important for differences between
2037 // blockaddress labels. Since the two labels are in the same function, it
2038 // is reasonable to treat their delta as a 32-bit value.
2039 [[fallthrough]];
2040 case Instruction::BitCast:
2041 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2042
2043 case Instruction::IntToPtr: {
2044 const DataLayout &DL = getDataLayout();
2045
2046 // Handle casts to pointers by changing them into casts to the appropriate
2047 // integer type. This promotes constant folding and simplifies this code.
2048 Constant *Op = CE->getOperand(0);
2049 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2050 /*IsSigned*/ false, DL);
2051 if (Op)
2052 return lowerConstantForGV(Op, ProcessingGeneric);
2053
2054 break; // Error
2055 }
2056
2057 case Instruction::PtrToInt: {
2058 const DataLayout &DL = getDataLayout();
2059
2060 // Support only foldable casts to/from pointers that can be eliminated by
2061 // changing the pointer to the appropriately sized integer type.
2062 Constant *Op = CE->getOperand(0);
2063 Type *Ty = CE->getType();
2064
2065 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2066
2067 // We can emit the pointer value into this slot if the slot is an
2068 // integer slot equal to the size of the pointer.
2069 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2070 return OpExpr;
2071
2072 // Otherwise the pointer is smaller than the resultant integer, mask off
2073 // the high bits so we are sure to get a proper truncation if the input is
2074 // a constant expr.
2075 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2076 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2077 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2078 }
2079
2080 // The MC library also has a right-shift operator, but it isn't consistently
2081 // signed or unsigned between different targets.
2082 case Instruction::Add: {
2083 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2084 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2085 switch (CE->getOpcode()) {
2086 default: llvm_unreachable("Unknown binary operator constant cast expr");
2087 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2088 }
2089 }
2090 }
2091
2092 // If the code isn't optimized, there may be outstanding folding
2093 // opportunities. Attempt to fold the expression using DataLayout as a
2094 // last resort before giving up.
2095 Constant *C = ConstantFoldConstant(CE, getDataLayout());
2096 if (C != CE)
2097 return lowerConstantForGV(C, ProcessingGeneric);
2098
2099 // Otherwise report the problem to the user.
2100 std::string S;
2101 raw_string_ostream OS(S);
2102 OS << "Unsupported expression in static initializer: ";
2103 CE->printAsOperand(OS, /*PrintType=*/false,
2104 !MF ? nullptr : MF->getFunction().getParent());
2105 report_fatal_error(Twine(OS.str()));
2106 }
2107
2108 // Copy of MCExpr::print customized for NVPTX
printMCExpr(const MCExpr & Expr,raw_ostream & OS)2109 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2110 switch (Expr.getKind()) {
2111 case MCExpr::Target:
2112 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2113 case MCExpr::Constant:
2114 OS << cast<MCConstantExpr>(Expr).getValue();
2115 return;
2116
2117 case MCExpr::SymbolRef: {
2118 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2119 const MCSymbol &Sym = SRE.getSymbol();
2120 Sym.print(OS, MAI);
2121 return;
2122 }
2123
2124 case MCExpr::Unary: {
2125 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2126 switch (UE.getOpcode()) {
2127 case MCUnaryExpr::LNot: OS << '!'; break;
2128 case MCUnaryExpr::Minus: OS << '-'; break;
2129 case MCUnaryExpr::Not: OS << '~'; break;
2130 case MCUnaryExpr::Plus: OS << '+'; break;
2131 }
2132 printMCExpr(*UE.getSubExpr(), OS);
2133 return;
2134 }
2135
2136 case MCExpr::Binary: {
2137 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2138
2139 // Only print parens around the LHS if it is non-trivial.
2140 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2141 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2142 printMCExpr(*BE.getLHS(), OS);
2143 } else {
2144 OS << '(';
2145 printMCExpr(*BE.getLHS(), OS);
2146 OS<< ')';
2147 }
2148
2149 switch (BE.getOpcode()) {
2150 case MCBinaryExpr::Add:
2151 // Print "X-42" instead of "X+-42".
2152 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2153 if (RHSC->getValue() < 0) {
2154 OS << RHSC->getValue();
2155 return;
2156 }
2157 }
2158
2159 OS << '+';
2160 break;
2161 default: llvm_unreachable("Unhandled binary operator");
2162 }
2163
2164 // Only print parens around the LHS if it is non-trivial.
2165 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2166 printMCExpr(*BE.getRHS(), OS);
2167 } else {
2168 OS << '(';
2169 printMCExpr(*BE.getRHS(), OS);
2170 OS << ')';
2171 }
2172 return;
2173 }
2174 }
2175
2176 llvm_unreachable("Invalid expression kind!");
2177 }
2178
2179 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2180 ///
PrintAsmOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2181 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2182 const char *ExtraCode, raw_ostream &O) {
2183 if (ExtraCode && ExtraCode[0]) {
2184 if (ExtraCode[1] != 0)
2185 return true; // Unknown modifier.
2186
2187 switch (ExtraCode[0]) {
2188 default:
2189 // See if this is a generic print operand
2190 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2191 case 'r':
2192 break;
2193 }
2194 }
2195
2196 printOperand(MI, OpNo, O);
2197
2198 return false;
2199 }
2200
PrintAsmMemoryOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2201 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2202 unsigned OpNo,
2203 const char *ExtraCode,
2204 raw_ostream &O) {
2205 if (ExtraCode && ExtraCode[0])
2206 return true; // Unknown modifier
2207
2208 O << '[';
2209 printMemOperand(MI, OpNo, O);
2210 O << ']';
2211
2212 return false;
2213 }
2214
printOperand(const MachineInstr * MI,unsigned OpNum,raw_ostream & O)2215 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2216 raw_ostream &O) {
2217 const MachineOperand &MO = MI->getOperand(OpNum);
2218 switch (MO.getType()) {
2219 case MachineOperand::MO_Register:
2220 if (MO.getReg().isPhysical()) {
2221 if (MO.getReg() == NVPTX::VRDepot)
2222 O << DEPOTNAME << getFunctionNumber();
2223 else
2224 O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2225 } else {
2226 emitVirtualRegister(MO.getReg(), O);
2227 }
2228 break;
2229
2230 case MachineOperand::MO_Immediate:
2231 O << MO.getImm();
2232 break;
2233
2234 case MachineOperand::MO_FPImmediate:
2235 printFPConstant(MO.getFPImm(), O);
2236 break;
2237
2238 case MachineOperand::MO_GlobalAddress:
2239 PrintSymbolOperand(MO, O);
2240 break;
2241
2242 case MachineOperand::MO_MachineBasicBlock:
2243 MO.getMBB()->getSymbol()->print(O, MAI);
2244 break;
2245
2246 default:
2247 llvm_unreachable("Operand type not supported.");
2248 }
2249 }
2250
printMemOperand(const MachineInstr * MI,unsigned OpNum,raw_ostream & O,const char * Modifier)2251 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2252 raw_ostream &O, const char *Modifier) {
2253 printOperand(MI, OpNum, O);
2254
2255 if (Modifier && strcmp(Modifier, "add") == 0) {
2256 O << ", ";
2257 printOperand(MI, OpNum + 1, O);
2258 } else {
2259 if (MI->getOperand(OpNum + 1).isImm() &&
2260 MI->getOperand(OpNum + 1).getImm() == 0)
2261 return; // don't print ',0' or '+0'
2262 O << "+";
2263 printOperand(MI, OpNum + 1, O);
2264 }
2265 }
2266
2267 // Force static initialization.
LLVMInitializeNVPTXAsmPrinter()2268 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2269 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2270 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2271 }
2272