1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Twine.h" 37 #include "llvm/Analysis/ConstantFolding.h" 38 #include "llvm/CodeGen/Analysis.h" 39 #include "llvm/CodeGen/MachineBasicBlock.h" 40 #include "llvm/CodeGen/MachineFrameInfo.h" 41 #include "llvm/CodeGen/MachineFunction.h" 42 #include "llvm/CodeGen/MachineInstr.h" 43 #include "llvm/CodeGen/MachineLoopInfo.h" 44 #include "llvm/CodeGen/MachineModuleInfo.h" 45 #include "llvm/CodeGen/MachineOperand.h" 46 #include "llvm/CodeGen/MachineRegisterInfo.h" 47 #include "llvm/CodeGen/TargetRegisterInfo.h" 48 #include "llvm/CodeGen/ValueTypes.h" 49 #include "llvm/CodeGenTypes/MachineValueType.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalAlias.h" 61 #include "llvm/IR/GlobalValue.h" 62 #include "llvm/IR/GlobalVariable.h" 63 #include "llvm/IR/Instruction.h" 64 #include "llvm/IR/LLVMContext.h" 65 #include "llvm/IR/Module.h" 66 #include "llvm/IR/Operator.h" 67 #include "llvm/IR/Type.h" 68 #include "llvm/IR/User.h" 69 #include "llvm/MC/MCExpr.h" 70 #include "llvm/MC/MCInst.h" 71 #include "llvm/MC/MCInstrDesc.h" 72 #include "llvm/MC/MCStreamer.h" 73 #include "llvm/MC/MCSymbol.h" 74 #include "llvm/MC/TargetRegistry.h" 75 #include "llvm/Support/Alignment.h" 76 #include "llvm/Support/Casting.h" 77 #include "llvm/Support/CommandLine.h" 78 #include "llvm/Support/Endian.h" 79 #include "llvm/Support/ErrorHandling.h" 80 #include "llvm/Support/NativeFormatting.h" 81 #include "llvm/Support/Path.h" 82 #include "llvm/Support/raw_ostream.h" 83 #include "llvm/Target/TargetLoweringObjectFile.h" 84 #include "llvm/Target/TargetMachine.h" 85 #include "llvm/TargetParser/Triple.h" 86 #include "llvm/Transforms/Utils/UnrollLoop.h" 87 #include <cassert> 88 #include <cstdint> 89 #include <cstring> 90 #include <new> 91 #include <string> 92 #include <utility> 93 #include <vector> 94 95 using namespace llvm; 96 97 static cl::opt<bool> 98 LowerCtorDtor("nvptx-lower-global-ctor-dtor", 99 cl::desc("Lower GPU ctor / dtors to globals on the device."), 100 cl::init(false), cl::Hidden); 101 102 #define DEPOTNAME "__local_depot" 103 104 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 105 /// depends. 106 static void 107 DiscoverDependentGlobals(const Value *V, 108 DenseSet<const GlobalVariable *> &Globals) { 109 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 110 Globals.insert(GV); 111 else { 112 if (const User *U = dyn_cast<User>(V)) { 113 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 114 DiscoverDependentGlobals(U->getOperand(i), Globals); 115 } 116 } 117 } 118 } 119 120 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 121 /// instances to be emitted, but only after any dependents have been added 122 /// first.s 123 static void 124 VisitGlobalVariableForEmission(const GlobalVariable *GV, 125 SmallVectorImpl<const GlobalVariable *> &Order, 126 DenseSet<const GlobalVariable *> &Visited, 127 DenseSet<const GlobalVariable *> &Visiting) { 128 // Have we already visited this one? 129 if (Visited.count(GV)) 130 return; 131 132 // Do we have a circular dependency? 133 if (!Visiting.insert(GV).second) 134 report_fatal_error("Circular dependency found in global variable set"); 135 136 // Make sure we visit all dependents first 137 DenseSet<const GlobalVariable *> Others; 138 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 139 DiscoverDependentGlobals(GV->getOperand(i), Others); 140 141 for (const GlobalVariable *GV : Others) 142 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 143 144 // Now we can visit ourself 145 Order.push_back(GV); 146 Visited.insert(GV); 147 Visiting.erase(GV); 148 } 149 150 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 151 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 152 getSubtargetInfo().getFeatureBits()); 153 154 MCInst Inst; 155 lowerToMCInst(MI, Inst); 156 EmitToStreamer(*OutStreamer, Inst); 157 } 158 159 // Handle symbol backtracking for targets that do not support image handles 160 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 161 unsigned OpNo, MCOperand &MCOp) { 162 const MachineOperand &MO = MI->getOperand(OpNo); 163 const MCInstrDesc &MCID = MI->getDesc(); 164 165 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 166 // This is a texture fetch, so operand 4 is a texref and operand 5 is 167 // a samplerref 168 if (OpNo == 4 && MO.isImm()) { 169 lowerImageHandleSymbol(MO.getImm(), MCOp); 170 return true; 171 } 172 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 173 lowerImageHandleSymbol(MO.getImm(), MCOp); 174 return true; 175 } 176 177 return false; 178 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 179 unsigned VecSize = 180 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 181 182 // For a surface load of vector size N, the Nth operand will be the surfref 183 if (OpNo == VecSize && MO.isImm()) { 184 lowerImageHandleSymbol(MO.getImm(), MCOp); 185 return true; 186 } 187 188 return false; 189 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 190 // This is a surface store, so operand 0 is a surfref 191 if (OpNo == 0 && MO.isImm()) { 192 lowerImageHandleSymbol(MO.getImm(), MCOp); 193 return true; 194 } 195 196 return false; 197 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 198 // This is a query, so operand 1 is a surfref/texref 199 if (OpNo == 1 && MO.isImm()) { 200 lowerImageHandleSymbol(MO.getImm(), MCOp); 201 return true; 202 } 203 204 return false; 205 } 206 207 return false; 208 } 209 210 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 211 // Ewwww 212 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 213 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 214 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 215 const char *Sym = MFI->getImageHandleSymbol(Index); 216 StringRef SymName = nvTM.getStrPool().save(Sym); 217 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); 218 } 219 220 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 221 OutMI.setOpcode(MI->getOpcode()); 222 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 223 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 224 const MachineOperand &MO = MI->getOperand(0); 225 OutMI.addOperand(GetSymbolRef( 226 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 227 return; 228 } 229 230 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 231 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 232 const MachineOperand &MO = MI->getOperand(i); 233 234 MCOperand MCOp; 235 if (!STI.hasImageHandles()) { 236 if (lowerImageHandleOperand(MI, i, MCOp)) { 237 OutMI.addOperand(MCOp); 238 continue; 239 } 240 } 241 242 if (lowerOperand(MO, MCOp)) 243 OutMI.addOperand(MCOp); 244 } 245 } 246 247 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 248 MCOperand &MCOp) { 249 switch (MO.getType()) { 250 default: llvm_unreachable("unknown operand type"); 251 case MachineOperand::MO_Register: 252 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 253 break; 254 case MachineOperand::MO_Immediate: 255 MCOp = MCOperand::createImm(MO.getImm()); 256 break; 257 case MachineOperand::MO_MachineBasicBlock: 258 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 259 MO.getMBB()->getSymbol(), OutContext)); 260 break; 261 case MachineOperand::MO_ExternalSymbol: 262 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 263 break; 264 case MachineOperand::MO_GlobalAddress: 265 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 266 break; 267 case MachineOperand::MO_FPImmediate: { 268 const ConstantFP *Cnt = MO.getFPImm(); 269 const APFloat &Val = Cnt->getValueAPF(); 270 271 switch (Cnt->getType()->getTypeID()) { 272 default: report_fatal_error("Unsupported FP type"); break; 273 case Type::HalfTyID: 274 MCOp = MCOperand::createExpr( 275 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 276 break; 277 case Type::BFloatTyID: 278 MCOp = MCOperand::createExpr( 279 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); 280 break; 281 case Type::FloatTyID: 282 MCOp = MCOperand::createExpr( 283 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 284 break; 285 case Type::DoubleTyID: 286 MCOp = MCOperand::createExpr( 287 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 288 break; 289 } 290 break; 291 } 292 } 293 return true; 294 } 295 296 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 297 if (Register::isVirtualRegister(Reg)) { 298 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 299 300 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 301 unsigned RegNum = RegMap[Reg]; 302 303 // Encode the register class in the upper 4 bits 304 // Must be kept in sync with NVPTXInstPrinter::printRegName 305 unsigned Ret = 0; 306 if (RC == &NVPTX::Int1RegsRegClass) { 307 Ret = (1 << 28); 308 } else if (RC == &NVPTX::Int16RegsRegClass) { 309 Ret = (2 << 28); 310 } else if (RC == &NVPTX::Int32RegsRegClass) { 311 Ret = (3 << 28); 312 } else if (RC == &NVPTX::Int64RegsRegClass) { 313 Ret = (4 << 28); 314 } else if (RC == &NVPTX::Float32RegsRegClass) { 315 Ret = (5 << 28); 316 } else if (RC == &NVPTX::Float64RegsRegClass) { 317 Ret = (6 << 28); 318 } else if (RC == &NVPTX::Int128RegsRegClass) { 319 Ret = (7 << 28); 320 } else { 321 report_fatal_error("Bad register class"); 322 } 323 324 // Insert the vreg number 325 Ret |= (RegNum & 0x0FFFFFFF); 326 return Ret; 327 } else { 328 // Some special-use registers are actually physical registers. 329 // Encode this as the register class ID of 0 and the real register ID. 330 return Reg & 0x0FFFFFFF; 331 } 332 } 333 334 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 335 const MCExpr *Expr; 336 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 337 OutContext); 338 return MCOperand::createExpr(Expr); 339 } 340 341 static bool ShouldPassAsArray(Type *Ty) { 342 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || 343 Ty->isHalfTy() || Ty->isBFloatTy(); 344 } 345 346 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 347 const DataLayout &DL = getDataLayout(); 348 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 349 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 350 351 Type *Ty = F->getReturnType(); 352 353 bool isABI = (STI.getSmVersion() >= 20); 354 355 if (Ty->getTypeID() == Type::VoidTyID) 356 return; 357 O << " ("; 358 359 if (isABI) { 360 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) && 361 !ShouldPassAsArray(Ty)) { 362 unsigned size = 0; 363 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 364 size = ITy->getBitWidth(); 365 } else { 366 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 367 size = Ty->getPrimitiveSizeInBits(); 368 } 369 size = promoteScalarArgumentSize(size); 370 O << ".param .b" << size << " func_retval0"; 371 } else if (isa<PointerType>(Ty)) { 372 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 373 << " func_retval0"; 374 } else if (ShouldPassAsArray(Ty)) { 375 unsigned totalsz = DL.getTypeAllocSize(Ty); 376 Align RetAlignment = TLI->getFunctionArgumentAlignment( 377 F, Ty, AttributeList::ReturnIndex, DL); 378 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0[" 379 << totalsz << "]"; 380 } else 381 llvm_unreachable("Unknown return type"); 382 } else { 383 SmallVector<EVT, 16> vtparts; 384 ComputeValueVTs(*TLI, DL, Ty, vtparts); 385 unsigned idx = 0; 386 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 387 unsigned elems = 1; 388 EVT elemtype = vtparts[i]; 389 if (vtparts[i].isVector()) { 390 elems = vtparts[i].getVectorNumElements(); 391 elemtype = vtparts[i].getVectorElementType(); 392 } 393 394 for (unsigned j = 0, je = elems; j != je; ++j) { 395 unsigned sz = elemtype.getSizeInBits(); 396 if (elemtype.isInteger()) 397 sz = promoteScalarArgumentSize(sz); 398 O << ".reg .b" << sz << " func_retval" << idx; 399 if (j < je - 1) 400 O << ", "; 401 ++idx; 402 } 403 if (i < e - 1) 404 O << ", "; 405 } 406 } 407 O << ") "; 408 } 409 410 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 411 raw_ostream &O) { 412 const Function &F = MF.getFunction(); 413 printReturnValStr(&F, O); 414 } 415 416 // Return true if MBB is the header of a loop marked with 417 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1. 418 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 419 const MachineBasicBlock &MBB) const { 420 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); 421 // We insert .pragma "nounroll" only to the loop header. 422 if (!LI.isLoopHeader(&MBB)) 423 return false; 424 425 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 426 // we iterate through each back edge of the loop with header MBB, and check 427 // whether its metadata contains llvm.loop.unroll.disable. 428 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 429 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 430 // Edges from other loops to MBB are not back edges. 431 continue; 432 } 433 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 434 if (MDNode *LoopID = 435 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 436 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 437 return true; 438 if (MDNode *UnrollCountMD = 439 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) { 440 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1)) 441 ->isOne()) 442 return true; 443 } 444 } 445 } 446 } 447 return false; 448 } 449 450 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 451 AsmPrinter::emitBasicBlockStart(MBB); 452 if (isLoopHeaderOfNoUnroll(MBB)) 453 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 454 } 455 456 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 457 SmallString<128> Str; 458 raw_svector_ostream O(Str); 459 460 if (!GlobalsEmitted) { 461 emitGlobals(*MF->getFunction().getParent()); 462 GlobalsEmitted = true; 463 } 464 465 // Set up 466 MRI = &MF->getRegInfo(); 467 F = &MF->getFunction(); 468 emitLinkageDirective(F, O); 469 if (isKernelFunction(*F)) 470 O << ".entry "; 471 else { 472 O << ".func "; 473 printReturnValStr(*MF, O); 474 } 475 476 CurrentFnSym->print(O, MAI); 477 478 emitFunctionParamList(F, O); 479 O << "\n"; 480 481 if (isKernelFunction(*F)) 482 emitKernelFunctionDirectives(*F, O); 483 484 if (shouldEmitPTXNoReturn(F, TM)) 485 O << ".noreturn"; 486 487 OutStreamer->emitRawText(O.str()); 488 489 VRegMapping.clear(); 490 // Emit open brace for function body. 491 OutStreamer->emitRawText(StringRef("{\n")); 492 setAndEmitFunctionVirtualRegisters(*MF); 493 // Emit initial .loc debug directive for correct relocation symbol data. 494 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) { 495 assert(SP->getUnit()); 496 if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo()) 497 emitInitialRawDwarfLocDirective(*MF); 498 } 499 } 500 501 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 502 bool Result = AsmPrinter::runOnMachineFunction(F); 503 // Emit closing brace for the body of function F. 504 // The closing brace must be emitted here because we need to emit additional 505 // debug labels/data after the last basic block. 506 // We need to emit the closing brace here because we don't have function that 507 // finished emission of the function body. 508 OutStreamer->emitRawText(StringRef("}\n")); 509 return Result; 510 } 511 512 void NVPTXAsmPrinter::emitFunctionBodyStart() { 513 SmallString<128> Str; 514 raw_svector_ostream O(Str); 515 emitDemotedVars(&MF->getFunction(), O); 516 OutStreamer->emitRawText(O.str()); 517 } 518 519 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 520 VRegMapping.clear(); 521 } 522 523 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 524 SmallString<128> Str; 525 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 526 return OutContext.getOrCreateSymbol(Str); 527 } 528 529 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 530 Register RegNo = MI->getOperand(0).getReg(); 531 if (RegNo.isVirtual()) { 532 OutStreamer->AddComment(Twine("implicit-def: ") + 533 getVirtualRegisterName(RegNo)); 534 } else { 535 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 536 OutStreamer->AddComment(Twine("implicit-def: ") + 537 STI.getRegisterInfo()->getName(RegNo)); 538 } 539 OutStreamer->addBlankLine(); 540 } 541 542 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 543 raw_ostream &O) const { 544 // If the NVVM IR has some of reqntid* specified, then output 545 // the reqntid directive, and set the unspecified ones to 1. 546 // If none of Reqntid* is specified, don't output reqntid directive. 547 std::optional<unsigned> Reqntidx = getReqNTIDx(F); 548 std::optional<unsigned> Reqntidy = getReqNTIDy(F); 549 std::optional<unsigned> Reqntidz = getReqNTIDz(F); 550 551 if (Reqntidx || Reqntidy || Reqntidz) 552 O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1) 553 << ", " << Reqntidz.value_or(1) << "\n"; 554 555 // If the NVVM IR has some of maxntid* specified, then output 556 // the maxntid directive, and set the unspecified ones to 1. 557 // If none of maxntid* is specified, don't output maxntid directive. 558 std::optional<unsigned> Maxntidx = getMaxNTIDx(F); 559 std::optional<unsigned> Maxntidy = getMaxNTIDy(F); 560 std::optional<unsigned> Maxntidz = getMaxNTIDz(F); 561 562 if (Maxntidx || Maxntidy || Maxntidz) 563 O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1) 564 << ", " << Maxntidz.value_or(1) << "\n"; 565 566 unsigned Mincta = 0; 567 if (getMinCTASm(F, Mincta)) 568 O << ".minnctapersm " << Mincta << "\n"; 569 570 unsigned Maxnreg = 0; 571 if (getMaxNReg(F, Maxnreg)) 572 O << ".maxnreg " << Maxnreg << "\n"; 573 574 // .maxclusterrank directive requires SM_90 or higher, make sure that we 575 // filter it out for lower SM versions, as it causes a hard ptxas crash. 576 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 577 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 578 unsigned Maxclusterrank = 0; 579 if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90) 580 O << ".maxclusterrank " << Maxclusterrank << "\n"; 581 } 582 583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 584 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 585 586 std::string Name; 587 raw_string_ostream NameStr(Name); 588 589 VRegRCMap::const_iterator I = VRegMapping.find(RC); 590 assert(I != VRegMapping.end() && "Bad register class"); 591 const DenseMap<unsigned, unsigned> &RegMap = I->second; 592 593 VRegMap::const_iterator VI = RegMap.find(Reg); 594 assert(VI != RegMap.end() && "Bad virtual register"); 595 unsigned MappedVR = VI->second; 596 597 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 598 599 NameStr.flush(); 600 return Name; 601 } 602 603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 604 raw_ostream &O) { 605 O << getVirtualRegisterName(vr); 606 } 607 608 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA, 609 raw_ostream &O) { 610 const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject()); 611 if (!F || isKernelFunction(*F) || F->isDeclaration()) 612 report_fatal_error( 613 "NVPTX aliasee must be a non-kernel function definition"); 614 615 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() || 616 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage()) 617 report_fatal_error("NVPTX aliasee must not be '.weak'"); 618 619 emitDeclarationWithName(F, getSymbol(GA), O); 620 } 621 622 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 623 emitDeclarationWithName(F, getSymbol(F), O); 624 } 625 626 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S, 627 raw_ostream &O) { 628 emitLinkageDirective(F, O); 629 if (isKernelFunction(*F)) 630 O << ".entry "; 631 else 632 O << ".func "; 633 printReturnValStr(F, O); 634 S->print(O, MAI); 635 O << "\n"; 636 emitFunctionParamList(F, O); 637 O << "\n"; 638 if (shouldEmitPTXNoReturn(F, TM)) 639 O << ".noreturn"; 640 O << ";\n"; 641 } 642 643 static bool usedInGlobalVarDef(const Constant *C) { 644 if (!C) 645 return false; 646 647 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 648 return GV->getName() != "llvm.used"; 649 } 650 651 for (const User *U : C->users()) 652 if (const Constant *C = dyn_cast<Constant>(U)) 653 if (usedInGlobalVarDef(C)) 654 return true; 655 656 return false; 657 } 658 659 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 660 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 661 if (othergv->getName() == "llvm.used") 662 return true; 663 } 664 665 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 666 if (instr->getParent() && instr->getParent()->getParent()) { 667 const Function *curFunc = instr->getParent()->getParent(); 668 if (oneFunc && (curFunc != oneFunc)) 669 return false; 670 oneFunc = curFunc; 671 return true; 672 } else 673 return false; 674 } 675 676 for (const User *UU : U->users()) 677 if (!usedInOneFunc(UU, oneFunc)) 678 return false; 679 680 return true; 681 } 682 683 /* Find out if a global variable can be demoted to local scope. 684 * Currently, this is valid for CUDA shared variables, which have local 685 * scope and global lifetime. So the conditions to check are : 686 * 1. Is the global variable in shared address space? 687 * 2. Does it have local linkage? 688 * 3. Is the global variable referenced only in one function? 689 */ 690 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 691 if (!gv->hasLocalLinkage()) 692 return false; 693 PointerType *Pty = gv->getType(); 694 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 695 return false; 696 697 const Function *oneFunc = nullptr; 698 699 bool flag = usedInOneFunc(gv, oneFunc); 700 if (!flag) 701 return false; 702 if (!oneFunc) 703 return false; 704 f = oneFunc; 705 return true; 706 } 707 708 static bool useFuncSeen(const Constant *C, 709 DenseMap<const Function *, bool> &seenMap) { 710 for (const User *U : C->users()) { 711 if (const Constant *cu = dyn_cast<Constant>(U)) { 712 if (useFuncSeen(cu, seenMap)) 713 return true; 714 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 715 const BasicBlock *bb = I->getParent(); 716 if (!bb) 717 continue; 718 const Function *caller = bb->getParent(); 719 if (!caller) 720 continue; 721 if (seenMap.contains(caller)) 722 return true; 723 } 724 } 725 return false; 726 } 727 728 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 729 DenseMap<const Function *, bool> seenMap; 730 for (const Function &F : M) { 731 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 732 emitDeclaration(&F, O); 733 continue; 734 } 735 736 if (F.isDeclaration()) { 737 if (F.use_empty()) 738 continue; 739 if (F.getIntrinsicID()) 740 continue; 741 emitDeclaration(&F, O); 742 continue; 743 } 744 for (const User *U : F.users()) { 745 if (const Constant *C = dyn_cast<Constant>(U)) { 746 if (usedInGlobalVarDef(C)) { 747 // The use is in the initialization of a global variable 748 // that is a function pointer, so print a declaration 749 // for the original function 750 emitDeclaration(&F, O); 751 break; 752 } 753 // Emit a declaration of this function if the function that 754 // uses this constant expr has already been seen. 755 if (useFuncSeen(C, seenMap)) { 756 emitDeclaration(&F, O); 757 break; 758 } 759 } 760 761 if (!isa<Instruction>(U)) 762 continue; 763 const Instruction *instr = cast<Instruction>(U); 764 const BasicBlock *bb = instr->getParent(); 765 if (!bb) 766 continue; 767 const Function *caller = bb->getParent(); 768 if (!caller) 769 continue; 770 771 // If a caller has already been seen, then the caller is 772 // appearing in the module before the callee. so print out 773 // a declaration for the callee. 774 if (seenMap.contains(caller)) { 775 emitDeclaration(&F, O); 776 break; 777 } 778 } 779 seenMap[&F] = true; 780 } 781 for (const GlobalAlias &GA : M.aliases()) 782 emitAliasDeclaration(&GA, O); 783 } 784 785 static bool isEmptyXXStructor(GlobalVariable *GV) { 786 if (!GV) return true; 787 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 788 if (!InitList) return true; // Not an array; we don't know how to parse. 789 return InitList->getNumOperands() == 0; 790 } 791 792 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 793 // Construct a default subtarget off of the TargetMachine defaults. The 794 // rest of NVPTX isn't friendly to change subtargets per function and 795 // so the default TargetMachine will have all of the options. 796 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 797 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 798 SmallString<128> Str1; 799 raw_svector_ostream OS1(Str1); 800 801 // Emit header before any dwarf directives are emitted below. 802 emitHeader(M, OS1, *STI); 803 OutStreamer->emitRawText(OS1.str()); 804 } 805 806 bool NVPTXAsmPrinter::doInitialization(Module &M) { 807 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 808 const NVPTXSubtarget &STI = 809 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 810 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30)) 811 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30"); 812 813 // OpenMP supports NVPTX global constructors and destructors. 814 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr; 815 816 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) && 817 !LowerCtorDtor && !IsOpenMP) { 818 report_fatal_error( 819 "Module has a nontrivial global ctor, which NVPTX does not support."); 820 return true; // error 821 } 822 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) && 823 !LowerCtorDtor && !IsOpenMP) { 824 report_fatal_error( 825 "Module has a nontrivial global dtor, which NVPTX does not support."); 826 return true; // error 827 } 828 829 // We need to call the parent's one explicitly. 830 bool Result = AsmPrinter::doInitialization(M); 831 832 GlobalsEmitted = false; 833 834 return Result; 835 } 836 837 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 838 SmallString<128> Str2; 839 raw_svector_ostream OS2(Str2); 840 841 emitDeclarations(M, OS2); 842 843 // As ptxas does not support forward references of globals, we need to first 844 // sort the list of module-level globals in def-use order. We visit each 845 // global variable in order, and ensure that we emit it *after* its dependent 846 // globals. We use a little extra memory maintaining both a set and a list to 847 // have fast searches while maintaining a strict ordering. 848 SmallVector<const GlobalVariable *, 8> Globals; 849 DenseSet<const GlobalVariable *> GVVisited; 850 DenseSet<const GlobalVariable *> GVVisiting; 851 852 // Visit each global variable, in order 853 for (const GlobalVariable &I : M.globals()) 854 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 855 856 assert(GVVisited.size() == M.global_size() && "Missed a global variable"); 857 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 858 859 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 860 const NVPTXSubtarget &STI = 861 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 862 863 // Print out module-level global variables in proper order 864 for (const GlobalVariable *GV : Globals) 865 printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI); 866 867 OS2 << '\n'; 868 869 OutStreamer->emitRawText(OS2.str()); 870 } 871 872 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) { 873 SmallString<128> Str; 874 raw_svector_ostream OS(Str); 875 876 MCSymbol *Name = getSymbol(&GA); 877 878 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName() 879 << ";\n"; 880 881 OutStreamer->emitRawText(OS.str()); 882 } 883 884 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 885 const NVPTXSubtarget &STI) { 886 O << "//\n"; 887 O << "// Generated by LLVM NVPTX Back-End\n"; 888 O << "//\n"; 889 O << "\n"; 890 891 unsigned PTXVersion = STI.getPTXVersion(); 892 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 893 894 O << ".target "; 895 O << STI.getTargetName(); 896 897 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 898 if (NTM.getDrvInterface() == NVPTX::NVCL) 899 O << ", texmode_independent"; 900 901 bool HasFullDebugInfo = false; 902 for (DICompileUnit *CU : M.debug_compile_units()) { 903 switch(CU->getEmissionKind()) { 904 case DICompileUnit::NoDebug: 905 case DICompileUnit::DebugDirectivesOnly: 906 break; 907 case DICompileUnit::LineTablesOnly: 908 case DICompileUnit::FullDebug: 909 HasFullDebugInfo = true; 910 break; 911 } 912 if (HasFullDebugInfo) 913 break; 914 } 915 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 916 O << ", debug"; 917 918 O << "\n"; 919 920 O << ".address_size "; 921 if (NTM.is64Bit()) 922 O << "64"; 923 else 924 O << "32"; 925 O << "\n"; 926 927 O << "\n"; 928 } 929 930 bool NVPTXAsmPrinter::doFinalization(Module &M) { 931 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 932 933 // If we did not emit any functions, then the global declarations have not 934 // yet been emitted. 935 if (!GlobalsEmitted) { 936 emitGlobals(M); 937 GlobalsEmitted = true; 938 } 939 940 // call doFinalization 941 bool ret = AsmPrinter::doFinalization(M); 942 943 clearAnnotationCache(&M); 944 945 auto *TS = 946 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()); 947 // Close the last emitted section 948 if (HasDebugInfo) { 949 TS->closeLastSection(); 950 // Emit empty .debug_loc section for better support of the empty files. 951 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 952 } 953 954 // Output last DWARF .file directives, if any. 955 TS->outputDwarfFileDirectives(); 956 957 return ret; 958 } 959 960 // This function emits appropriate linkage directives for 961 // functions and global variables. 962 // 963 // extern function declaration -> .extern 964 // extern function definition -> .visible 965 // external global variable with init -> .visible 966 // external without init -> .extern 967 // appending -> not allowed, assert. 968 // for any linkage other than 969 // internal, private, linker_private, 970 // linker_private_weak, linker_private_weak_def_auto, 971 // we emit -> .weak. 972 973 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 974 raw_ostream &O) { 975 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 976 if (V->hasExternalLinkage()) { 977 if (isa<GlobalVariable>(V)) { 978 const GlobalVariable *GVar = cast<GlobalVariable>(V); 979 if (GVar) { 980 if (GVar->hasInitializer()) 981 O << ".visible "; 982 else 983 O << ".extern "; 984 } 985 } else if (V->isDeclaration()) 986 O << ".extern "; 987 else 988 O << ".visible "; 989 } else if (V->hasAppendingLinkage()) { 990 std::string msg; 991 msg.append("Error: "); 992 msg.append("Symbol "); 993 if (V->hasName()) 994 msg.append(std::string(V->getName())); 995 msg.append("has unsupported appending linkage type"); 996 llvm_unreachable(msg.c_str()); 997 } else if (!V->hasInternalLinkage() && 998 !V->hasPrivateLinkage()) { 999 O << ".weak "; 1000 } 1001 } 1002 } 1003 1004 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 1005 raw_ostream &O, bool processDemoted, 1006 const NVPTXSubtarget &STI) { 1007 // Skip meta data 1008 if (GVar->hasSection()) { 1009 if (GVar->getSection() == "llvm.metadata") 1010 return; 1011 } 1012 1013 // Skip LLVM intrinsic global variables 1014 if (GVar->getName().starts_with("llvm.") || 1015 GVar->getName().starts_with("nvvm.")) 1016 return; 1017 1018 const DataLayout &DL = getDataLayout(); 1019 1020 // GlobalVariables are always constant pointers themselves. 1021 Type *ETy = GVar->getValueType(); 1022 1023 if (GVar->hasExternalLinkage()) { 1024 if (GVar->hasInitializer()) 1025 O << ".visible "; 1026 else 1027 O << ".extern "; 1028 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() && 1029 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) { 1030 O << ".common "; 1031 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 1032 GVar->hasAvailableExternallyLinkage() || 1033 GVar->hasCommonLinkage()) { 1034 O << ".weak "; 1035 } 1036 1037 if (isTexture(*GVar)) { 1038 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1039 return; 1040 } 1041 1042 if (isSurface(*GVar)) { 1043 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1044 return; 1045 } 1046 1047 if (GVar->isDeclaration()) { 1048 // (extern) declarations, no definition or initializer 1049 // Currently the only known declaration is for an automatic __local 1050 // (.shared) promoted to global. 1051 emitPTXGlobalVariable(GVar, O, STI); 1052 O << ";\n"; 1053 return; 1054 } 1055 1056 if (isSampler(*GVar)) { 1057 O << ".global .samplerref " << getSamplerName(*GVar); 1058 1059 const Constant *Initializer = nullptr; 1060 if (GVar->hasInitializer()) 1061 Initializer = GVar->getInitializer(); 1062 const ConstantInt *CI = nullptr; 1063 if (Initializer) 1064 CI = dyn_cast<ConstantInt>(Initializer); 1065 if (CI) { 1066 unsigned sample = CI->getZExtValue(); 1067 1068 O << " = { "; 1069 1070 for (int i = 0, 1071 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1072 i < 3; i++) { 1073 O << "addr_mode_" << i << " = "; 1074 switch (addr) { 1075 case 0: 1076 O << "wrap"; 1077 break; 1078 case 1: 1079 O << "clamp_to_border"; 1080 break; 1081 case 2: 1082 O << "clamp_to_edge"; 1083 break; 1084 case 3: 1085 O << "wrap"; 1086 break; 1087 case 4: 1088 O << "mirror"; 1089 break; 1090 } 1091 O << ", "; 1092 } 1093 O << "filter_mode = "; 1094 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1095 case 0: 1096 O << "nearest"; 1097 break; 1098 case 1: 1099 O << "linear"; 1100 break; 1101 case 2: 1102 llvm_unreachable("Anisotropic filtering is not supported"); 1103 default: 1104 O << "nearest"; 1105 break; 1106 } 1107 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1108 O << ", force_unnormalized_coords = 1"; 1109 } 1110 O << " }"; 1111 } 1112 1113 O << ";\n"; 1114 return; 1115 } 1116 1117 if (GVar->hasPrivateLinkage()) { 1118 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1119 return; 1120 1121 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1122 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1123 return; 1124 if (GVar->use_empty()) 1125 return; 1126 } 1127 1128 const Function *demotedFunc = nullptr; 1129 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1130 O << "// " << GVar->getName() << " has been demoted\n"; 1131 if (localDecls.find(demotedFunc) != localDecls.end()) 1132 localDecls[demotedFunc].push_back(GVar); 1133 else { 1134 std::vector<const GlobalVariable *> temp; 1135 temp.push_back(GVar); 1136 localDecls[demotedFunc] = temp; 1137 } 1138 return; 1139 } 1140 1141 O << "."; 1142 emitPTXAddressSpace(GVar->getAddressSpace(), O); 1143 1144 if (isManaged(*GVar)) { 1145 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1146 report_fatal_error( 1147 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1148 } 1149 O << " .attribute(.managed)"; 1150 } 1151 1152 if (MaybeAlign A = GVar->getAlign()) 1153 O << " .align " << A->value(); 1154 else 1155 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1156 1157 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1158 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1159 O << " ."; 1160 // Special case: ABI requires that we use .u8 for predicates 1161 if (ETy->isIntegerTy(1)) 1162 O << "u8"; 1163 else 1164 O << getPTXFundamentalTypeStr(ETy, false); 1165 O << " "; 1166 getSymbol(GVar)->print(O, MAI); 1167 1168 // Ptx allows variable initilization only for constant and global state 1169 // spaces. 1170 if (GVar->hasInitializer()) { 1171 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1172 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1173 const Constant *Initializer = GVar->getInitializer(); 1174 // 'undef' is treated as there is no value specified. 1175 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1176 O << " = "; 1177 printScalarConstant(Initializer, O); 1178 } 1179 } else { 1180 // The frontend adds zero-initializer to device and constant variables 1181 // that don't have an initial value, and UndefValue to shared 1182 // variables, so skip warning for this case. 1183 if (!GVar->getInitializer()->isNullValue() && 1184 !isa<UndefValue>(GVar->getInitializer())) { 1185 report_fatal_error("initial value of '" + GVar->getName() + 1186 "' is not allowed in addrspace(" + 1187 Twine(GVar->getAddressSpace()) + ")"); 1188 } 1189 } 1190 } 1191 } else { 1192 uint64_t ElementSize = 0; 1193 1194 // Although PTX has direct support for struct type and array type and 1195 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1196 // targets that support these high level field accesses. Structs, arrays 1197 // and vectors are lowered into arrays of bytes. 1198 switch (ETy->getTypeID()) { 1199 case Type::IntegerTyID: // Integers larger than 64 bits 1200 case Type::StructTyID: 1201 case Type::ArrayTyID: 1202 case Type::FixedVectorTyID: 1203 ElementSize = DL.getTypeStoreSize(ETy); 1204 // Ptx allows variable initilization only for constant and 1205 // global state spaces. 1206 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1207 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1208 GVar->hasInitializer()) { 1209 const Constant *Initializer = GVar->getInitializer(); 1210 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1211 AggBuffer aggBuffer(ElementSize, *this); 1212 bufferAggregateConstant(Initializer, &aggBuffer); 1213 if (aggBuffer.numSymbols()) { 1214 unsigned int ptrSize = MAI->getCodePointerSize(); 1215 if (ElementSize % ptrSize || 1216 !aggBuffer.allSymbolsAligned(ptrSize)) { 1217 // Print in bytes and use the mask() operator for pointers. 1218 if (!STI.hasMaskOperator()) 1219 report_fatal_error( 1220 "initialized packed aggregate with pointers '" + 1221 GVar->getName() + 1222 "' requires at least PTX ISA version 7.1"); 1223 O << " .u8 "; 1224 getSymbol(GVar)->print(O, MAI); 1225 O << "[" << ElementSize << "] = {"; 1226 aggBuffer.printBytes(O); 1227 O << "}"; 1228 } else { 1229 O << " .u" << ptrSize * 8 << " "; 1230 getSymbol(GVar)->print(O, MAI); 1231 O << "[" << ElementSize / ptrSize << "] = {"; 1232 aggBuffer.printWords(O); 1233 O << "}"; 1234 } 1235 } else { 1236 O << " .b8 "; 1237 getSymbol(GVar)->print(O, MAI); 1238 O << "[" << ElementSize << "] = {"; 1239 aggBuffer.printBytes(O); 1240 O << "}"; 1241 } 1242 } else { 1243 O << " .b8 "; 1244 getSymbol(GVar)->print(O, MAI); 1245 if (ElementSize) { 1246 O << "["; 1247 O << ElementSize; 1248 O << "]"; 1249 } 1250 } 1251 } else { 1252 O << " .b8 "; 1253 getSymbol(GVar)->print(O, MAI); 1254 if (ElementSize) { 1255 O << "["; 1256 O << ElementSize; 1257 O << "]"; 1258 } 1259 } 1260 break; 1261 default: 1262 llvm_unreachable("type not supported yet"); 1263 } 1264 } 1265 O << ";\n"; 1266 } 1267 1268 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1269 const Value *v = Symbols[nSym]; 1270 const Value *v0 = SymbolsBeforeStripping[nSym]; 1271 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1272 MCSymbol *Name = AP.getSymbol(GVar); 1273 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1274 // Is v0 a generic pointer? 1275 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1276 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1277 os << "generic("; 1278 Name->print(os, AP.MAI); 1279 os << ")"; 1280 } else { 1281 Name->print(os, AP.MAI); 1282 } 1283 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1284 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1285 AP.printMCExpr(*Expr, os); 1286 } else 1287 llvm_unreachable("symbol type unknown"); 1288 } 1289 1290 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1291 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1292 // Do not emit trailing zero initializers. They will be zero-initialized by 1293 // ptxas. This saves on both space requirements for the generated PTX and on 1294 // memory use by ptxas. (See: 1295 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space) 1296 unsigned int InitializerCount = size; 1297 // TODO: symbols make this harder, but it would still be good to trim trailing 1298 // 0s for aggs with symbols as well. 1299 if (numSymbols() == 0) 1300 while (InitializerCount >= 1 && !buffer[InitializerCount - 1]) 1301 InitializerCount--; 1302 1303 symbolPosInBuffer.push_back(InitializerCount); 1304 unsigned int nSym = 0; 1305 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1306 for (unsigned int pos = 0; pos < InitializerCount;) { 1307 if (pos) 1308 os << ", "; 1309 if (pos != nextSymbolPos) { 1310 os << (unsigned int)buffer[pos]; 1311 ++pos; 1312 continue; 1313 } 1314 // Generate a per-byte mask() operator for the symbol, which looks like: 1315 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1316 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1317 std::string symText; 1318 llvm::raw_string_ostream oss(symText); 1319 printSymbol(nSym, oss); 1320 for (unsigned i = 0; i < ptrSize; ++i) { 1321 if (i) 1322 os << ", "; 1323 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1324 os << "(" << symText << ")"; 1325 } 1326 pos += ptrSize; 1327 nextSymbolPos = symbolPosInBuffer[++nSym]; 1328 assert(nextSymbolPos >= pos); 1329 } 1330 } 1331 1332 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1333 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1334 symbolPosInBuffer.push_back(size); 1335 unsigned int nSym = 0; 1336 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1337 assert(nextSymbolPos % ptrSize == 0); 1338 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1339 if (pos) 1340 os << ", "; 1341 if (pos == nextSymbolPos) { 1342 printSymbol(nSym, os); 1343 nextSymbolPos = symbolPosInBuffer[++nSym]; 1344 assert(nextSymbolPos % ptrSize == 0); 1345 assert(nextSymbolPos >= pos + ptrSize); 1346 } else if (ptrSize == 4) 1347 os << support::endian::read32le(&buffer[pos]); 1348 else 1349 os << support::endian::read64le(&buffer[pos]); 1350 } 1351 } 1352 1353 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1354 if (localDecls.find(f) == localDecls.end()) 1355 return; 1356 1357 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1358 1359 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1360 const NVPTXSubtarget &STI = 1361 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1362 1363 for (const GlobalVariable *GV : gvars) { 1364 O << "\t// demoted variable\n\t"; 1365 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1366 } 1367 } 1368 1369 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1370 raw_ostream &O) const { 1371 switch (AddressSpace) { 1372 case ADDRESS_SPACE_LOCAL: 1373 O << "local"; 1374 break; 1375 case ADDRESS_SPACE_GLOBAL: 1376 O << "global"; 1377 break; 1378 case ADDRESS_SPACE_CONST: 1379 O << "const"; 1380 break; 1381 case ADDRESS_SPACE_SHARED: 1382 O << "shared"; 1383 break; 1384 default: 1385 report_fatal_error("Bad address space found while emitting PTX: " + 1386 llvm::Twine(AddressSpace)); 1387 break; 1388 } 1389 } 1390 1391 std::string 1392 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1393 switch (Ty->getTypeID()) { 1394 case Type::IntegerTyID: { 1395 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1396 if (NumBits == 1) 1397 return "pred"; 1398 else if (NumBits <= 64) { 1399 std::string name = "u"; 1400 return name + utostr(NumBits); 1401 } else { 1402 llvm_unreachable("Integer too large"); 1403 break; 1404 } 1405 break; 1406 } 1407 case Type::BFloatTyID: 1408 case Type::HalfTyID: 1409 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 1410 // PTX assembly. 1411 return "b16"; 1412 case Type::FloatTyID: 1413 return "f32"; 1414 case Type::DoubleTyID: 1415 return "f64"; 1416 case Type::PointerTyID: { 1417 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace()); 1418 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size"); 1419 1420 if (PtrSize == 64) 1421 if (useB4PTR) 1422 return "b64"; 1423 else 1424 return "u64"; 1425 else if (useB4PTR) 1426 return "b32"; 1427 else 1428 return "u32"; 1429 } 1430 default: 1431 break; 1432 } 1433 llvm_unreachable("unexpected type"); 1434 } 1435 1436 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1437 raw_ostream &O, 1438 const NVPTXSubtarget &STI) { 1439 const DataLayout &DL = getDataLayout(); 1440 1441 // GlobalVariables are always constant pointers themselves. 1442 Type *ETy = GVar->getValueType(); 1443 1444 O << "."; 1445 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1446 if (isManaged(*GVar)) { 1447 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1448 report_fatal_error( 1449 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1450 } 1451 O << " .attribute(.managed)"; 1452 } 1453 if (MaybeAlign A = GVar->getAlign()) 1454 O << " .align " << A->value(); 1455 else 1456 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1457 1458 // Special case for i128 1459 if (ETy->isIntegerTy(128)) { 1460 O << " .b8 "; 1461 getSymbol(GVar)->print(O, MAI); 1462 O << "[16]"; 1463 return; 1464 } 1465 1466 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1467 O << " ."; 1468 O << getPTXFundamentalTypeStr(ETy); 1469 O << " "; 1470 getSymbol(GVar)->print(O, MAI); 1471 return; 1472 } 1473 1474 int64_t ElementSize = 0; 1475 1476 // Although PTX has direct support for struct type and array type and LLVM IR 1477 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1478 // support these high level field accesses. Structs and arrays are lowered 1479 // into arrays of bytes. 1480 switch (ETy->getTypeID()) { 1481 case Type::StructTyID: 1482 case Type::ArrayTyID: 1483 case Type::FixedVectorTyID: 1484 ElementSize = DL.getTypeStoreSize(ETy); 1485 O << " .b8 "; 1486 getSymbol(GVar)->print(O, MAI); 1487 O << "["; 1488 if (ElementSize) { 1489 O << ElementSize; 1490 } 1491 O << "]"; 1492 break; 1493 default: 1494 llvm_unreachable("type not supported yet"); 1495 } 1496 } 1497 1498 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1499 const DataLayout &DL = getDataLayout(); 1500 const AttributeList &PAL = F->getAttributes(); 1501 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1502 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1503 1504 Function::const_arg_iterator I, E; 1505 unsigned paramIndex = 0; 1506 bool first = true; 1507 bool isKernelFunc = isKernelFunction(*F); 1508 bool isABI = (STI.getSmVersion() >= 20); 1509 bool hasImageHandles = STI.hasImageHandles(); 1510 1511 if (F->arg_empty() && !F->isVarArg()) { 1512 O << "()"; 1513 return; 1514 } 1515 1516 O << "(\n"; 1517 1518 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1519 Type *Ty = I->getType(); 1520 1521 if (!first) 1522 O << ",\n"; 1523 1524 first = false; 1525 1526 // Handle image/sampler parameters 1527 if (isKernelFunction(*F)) { 1528 if (isSampler(*I) || isImage(*I)) { 1529 if (isImage(*I)) { 1530 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1531 if (hasImageHandles) 1532 O << "\t.param .u64 .ptr .surfref "; 1533 else 1534 O << "\t.param .surfref "; 1535 O << TLI->getParamName(F, paramIndex); 1536 } 1537 else { // Default image is read_only 1538 if (hasImageHandles) 1539 O << "\t.param .u64 .ptr .texref "; 1540 else 1541 O << "\t.param .texref "; 1542 O << TLI->getParamName(F, paramIndex); 1543 } 1544 } else { 1545 if (hasImageHandles) 1546 O << "\t.param .u64 .ptr .samplerref "; 1547 else 1548 O << "\t.param .samplerref "; 1549 O << TLI->getParamName(F, paramIndex); 1550 } 1551 continue; 1552 } 1553 } 1554 1555 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1556 paramIndex](Type *Ty) -> Align { 1557 if (MaybeAlign StackAlign = 1558 getAlign(*F, paramIndex + AttributeList::FirstArgIndex)) 1559 return StackAlign.value(); 1560 1561 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1562 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1563 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1564 }; 1565 1566 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1567 if (ShouldPassAsArray(Ty)) { 1568 // Just print .param .align <a> .b8 .param[size]; 1569 // <a> = optimal alignment for the element type; always multiple of 1570 // PAL.getParamAlignment 1571 // size = typeallocsize of element type 1572 Align OptimalAlign = getOptimalAlignForParam(Ty); 1573 1574 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1575 O << TLI->getParamName(F, paramIndex); 1576 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1577 1578 continue; 1579 } 1580 // Just a scalar 1581 auto *PTy = dyn_cast<PointerType>(Ty); 1582 unsigned PTySizeInBits = 0; 1583 if (PTy) { 1584 PTySizeInBits = 1585 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); 1586 assert(PTySizeInBits && "Invalid pointer size"); 1587 } 1588 1589 if (isKernelFunc) { 1590 if (PTy) { 1591 // Special handling for pointer arguments to kernel 1592 O << "\t.param .u" << PTySizeInBits << " "; 1593 1594 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1595 NVPTX::CUDA) { 1596 int addrSpace = PTy->getAddressSpace(); 1597 switch (addrSpace) { 1598 default: 1599 O << ".ptr "; 1600 break; 1601 case ADDRESS_SPACE_CONST: 1602 O << ".ptr .const "; 1603 break; 1604 case ADDRESS_SPACE_SHARED: 1605 O << ".ptr .shared "; 1606 break; 1607 case ADDRESS_SPACE_GLOBAL: 1608 O << ".ptr .global "; 1609 break; 1610 } 1611 Align ParamAlign = I->getParamAlign().valueOrOne(); 1612 O << ".align " << ParamAlign.value() << " "; 1613 } 1614 O << TLI->getParamName(F, paramIndex); 1615 continue; 1616 } 1617 1618 // non-pointer scalar to kernel func 1619 O << "\t.param ."; 1620 // Special case: predicate operands become .u8 types 1621 if (Ty->isIntegerTy(1)) 1622 O << "u8"; 1623 else 1624 O << getPTXFundamentalTypeStr(Ty); 1625 O << " "; 1626 O << TLI->getParamName(F, paramIndex); 1627 continue; 1628 } 1629 // Non-kernel function, just print .param .b<size> for ABI 1630 // and .reg .b<size> for non-ABI 1631 unsigned sz = 0; 1632 if (isa<IntegerType>(Ty)) { 1633 sz = cast<IntegerType>(Ty)->getBitWidth(); 1634 sz = promoteScalarArgumentSize(sz); 1635 } else if (PTy) { 1636 assert(PTySizeInBits && "Invalid pointer size"); 1637 sz = PTySizeInBits; 1638 } else 1639 sz = Ty->getPrimitiveSizeInBits(); 1640 if (isABI) 1641 O << "\t.param .b" << sz << " "; 1642 else 1643 O << "\t.reg .b" << sz << " "; 1644 O << TLI->getParamName(F, paramIndex); 1645 continue; 1646 } 1647 1648 // param has byVal attribute. 1649 Type *ETy = PAL.getParamByValType(paramIndex); 1650 assert(ETy && "Param should have byval type"); 1651 1652 if (isABI || isKernelFunc) { 1653 // Just print .param .align <a> .b8 .param[size]; 1654 // <a> = optimal alignment for the element type; always multiple of 1655 // PAL.getParamAlignment 1656 // size = typeallocsize of element type 1657 Align OptimalAlign = 1658 isKernelFunc 1659 ? getOptimalAlignForParam(ETy) 1660 : TLI->getFunctionByValParamAlign( 1661 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); 1662 1663 unsigned sz = DL.getTypeAllocSize(ETy); 1664 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1665 O << TLI->getParamName(F, paramIndex); 1666 O << "[" << sz << "]"; 1667 continue; 1668 } else { 1669 // Split the ETy into constituent parts and 1670 // print .param .b<size> <name> for each part. 1671 // Further, if a part is vector, print the above for 1672 // each vector element. 1673 SmallVector<EVT, 16> vtparts; 1674 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1675 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1676 unsigned elems = 1; 1677 EVT elemtype = vtparts[i]; 1678 if (vtparts[i].isVector()) { 1679 elems = vtparts[i].getVectorNumElements(); 1680 elemtype = vtparts[i].getVectorElementType(); 1681 } 1682 1683 for (unsigned j = 0, je = elems; j != je; ++j) { 1684 unsigned sz = elemtype.getSizeInBits(); 1685 if (elemtype.isInteger()) 1686 sz = promoteScalarArgumentSize(sz); 1687 O << "\t.reg .b" << sz << " "; 1688 O << TLI->getParamName(F, paramIndex); 1689 if (j < je - 1) 1690 O << ",\n"; 1691 ++paramIndex; 1692 } 1693 if (i < e - 1) 1694 O << ",\n"; 1695 } 1696 --paramIndex; 1697 continue; 1698 } 1699 } 1700 1701 if (F->isVarArg()) { 1702 if (!first) 1703 O << ",\n"; 1704 O << "\t.param .align " << STI.getMaxRequiredAlignment(); 1705 O << " .b8 "; 1706 O << TLI->getParamName(F, /* vararg */ -1) << "[]"; 1707 } 1708 1709 O << "\n)"; 1710 } 1711 1712 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1713 const MachineFunction &MF) { 1714 SmallString<128> Str; 1715 raw_svector_ostream O(Str); 1716 1717 // Map the global virtual register number to a register class specific 1718 // virtual register number starting from 1 with that class. 1719 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1720 //unsigned numRegClasses = TRI->getNumRegClasses(); 1721 1722 // Emit the Fake Stack Object 1723 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1724 int64_t NumBytes = MFI.getStackSize(); 1725 if (NumBytes) { 1726 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1727 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1728 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1729 O << "\t.reg .b64 \t%SP;\n"; 1730 O << "\t.reg .b64 \t%SPL;\n"; 1731 } else { 1732 O << "\t.reg .b32 \t%SP;\n"; 1733 O << "\t.reg .b32 \t%SPL;\n"; 1734 } 1735 } 1736 1737 // Go through all virtual registers to establish the mapping between the 1738 // global virtual 1739 // register number and the per class virtual register number. 1740 // We use the per class virtual register number in the ptx output. 1741 unsigned int numVRs = MRI->getNumVirtRegs(); 1742 for (unsigned i = 0; i < numVRs; i++) { 1743 Register vr = Register::index2VirtReg(i); 1744 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1745 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1746 int n = regmap.size(); 1747 regmap.insert(std::make_pair(vr, n + 1)); 1748 } 1749 1750 // Emit register declarations 1751 // @TODO: Extract out the real register usage 1752 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1753 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1754 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1755 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1756 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1757 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1758 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1759 1760 // Emit declaration of the virtual registers or 'physical' registers for 1761 // each register class 1762 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1763 const TargetRegisterClass *RC = TRI->getRegClass(i); 1764 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1765 std::string rcname = getNVPTXRegClassName(RC); 1766 std::string rcStr = getNVPTXRegClassStr(RC); 1767 int n = regmap.size(); 1768 1769 // Only declare those registers that may be used. 1770 if (n) { 1771 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1772 << ">;\n"; 1773 } 1774 } 1775 1776 OutStreamer->emitRawText(O.str()); 1777 } 1778 1779 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1780 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1781 bool ignored; 1782 unsigned int numHex; 1783 const char *lead; 1784 1785 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1786 numHex = 8; 1787 lead = "0f"; 1788 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1789 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1790 numHex = 16; 1791 lead = "0d"; 1792 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1793 } else 1794 llvm_unreachable("unsupported fp type"); 1795 1796 APInt API = APF.bitcastToAPInt(); 1797 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1798 } 1799 1800 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1801 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1802 O << CI->getValue(); 1803 return; 1804 } 1805 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1806 printFPConstant(CFP, O); 1807 return; 1808 } 1809 if (isa<ConstantPointerNull>(CPV)) { 1810 O << "0"; 1811 return; 1812 } 1813 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1814 bool IsNonGenericPointer = false; 1815 if (GVar->getType()->getAddressSpace() != 0) { 1816 IsNonGenericPointer = true; 1817 } 1818 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1819 O << "generic("; 1820 getSymbol(GVar)->print(O, MAI); 1821 O << ")"; 1822 } else { 1823 getSymbol(GVar)->print(O, MAI); 1824 } 1825 return; 1826 } 1827 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1828 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false); 1829 printMCExpr(*E, O); 1830 return; 1831 } 1832 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1833 } 1834 1835 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1836 AggBuffer *AggBuffer) { 1837 const DataLayout &DL = getDataLayout(); 1838 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1839 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1840 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1841 // only the space allocated by CPV. 1842 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1843 return; 1844 } 1845 1846 // Helper for filling AggBuffer with APInts. 1847 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1848 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1849 SmallVector<unsigned char, 16> Buf(NumBytes); 1850 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the 1851 // input's bit width, and i1 arrays may not have a length that is a multuple 1852 // of 8. We handle the last byte separately, so we never request out of 1853 // bounds bits. 1854 for (unsigned I = 0; I < NumBytes - 1; ++I) { 1855 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1856 } 1857 size_t LastBytePosition = (NumBytes - 1) * 8; 1858 size_t LastByteBits = Val.getBitWidth() - LastBytePosition; 1859 Buf[NumBytes - 1] = 1860 Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition); 1861 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1862 }; 1863 1864 switch (CPV->getType()->getTypeID()) { 1865 case Type::IntegerTyID: 1866 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1867 AddIntToBuffer(CI->getValue()); 1868 break; 1869 } 1870 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1871 if (const auto *CI = 1872 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1873 AddIntToBuffer(CI->getValue()); 1874 break; 1875 } 1876 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1877 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1878 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1879 AggBuffer->addZeros(AllocSize); 1880 break; 1881 } 1882 } 1883 llvm_unreachable("unsupported integer const type"); 1884 break; 1885 1886 case Type::HalfTyID: 1887 case Type::BFloatTyID: 1888 case Type::FloatTyID: 1889 case Type::DoubleTyID: 1890 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1891 break; 1892 1893 case Type::PointerTyID: { 1894 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1895 AggBuffer->addSymbol(GVar, GVar); 1896 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1897 const Value *v = Cexpr->stripPointerCasts(); 1898 AggBuffer->addSymbol(v, Cexpr); 1899 } 1900 AggBuffer->addZeros(AllocSize); 1901 break; 1902 } 1903 1904 case Type::ArrayTyID: 1905 case Type::FixedVectorTyID: 1906 case Type::StructTyID: { 1907 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1908 bufferAggregateConstant(CPV, AggBuffer); 1909 if (Bytes > AllocSize) 1910 AggBuffer->addZeros(Bytes - AllocSize); 1911 } else if (isa<ConstantAggregateZero>(CPV)) 1912 AggBuffer->addZeros(Bytes); 1913 else 1914 llvm_unreachable("Unexpected Constant type"); 1915 break; 1916 } 1917 1918 default: 1919 llvm_unreachable("unsupported type"); 1920 } 1921 } 1922 1923 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1924 AggBuffer *aggBuffer) { 1925 const DataLayout &DL = getDataLayout(); 1926 int Bytes; 1927 1928 // Integers of arbitrary width 1929 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1930 APInt Val = CI->getValue(); 1931 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1932 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1933 aggBuffer->addBytes(&Byte, 1, 1); 1934 Val.lshrInPlace(8); 1935 } 1936 return; 1937 } 1938 1939 // Old constants 1940 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1941 if (CPV->getNumOperands()) 1942 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1943 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1944 return; 1945 } 1946 1947 if (const ConstantDataSequential *CDS = 1948 dyn_cast<ConstantDataSequential>(CPV)) { 1949 if (CDS->getNumElements()) 1950 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1951 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1952 aggBuffer); 1953 return; 1954 } 1955 1956 if (isa<ConstantStruct>(CPV)) { 1957 if (CPV->getNumOperands()) { 1958 StructType *ST = cast<StructType>(CPV->getType()); 1959 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1960 if (i == (e - 1)) 1961 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1962 DL.getTypeAllocSize(ST) - 1963 DL.getStructLayout(ST)->getElementOffset(i); 1964 else 1965 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1966 DL.getStructLayout(ST)->getElementOffset(i); 1967 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1968 } 1969 } 1970 return; 1971 } 1972 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1973 } 1974 1975 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1976 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1977 /// expressions that are representable in PTX and create 1978 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1979 const MCExpr * 1980 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1981 MCContext &Ctx = OutContext; 1982 1983 if (CV->isNullValue() || isa<UndefValue>(CV)) 1984 return MCConstantExpr::create(0, Ctx); 1985 1986 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1987 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1988 1989 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1990 const MCSymbolRefExpr *Expr = 1991 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1992 if (ProcessingGeneric) { 1993 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1994 } else { 1995 return Expr; 1996 } 1997 } 1998 1999 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 2000 if (!CE) { 2001 llvm_unreachable("Unknown constant value to lower!"); 2002 } 2003 2004 switch (CE->getOpcode()) { 2005 default: 2006 break; // Error 2007 2008 case Instruction::AddrSpaceCast: { 2009 // Strip the addrspacecast and pass along the operand 2010 PointerType *DstTy = cast<PointerType>(CE->getType()); 2011 if (DstTy->getAddressSpace() == 0) 2012 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 2013 2014 break; // Error 2015 } 2016 2017 case Instruction::GetElementPtr: { 2018 const DataLayout &DL = getDataLayout(); 2019 2020 // Generate a symbolic expression for the byte address 2021 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 2022 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 2023 2024 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2025 ProcessingGeneric); 2026 if (!OffsetAI) 2027 return Base; 2028 2029 int64_t Offset = OffsetAI.getSExtValue(); 2030 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2031 Ctx); 2032 } 2033 2034 case Instruction::Trunc: 2035 // We emit the value and depend on the assembler to truncate the generated 2036 // expression properly. This is important for differences between 2037 // blockaddress labels. Since the two labels are in the same function, it 2038 // is reasonable to treat their delta as a 32-bit value. 2039 [[fallthrough]]; 2040 case Instruction::BitCast: 2041 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2042 2043 case Instruction::IntToPtr: { 2044 const DataLayout &DL = getDataLayout(); 2045 2046 // Handle casts to pointers by changing them into casts to the appropriate 2047 // integer type. This promotes constant folding and simplifies this code. 2048 Constant *Op = CE->getOperand(0); 2049 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2050 /*IsSigned*/ false, DL); 2051 if (Op) 2052 return lowerConstantForGV(Op, ProcessingGeneric); 2053 2054 break; // Error 2055 } 2056 2057 case Instruction::PtrToInt: { 2058 const DataLayout &DL = getDataLayout(); 2059 2060 // Support only foldable casts to/from pointers that can be eliminated by 2061 // changing the pointer to the appropriately sized integer type. 2062 Constant *Op = CE->getOperand(0); 2063 Type *Ty = CE->getType(); 2064 2065 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2066 2067 // We can emit the pointer value into this slot if the slot is an 2068 // integer slot equal to the size of the pointer. 2069 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2070 return OpExpr; 2071 2072 // Otherwise the pointer is smaller than the resultant integer, mask off 2073 // the high bits so we are sure to get a proper truncation if the input is 2074 // a constant expr. 2075 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2076 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2077 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2078 } 2079 2080 // The MC library also has a right-shift operator, but it isn't consistently 2081 // signed or unsigned between different targets. 2082 case Instruction::Add: { 2083 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2084 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2085 switch (CE->getOpcode()) { 2086 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2087 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2088 } 2089 } 2090 } 2091 2092 // If the code isn't optimized, there may be outstanding folding 2093 // opportunities. Attempt to fold the expression using DataLayout as a 2094 // last resort before giving up. 2095 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 2096 if (C != CE) 2097 return lowerConstantForGV(C, ProcessingGeneric); 2098 2099 // Otherwise report the problem to the user. 2100 std::string S; 2101 raw_string_ostream OS(S); 2102 OS << "Unsupported expression in static initializer: "; 2103 CE->printAsOperand(OS, /*PrintType=*/false, 2104 !MF ? nullptr : MF->getFunction().getParent()); 2105 report_fatal_error(Twine(OS.str())); 2106 } 2107 2108 // Copy of MCExpr::print customized for NVPTX 2109 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2110 switch (Expr.getKind()) { 2111 case MCExpr::Target: 2112 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2113 case MCExpr::Constant: 2114 OS << cast<MCConstantExpr>(Expr).getValue(); 2115 return; 2116 2117 case MCExpr::SymbolRef: { 2118 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2119 const MCSymbol &Sym = SRE.getSymbol(); 2120 Sym.print(OS, MAI); 2121 return; 2122 } 2123 2124 case MCExpr::Unary: { 2125 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2126 switch (UE.getOpcode()) { 2127 case MCUnaryExpr::LNot: OS << '!'; break; 2128 case MCUnaryExpr::Minus: OS << '-'; break; 2129 case MCUnaryExpr::Not: OS << '~'; break; 2130 case MCUnaryExpr::Plus: OS << '+'; break; 2131 } 2132 printMCExpr(*UE.getSubExpr(), OS); 2133 return; 2134 } 2135 2136 case MCExpr::Binary: { 2137 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2138 2139 // Only print parens around the LHS if it is non-trivial. 2140 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2141 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2142 printMCExpr(*BE.getLHS(), OS); 2143 } else { 2144 OS << '('; 2145 printMCExpr(*BE.getLHS(), OS); 2146 OS<< ')'; 2147 } 2148 2149 switch (BE.getOpcode()) { 2150 case MCBinaryExpr::Add: 2151 // Print "X-42" instead of "X+-42". 2152 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2153 if (RHSC->getValue() < 0) { 2154 OS << RHSC->getValue(); 2155 return; 2156 } 2157 } 2158 2159 OS << '+'; 2160 break; 2161 default: llvm_unreachable("Unhandled binary operator"); 2162 } 2163 2164 // Only print parens around the LHS if it is non-trivial. 2165 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2166 printMCExpr(*BE.getRHS(), OS); 2167 } else { 2168 OS << '('; 2169 printMCExpr(*BE.getRHS(), OS); 2170 OS << ')'; 2171 } 2172 return; 2173 } 2174 } 2175 2176 llvm_unreachable("Invalid expression kind!"); 2177 } 2178 2179 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2180 /// 2181 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2182 const char *ExtraCode, raw_ostream &O) { 2183 if (ExtraCode && ExtraCode[0]) { 2184 if (ExtraCode[1] != 0) 2185 return true; // Unknown modifier. 2186 2187 switch (ExtraCode[0]) { 2188 default: 2189 // See if this is a generic print operand 2190 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2191 case 'r': 2192 break; 2193 } 2194 } 2195 2196 printOperand(MI, OpNo, O); 2197 2198 return false; 2199 } 2200 2201 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2202 unsigned OpNo, 2203 const char *ExtraCode, 2204 raw_ostream &O) { 2205 if (ExtraCode && ExtraCode[0]) 2206 return true; // Unknown modifier 2207 2208 O << '['; 2209 printMemOperand(MI, OpNo, O); 2210 O << ']'; 2211 2212 return false; 2213 } 2214 2215 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum, 2216 raw_ostream &O) { 2217 const MachineOperand &MO = MI->getOperand(OpNum); 2218 switch (MO.getType()) { 2219 case MachineOperand::MO_Register: 2220 if (MO.getReg().isPhysical()) { 2221 if (MO.getReg() == NVPTX::VRDepot) 2222 O << DEPOTNAME << getFunctionNumber(); 2223 else 2224 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2225 } else { 2226 emitVirtualRegister(MO.getReg(), O); 2227 } 2228 break; 2229 2230 case MachineOperand::MO_Immediate: 2231 O << MO.getImm(); 2232 break; 2233 2234 case MachineOperand::MO_FPImmediate: 2235 printFPConstant(MO.getFPImm(), O); 2236 break; 2237 2238 case MachineOperand::MO_GlobalAddress: 2239 PrintSymbolOperand(MO, O); 2240 break; 2241 2242 case MachineOperand::MO_MachineBasicBlock: 2243 MO.getMBB()->getSymbol()->print(O, MAI); 2244 break; 2245 2246 default: 2247 llvm_unreachable("Operand type not supported."); 2248 } 2249 } 2250 2251 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum, 2252 raw_ostream &O, const char *Modifier) { 2253 printOperand(MI, OpNum, O); 2254 2255 if (Modifier && strcmp(Modifier, "add") == 0) { 2256 O << ", "; 2257 printOperand(MI, OpNum + 1, O); 2258 } else { 2259 if (MI->getOperand(OpNum + 1).isImm() && 2260 MI->getOperand(OpNum + 1).getImm() == 0) 2261 return; // don't print ',0' or '+0' 2262 O << "+"; 2263 printOperand(MI, OpNum + 1, O); 2264 } 2265 } 2266 2267 // Force static initialization. 2268 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2269 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2270 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2271 } 2272