1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Twine.h" 37 #include "llvm/Analysis/ConstantFolding.h" 38 #include "llvm/CodeGen/Analysis.h" 39 #include "llvm/CodeGen/MachineBasicBlock.h" 40 #include "llvm/CodeGen/MachineFrameInfo.h" 41 #include "llvm/CodeGen/MachineFunction.h" 42 #include "llvm/CodeGen/MachineInstr.h" 43 #include "llvm/CodeGen/MachineLoopInfo.h" 44 #include "llvm/CodeGen/MachineModuleInfo.h" 45 #include "llvm/CodeGen/MachineOperand.h" 46 #include "llvm/CodeGen/MachineRegisterInfo.h" 47 #include "llvm/CodeGen/MachineValueType.h" 48 #include "llvm/CodeGen/TargetRegisterInfo.h" 49 #include "llvm/CodeGen/ValueTypes.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalValue.h" 61 #include "llvm/IR/GlobalVariable.h" 62 #include "llvm/IR/Instruction.h" 63 #include "llvm/IR/LLVMContext.h" 64 #include "llvm/IR/Module.h" 65 #include "llvm/IR/Operator.h" 66 #include "llvm/IR/Type.h" 67 #include "llvm/IR/User.h" 68 #include "llvm/MC/MCExpr.h" 69 #include "llvm/MC/MCInst.h" 70 #include "llvm/MC/MCInstrDesc.h" 71 #include "llvm/MC/MCStreamer.h" 72 #include "llvm/MC/MCSymbol.h" 73 #include "llvm/MC/TargetRegistry.h" 74 #include "llvm/Support/Casting.h" 75 #include "llvm/Support/CommandLine.h" 76 #include "llvm/Support/Endian.h" 77 #include "llvm/Support/ErrorHandling.h" 78 #include "llvm/Support/NativeFormatting.h" 79 #include "llvm/Support/Path.h" 80 #include "llvm/Support/raw_ostream.h" 81 #include "llvm/Target/TargetLoweringObjectFile.h" 82 #include "llvm/Target/TargetMachine.h" 83 #include "llvm/TargetParser/Triple.h" 84 #include "llvm/Transforms/Utils/UnrollLoop.h" 85 #include <cassert> 86 #include <cstdint> 87 #include <cstring> 88 #include <new> 89 #include <string> 90 #include <utility> 91 #include <vector> 92 93 using namespace llvm; 94 95 static cl::opt<bool> 96 LowerCtorDtor("nvptx-lower-global-ctor-dtor", 97 cl::desc("Lower GPU ctor / dtors to globals on the device."), 98 cl::init(false), cl::Hidden); 99 100 #define DEPOTNAME "__local_depot" 101 102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 103 /// depends. 104 static void 105 DiscoverDependentGlobals(const Value *V, 106 DenseSet<const GlobalVariable *> &Globals) { 107 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 108 Globals.insert(GV); 109 else { 110 if (const User *U = dyn_cast<User>(V)) { 111 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 112 DiscoverDependentGlobals(U->getOperand(i), Globals); 113 } 114 } 115 } 116 } 117 118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 119 /// instances to be emitted, but only after any dependents have been added 120 /// first.s 121 static void 122 VisitGlobalVariableForEmission(const GlobalVariable *GV, 123 SmallVectorImpl<const GlobalVariable *> &Order, 124 DenseSet<const GlobalVariable *> &Visited, 125 DenseSet<const GlobalVariable *> &Visiting) { 126 // Have we already visited this one? 127 if (Visited.count(GV)) 128 return; 129 130 // Do we have a circular dependency? 131 if (!Visiting.insert(GV).second) 132 report_fatal_error("Circular dependency found in global variable set"); 133 134 // Make sure we visit all dependents first 135 DenseSet<const GlobalVariable *> Others; 136 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 137 DiscoverDependentGlobals(GV->getOperand(i), Others); 138 139 for (const GlobalVariable *GV : Others) 140 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 141 142 // Now we can visit ourself 143 Order.push_back(GV); 144 Visited.insert(GV); 145 Visiting.erase(GV); 146 } 147 148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 149 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 150 getSubtargetInfo().getFeatureBits()); 151 152 MCInst Inst; 153 lowerToMCInst(MI, Inst); 154 EmitToStreamer(*OutStreamer, Inst); 155 } 156 157 // Handle symbol backtracking for targets that do not support image handles 158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 159 unsigned OpNo, MCOperand &MCOp) { 160 const MachineOperand &MO = MI->getOperand(OpNo); 161 const MCInstrDesc &MCID = MI->getDesc(); 162 163 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 164 // This is a texture fetch, so operand 4 is a texref and operand 5 is 165 // a samplerref 166 if (OpNo == 4 && MO.isImm()) { 167 lowerImageHandleSymbol(MO.getImm(), MCOp); 168 return true; 169 } 170 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 171 lowerImageHandleSymbol(MO.getImm(), MCOp); 172 return true; 173 } 174 175 return false; 176 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 177 unsigned VecSize = 178 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 179 180 // For a surface load of vector size N, the Nth operand will be the surfref 181 if (OpNo == VecSize && MO.isImm()) { 182 lowerImageHandleSymbol(MO.getImm(), MCOp); 183 return true; 184 } 185 186 return false; 187 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 188 // This is a surface store, so operand 0 is a surfref 189 if (OpNo == 0 && MO.isImm()) { 190 lowerImageHandleSymbol(MO.getImm(), MCOp); 191 return true; 192 } 193 194 return false; 195 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 196 // This is a query, so operand 1 is a surfref/texref 197 if (OpNo == 1 && MO.isImm()) { 198 lowerImageHandleSymbol(MO.getImm(), MCOp); 199 return true; 200 } 201 202 return false; 203 } 204 205 return false; 206 } 207 208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 209 // Ewwww 210 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 211 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 212 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 213 const char *Sym = MFI->getImageHandleSymbol(Index); 214 StringRef SymName = nvTM.getStrPool().save(Sym); 215 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); 216 } 217 218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 219 OutMI.setOpcode(MI->getOpcode()); 220 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 221 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 222 const MachineOperand &MO = MI->getOperand(0); 223 OutMI.addOperand(GetSymbolRef( 224 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 225 return; 226 } 227 228 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 229 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 230 const MachineOperand &MO = MI->getOperand(i); 231 232 MCOperand MCOp; 233 if (!STI.hasImageHandles()) { 234 if (lowerImageHandleOperand(MI, i, MCOp)) { 235 OutMI.addOperand(MCOp); 236 continue; 237 } 238 } 239 240 if (lowerOperand(MO, MCOp)) 241 OutMI.addOperand(MCOp); 242 } 243 } 244 245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 246 MCOperand &MCOp) { 247 switch (MO.getType()) { 248 default: llvm_unreachable("unknown operand type"); 249 case MachineOperand::MO_Register: 250 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 251 break; 252 case MachineOperand::MO_Immediate: 253 MCOp = MCOperand::createImm(MO.getImm()); 254 break; 255 case MachineOperand::MO_MachineBasicBlock: 256 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 257 MO.getMBB()->getSymbol(), OutContext)); 258 break; 259 case MachineOperand::MO_ExternalSymbol: 260 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 261 break; 262 case MachineOperand::MO_GlobalAddress: 263 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 264 break; 265 case MachineOperand::MO_FPImmediate: { 266 const ConstantFP *Cnt = MO.getFPImm(); 267 const APFloat &Val = Cnt->getValueAPF(); 268 269 switch (Cnt->getType()->getTypeID()) { 270 default: report_fatal_error("Unsupported FP type"); break; 271 case Type::HalfTyID: 272 MCOp = MCOperand::createExpr( 273 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 274 break; 275 case Type::BFloatTyID: 276 MCOp = MCOperand::createExpr( 277 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); 278 break; 279 case Type::FloatTyID: 280 MCOp = MCOperand::createExpr( 281 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 282 break; 283 case Type::DoubleTyID: 284 MCOp = MCOperand::createExpr( 285 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 286 break; 287 } 288 break; 289 } 290 } 291 return true; 292 } 293 294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 295 if (Register::isVirtualRegister(Reg)) { 296 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 297 298 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 299 unsigned RegNum = RegMap[Reg]; 300 301 // Encode the register class in the upper 4 bits 302 // Must be kept in sync with NVPTXInstPrinter::printRegName 303 unsigned Ret = 0; 304 if (RC == &NVPTX::Int1RegsRegClass) { 305 Ret = (1 << 28); 306 } else if (RC == &NVPTX::Int16RegsRegClass) { 307 Ret = (2 << 28); 308 } else if (RC == &NVPTX::Int32RegsRegClass) { 309 Ret = (3 << 28); 310 } else if (RC == &NVPTX::Int64RegsRegClass) { 311 Ret = (4 << 28); 312 } else if (RC == &NVPTX::Float32RegsRegClass) { 313 Ret = (5 << 28); 314 } else if (RC == &NVPTX::Float64RegsRegClass) { 315 Ret = (6 << 28); 316 } else { 317 report_fatal_error("Bad register class"); 318 } 319 320 // Insert the vreg number 321 Ret |= (RegNum & 0x0FFFFFFF); 322 return Ret; 323 } else { 324 // Some special-use registers are actually physical registers. 325 // Encode this as the register class ID of 0 and the real register ID. 326 return Reg & 0x0FFFFFFF; 327 } 328 } 329 330 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 331 const MCExpr *Expr; 332 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 333 OutContext); 334 return MCOperand::createExpr(Expr); 335 } 336 337 static bool ShouldPassAsArray(Type *Ty) { 338 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || 339 Ty->isHalfTy() || Ty->isBFloatTy(); 340 } 341 342 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 343 const DataLayout &DL = getDataLayout(); 344 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 345 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 346 347 Type *Ty = F->getReturnType(); 348 349 bool isABI = (STI.getSmVersion() >= 20); 350 351 if (Ty->getTypeID() == Type::VoidTyID) 352 return; 353 O << " ("; 354 355 if (isABI) { 356 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) && 357 !ShouldPassAsArray(Ty)) { 358 unsigned size = 0; 359 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 360 size = ITy->getBitWidth(); 361 } else { 362 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 363 size = Ty->getPrimitiveSizeInBits(); 364 } 365 size = promoteScalarArgumentSize(size); 366 O << ".param .b" << size << " func_retval0"; 367 } else if (isa<PointerType>(Ty)) { 368 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 369 << " func_retval0"; 370 } else if (ShouldPassAsArray(Ty)) { 371 unsigned totalsz = DL.getTypeAllocSize(Ty); 372 unsigned retAlignment = 0; 373 if (!getAlign(*F, 0, retAlignment)) 374 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value(); 375 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz 376 << "]"; 377 } else 378 llvm_unreachable("Unknown return type"); 379 } else { 380 SmallVector<EVT, 16> vtparts; 381 ComputeValueVTs(*TLI, DL, Ty, vtparts); 382 unsigned idx = 0; 383 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 384 unsigned elems = 1; 385 EVT elemtype = vtparts[i]; 386 if (vtparts[i].isVector()) { 387 elems = vtparts[i].getVectorNumElements(); 388 elemtype = vtparts[i].getVectorElementType(); 389 } 390 391 for (unsigned j = 0, je = elems; j != je; ++j) { 392 unsigned sz = elemtype.getSizeInBits(); 393 if (elemtype.isInteger()) 394 sz = promoteScalarArgumentSize(sz); 395 O << ".reg .b" << sz << " func_retval" << idx; 396 if (j < je - 1) 397 O << ", "; 398 ++idx; 399 } 400 if (i < e - 1) 401 O << ", "; 402 } 403 } 404 O << ") "; 405 } 406 407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 408 raw_ostream &O) { 409 const Function &F = MF.getFunction(); 410 printReturnValStr(&F, O); 411 } 412 413 // Return true if MBB is the header of a loop marked with 414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1. 415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 416 const MachineBasicBlock &MBB) const { 417 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 418 // We insert .pragma "nounroll" only to the loop header. 419 if (!LI.isLoopHeader(&MBB)) 420 return false; 421 422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 423 // we iterate through each back edge of the loop with header MBB, and check 424 // whether its metadata contains llvm.loop.unroll.disable. 425 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 426 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 427 // Edges from other loops to MBB are not back edges. 428 continue; 429 } 430 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 431 if (MDNode *LoopID = 432 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 433 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 434 return true; 435 if (MDNode *UnrollCountMD = 436 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) { 437 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1)) 438 ->isOne()) 439 return true; 440 } 441 } 442 } 443 } 444 return false; 445 } 446 447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 448 AsmPrinter::emitBasicBlockStart(MBB); 449 if (isLoopHeaderOfNoUnroll(MBB)) 450 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 451 } 452 453 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 454 SmallString<128> Str; 455 raw_svector_ostream O(Str); 456 457 if (!GlobalsEmitted) { 458 emitGlobals(*MF->getFunction().getParent()); 459 GlobalsEmitted = true; 460 } 461 462 // Set up 463 MRI = &MF->getRegInfo(); 464 F = &MF->getFunction(); 465 emitLinkageDirective(F, O); 466 if (isKernelFunction(*F)) 467 O << ".entry "; 468 else { 469 O << ".func "; 470 printReturnValStr(*MF, O); 471 } 472 473 CurrentFnSym->print(O, MAI); 474 475 emitFunctionParamList(F, O); 476 O << "\n"; 477 478 if (isKernelFunction(*F)) 479 emitKernelFunctionDirectives(*F, O); 480 481 if (shouldEmitPTXNoReturn(F, TM)) 482 O << ".noreturn"; 483 484 OutStreamer->emitRawText(O.str()); 485 486 VRegMapping.clear(); 487 // Emit open brace for function body. 488 OutStreamer->emitRawText(StringRef("{\n")); 489 setAndEmitFunctionVirtualRegisters(*MF); 490 // Emit initial .loc debug directive for correct relocation symbol data. 491 if (MMI && MMI->hasDebugInfo()) 492 emitInitialRawDwarfLocDirective(*MF); 493 } 494 495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 496 bool Result = AsmPrinter::runOnMachineFunction(F); 497 // Emit closing brace for the body of function F. 498 // The closing brace must be emitted here because we need to emit additional 499 // debug labels/data after the last basic block. 500 // We need to emit the closing brace here because we don't have function that 501 // finished emission of the function body. 502 OutStreamer->emitRawText(StringRef("}\n")); 503 return Result; 504 } 505 506 void NVPTXAsmPrinter::emitFunctionBodyStart() { 507 SmallString<128> Str; 508 raw_svector_ostream O(Str); 509 emitDemotedVars(&MF->getFunction(), O); 510 OutStreamer->emitRawText(O.str()); 511 } 512 513 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 514 VRegMapping.clear(); 515 } 516 517 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 518 SmallString<128> Str; 519 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 520 return OutContext.getOrCreateSymbol(Str); 521 } 522 523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 524 Register RegNo = MI->getOperand(0).getReg(); 525 if (RegNo.isVirtual()) { 526 OutStreamer->AddComment(Twine("implicit-def: ") + 527 getVirtualRegisterName(RegNo)); 528 } else { 529 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 530 OutStreamer->AddComment(Twine("implicit-def: ") + 531 STI.getRegisterInfo()->getName(RegNo)); 532 } 533 OutStreamer->addBlankLine(); 534 } 535 536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 537 raw_ostream &O) const { 538 // If the NVVM IR has some of reqntid* specified, then output 539 // the reqntid directive, and set the unspecified ones to 1. 540 // If none of reqntid* is specified, don't output reqntid directive. 541 unsigned reqntidx, reqntidy, reqntidz; 542 bool specified = false; 543 if (!getReqNTIDx(F, reqntidx)) 544 reqntidx = 1; 545 else 546 specified = true; 547 if (!getReqNTIDy(F, reqntidy)) 548 reqntidy = 1; 549 else 550 specified = true; 551 if (!getReqNTIDz(F, reqntidz)) 552 reqntidz = 1; 553 else 554 specified = true; 555 556 if (specified) 557 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz 558 << "\n"; 559 560 // If the NVVM IR has some of maxntid* specified, then output 561 // the maxntid directive, and set the unspecified ones to 1. 562 // If none of maxntid* is specified, don't output maxntid directive. 563 unsigned maxntidx, maxntidy, maxntidz; 564 specified = false; 565 if (!getMaxNTIDx(F, maxntidx)) 566 maxntidx = 1; 567 else 568 specified = true; 569 if (!getMaxNTIDy(F, maxntidy)) 570 maxntidy = 1; 571 else 572 specified = true; 573 if (!getMaxNTIDz(F, maxntidz)) 574 maxntidz = 1; 575 else 576 specified = true; 577 578 if (specified) 579 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz 580 << "\n"; 581 582 unsigned mincta; 583 if (getMinCTASm(F, mincta)) 584 O << ".minnctapersm " << mincta << "\n"; 585 586 unsigned maxnreg; 587 if (getMaxNReg(F, maxnreg)) 588 O << ".maxnreg " << maxnreg << "\n"; 589 } 590 591 std::string 592 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 593 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 594 595 std::string Name; 596 raw_string_ostream NameStr(Name); 597 598 VRegRCMap::const_iterator I = VRegMapping.find(RC); 599 assert(I != VRegMapping.end() && "Bad register class"); 600 const DenseMap<unsigned, unsigned> &RegMap = I->second; 601 602 VRegMap::const_iterator VI = RegMap.find(Reg); 603 assert(VI != RegMap.end() && "Bad virtual register"); 604 unsigned MappedVR = VI->second; 605 606 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 607 608 NameStr.flush(); 609 return Name; 610 } 611 612 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 613 raw_ostream &O) { 614 O << getVirtualRegisterName(vr); 615 } 616 617 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 618 emitLinkageDirective(F, O); 619 if (isKernelFunction(*F)) 620 O << ".entry "; 621 else 622 O << ".func "; 623 printReturnValStr(F, O); 624 getSymbol(F)->print(O, MAI); 625 O << "\n"; 626 emitFunctionParamList(F, O); 627 O << "\n"; 628 if (shouldEmitPTXNoReturn(F, TM)) 629 O << ".noreturn"; 630 O << ";\n"; 631 } 632 633 static bool usedInGlobalVarDef(const Constant *C) { 634 if (!C) 635 return false; 636 637 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 638 return GV->getName() != "llvm.used"; 639 } 640 641 for (const User *U : C->users()) 642 if (const Constant *C = dyn_cast<Constant>(U)) 643 if (usedInGlobalVarDef(C)) 644 return true; 645 646 return false; 647 } 648 649 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 650 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 651 if (othergv->getName() == "llvm.used") 652 return true; 653 } 654 655 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 656 if (instr->getParent() && instr->getParent()->getParent()) { 657 const Function *curFunc = instr->getParent()->getParent(); 658 if (oneFunc && (curFunc != oneFunc)) 659 return false; 660 oneFunc = curFunc; 661 return true; 662 } else 663 return false; 664 } 665 666 for (const User *UU : U->users()) 667 if (!usedInOneFunc(UU, oneFunc)) 668 return false; 669 670 return true; 671 } 672 673 /* Find out if a global variable can be demoted to local scope. 674 * Currently, this is valid for CUDA shared variables, which have local 675 * scope and global lifetime. So the conditions to check are : 676 * 1. Is the global variable in shared address space? 677 * 2. Does it have internal linkage? 678 * 3. Is the global variable referenced only in one function? 679 */ 680 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 681 if (!gv->hasInternalLinkage()) 682 return false; 683 PointerType *Pty = gv->getType(); 684 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 685 return false; 686 687 const Function *oneFunc = nullptr; 688 689 bool flag = usedInOneFunc(gv, oneFunc); 690 if (!flag) 691 return false; 692 if (!oneFunc) 693 return false; 694 f = oneFunc; 695 return true; 696 } 697 698 static bool useFuncSeen(const Constant *C, 699 DenseMap<const Function *, bool> &seenMap) { 700 for (const User *U : C->users()) { 701 if (const Constant *cu = dyn_cast<Constant>(U)) { 702 if (useFuncSeen(cu, seenMap)) 703 return true; 704 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 705 const BasicBlock *bb = I->getParent(); 706 if (!bb) 707 continue; 708 const Function *caller = bb->getParent(); 709 if (!caller) 710 continue; 711 if (seenMap.contains(caller)) 712 return true; 713 } 714 } 715 return false; 716 } 717 718 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 719 DenseMap<const Function *, bool> seenMap; 720 for (const Function &F : M) { 721 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 722 emitDeclaration(&F, O); 723 continue; 724 } 725 726 if (F.isDeclaration()) { 727 if (F.use_empty()) 728 continue; 729 if (F.getIntrinsicID()) 730 continue; 731 emitDeclaration(&F, O); 732 continue; 733 } 734 for (const User *U : F.users()) { 735 if (const Constant *C = dyn_cast<Constant>(U)) { 736 if (usedInGlobalVarDef(C)) { 737 // The use is in the initialization of a global variable 738 // that is a function pointer, so print a declaration 739 // for the original function 740 emitDeclaration(&F, O); 741 break; 742 } 743 // Emit a declaration of this function if the function that 744 // uses this constant expr has already been seen. 745 if (useFuncSeen(C, seenMap)) { 746 emitDeclaration(&F, O); 747 break; 748 } 749 } 750 751 if (!isa<Instruction>(U)) 752 continue; 753 const Instruction *instr = cast<Instruction>(U); 754 const BasicBlock *bb = instr->getParent(); 755 if (!bb) 756 continue; 757 const Function *caller = bb->getParent(); 758 if (!caller) 759 continue; 760 761 // If a caller has already been seen, then the caller is 762 // appearing in the module before the callee. so print out 763 // a declaration for the callee. 764 if (seenMap.contains(caller)) { 765 emitDeclaration(&F, O); 766 break; 767 } 768 } 769 seenMap[&F] = true; 770 } 771 } 772 773 static bool isEmptyXXStructor(GlobalVariable *GV) { 774 if (!GV) return true; 775 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 776 if (!InitList) return true; // Not an array; we don't know how to parse. 777 return InitList->getNumOperands() == 0; 778 } 779 780 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 781 // Construct a default subtarget off of the TargetMachine defaults. The 782 // rest of NVPTX isn't friendly to change subtargets per function and 783 // so the default TargetMachine will have all of the options. 784 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 785 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 786 SmallString<128> Str1; 787 raw_svector_ostream OS1(Str1); 788 789 // Emit header before any dwarf directives are emitted below. 790 emitHeader(M, OS1, *STI); 791 OutStreamer->emitRawText(OS1.str()); 792 } 793 794 bool NVPTXAsmPrinter::doInitialization(Module &M) { 795 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 796 const NVPTXSubtarget &STI = 797 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 798 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30)) 799 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30"); 800 801 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) && 802 !LowerCtorDtor) { 803 report_fatal_error( 804 "Module has a nontrivial global ctor, which NVPTX does not support."); 805 return true; // error 806 } 807 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) && 808 !LowerCtorDtor) { 809 report_fatal_error( 810 "Module has a nontrivial global dtor, which NVPTX does not support."); 811 return true; // error 812 } 813 814 // We need to call the parent's one explicitly. 815 bool Result = AsmPrinter::doInitialization(M); 816 817 GlobalsEmitted = false; 818 819 return Result; 820 } 821 822 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 823 SmallString<128> Str2; 824 raw_svector_ostream OS2(Str2); 825 826 emitDeclarations(M, OS2); 827 828 // As ptxas does not support forward references of globals, we need to first 829 // sort the list of module-level globals in def-use order. We visit each 830 // global variable in order, and ensure that we emit it *after* its dependent 831 // globals. We use a little extra memory maintaining both a set and a list to 832 // have fast searches while maintaining a strict ordering. 833 SmallVector<const GlobalVariable *, 8> Globals; 834 DenseSet<const GlobalVariable *> GVVisited; 835 DenseSet<const GlobalVariable *> GVVisiting; 836 837 // Visit each global variable, in order 838 for (const GlobalVariable &I : M.globals()) 839 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 840 841 assert(GVVisited.size() == M.global_size() && "Missed a global variable"); 842 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 843 844 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 845 const NVPTXSubtarget &STI = 846 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 847 848 // Print out module-level global variables in proper order 849 for (unsigned i = 0, e = Globals.size(); i != e; ++i) 850 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI); 851 852 OS2 << '\n'; 853 854 OutStreamer->emitRawText(OS2.str()); 855 } 856 857 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) { 858 SmallString<128> Str; 859 raw_svector_ostream OS(Str); 860 861 MCSymbol *Name = getSymbol(&GA); 862 const Function *F = dyn_cast<Function>(GA.getAliasee()); 863 if (!F || isKernelFunction(*F)) 864 report_fatal_error("NVPTX aliasee must be a non-kernel function"); 865 866 if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() || 867 GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage()) 868 report_fatal_error("NVPTX aliasee must not be '.weak'"); 869 870 OS << "\n"; 871 emitLinkageDirective(F, OS); 872 OS << ".func "; 873 printReturnValStr(F, OS); 874 OS << Name->getName(); 875 emitFunctionParamList(F, OS); 876 if (shouldEmitPTXNoReturn(F, TM)) 877 OS << "\n.noreturn"; 878 OS << ";\n"; 879 880 OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n"; 881 882 OutStreamer->emitRawText(OS.str()); 883 } 884 885 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 886 const NVPTXSubtarget &STI) { 887 O << "//\n"; 888 O << "// Generated by LLVM NVPTX Back-End\n"; 889 O << "//\n"; 890 O << "\n"; 891 892 unsigned PTXVersion = STI.getPTXVersion(); 893 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 894 895 O << ".target "; 896 O << STI.getTargetName(); 897 898 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 899 if (NTM.getDrvInterface() == NVPTX::NVCL) 900 O << ", texmode_independent"; 901 902 bool HasFullDebugInfo = false; 903 for (DICompileUnit *CU : M.debug_compile_units()) { 904 switch(CU->getEmissionKind()) { 905 case DICompileUnit::NoDebug: 906 case DICompileUnit::DebugDirectivesOnly: 907 break; 908 case DICompileUnit::LineTablesOnly: 909 case DICompileUnit::FullDebug: 910 HasFullDebugInfo = true; 911 break; 912 } 913 if (HasFullDebugInfo) 914 break; 915 } 916 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 917 O << ", debug"; 918 919 O << "\n"; 920 921 O << ".address_size "; 922 if (NTM.is64Bit()) 923 O << "64"; 924 else 925 O << "32"; 926 O << "\n"; 927 928 O << "\n"; 929 } 930 931 bool NVPTXAsmPrinter::doFinalization(Module &M) { 932 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 933 934 // If we did not emit any functions, then the global declarations have not 935 // yet been emitted. 936 if (!GlobalsEmitted) { 937 emitGlobals(M); 938 GlobalsEmitted = true; 939 } 940 941 // If we have any aliases we emit them at the end. 942 SmallVector<GlobalAlias *> AliasesToRemove; 943 for (GlobalAlias &Alias : M.aliases()) { 944 emitGlobalAlias(M, Alias); 945 AliasesToRemove.push_back(&Alias); 946 } 947 948 for (GlobalAlias *A : AliasesToRemove) 949 A->eraseFromParent(); 950 951 // call doFinalization 952 bool ret = AsmPrinter::doFinalization(M); 953 954 clearAnnotationCache(&M); 955 956 auto *TS = 957 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()); 958 // Close the last emitted section 959 if (HasDebugInfo) { 960 TS->closeLastSection(); 961 // Emit empty .debug_loc section for better support of the empty files. 962 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 963 } 964 965 // Output last DWARF .file directives, if any. 966 TS->outputDwarfFileDirectives(); 967 968 return ret; 969 } 970 971 // This function emits appropriate linkage directives for 972 // functions and global variables. 973 // 974 // extern function declaration -> .extern 975 // extern function definition -> .visible 976 // external global variable with init -> .visible 977 // external without init -> .extern 978 // appending -> not allowed, assert. 979 // for any linkage other than 980 // internal, private, linker_private, 981 // linker_private_weak, linker_private_weak_def_auto, 982 // we emit -> .weak. 983 984 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 985 raw_ostream &O) { 986 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 987 if (V->hasExternalLinkage()) { 988 if (isa<GlobalVariable>(V)) { 989 const GlobalVariable *GVar = cast<GlobalVariable>(V); 990 if (GVar) { 991 if (GVar->hasInitializer()) 992 O << ".visible "; 993 else 994 O << ".extern "; 995 } 996 } else if (V->isDeclaration()) 997 O << ".extern "; 998 else 999 O << ".visible "; 1000 } else if (V->hasAppendingLinkage()) { 1001 std::string msg; 1002 msg.append("Error: "); 1003 msg.append("Symbol "); 1004 if (V->hasName()) 1005 msg.append(std::string(V->getName())); 1006 msg.append("has unsupported appending linkage type"); 1007 llvm_unreachable(msg.c_str()); 1008 } else if (!V->hasInternalLinkage() && 1009 !V->hasPrivateLinkage()) { 1010 O << ".weak "; 1011 } 1012 } 1013 } 1014 1015 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 1016 raw_ostream &O, bool processDemoted, 1017 const NVPTXSubtarget &STI) { 1018 // Skip meta data 1019 if (GVar->hasSection()) { 1020 if (GVar->getSection() == "llvm.metadata") 1021 return; 1022 } 1023 1024 // Skip LLVM intrinsic global variables 1025 if (GVar->getName().startswith("llvm.") || 1026 GVar->getName().startswith("nvvm.")) 1027 return; 1028 1029 const DataLayout &DL = getDataLayout(); 1030 1031 // GlobalVariables are always constant pointers themselves. 1032 PointerType *PTy = GVar->getType(); 1033 Type *ETy = GVar->getValueType(); 1034 1035 if (GVar->hasExternalLinkage()) { 1036 if (GVar->hasInitializer()) 1037 O << ".visible "; 1038 else 1039 O << ".extern "; 1040 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 1041 GVar->hasAvailableExternallyLinkage() || 1042 GVar->hasCommonLinkage()) { 1043 O << ".weak "; 1044 } 1045 1046 if (isTexture(*GVar)) { 1047 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1048 return; 1049 } 1050 1051 if (isSurface(*GVar)) { 1052 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1053 return; 1054 } 1055 1056 if (GVar->isDeclaration()) { 1057 // (extern) declarations, no definition or initializer 1058 // Currently the only known declaration is for an automatic __local 1059 // (.shared) promoted to global. 1060 emitPTXGlobalVariable(GVar, O, STI); 1061 O << ";\n"; 1062 return; 1063 } 1064 1065 if (isSampler(*GVar)) { 1066 O << ".global .samplerref " << getSamplerName(*GVar); 1067 1068 const Constant *Initializer = nullptr; 1069 if (GVar->hasInitializer()) 1070 Initializer = GVar->getInitializer(); 1071 const ConstantInt *CI = nullptr; 1072 if (Initializer) 1073 CI = dyn_cast<ConstantInt>(Initializer); 1074 if (CI) { 1075 unsigned sample = CI->getZExtValue(); 1076 1077 O << " = { "; 1078 1079 for (int i = 0, 1080 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1081 i < 3; i++) { 1082 O << "addr_mode_" << i << " = "; 1083 switch (addr) { 1084 case 0: 1085 O << "wrap"; 1086 break; 1087 case 1: 1088 O << "clamp_to_border"; 1089 break; 1090 case 2: 1091 O << "clamp_to_edge"; 1092 break; 1093 case 3: 1094 O << "wrap"; 1095 break; 1096 case 4: 1097 O << "mirror"; 1098 break; 1099 } 1100 O << ", "; 1101 } 1102 O << "filter_mode = "; 1103 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1104 case 0: 1105 O << "nearest"; 1106 break; 1107 case 1: 1108 O << "linear"; 1109 break; 1110 case 2: 1111 llvm_unreachable("Anisotropic filtering is not supported"); 1112 default: 1113 O << "nearest"; 1114 break; 1115 } 1116 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1117 O << ", force_unnormalized_coords = 1"; 1118 } 1119 O << " }"; 1120 } 1121 1122 O << ";\n"; 1123 return; 1124 } 1125 1126 if (GVar->hasPrivateLinkage()) { 1127 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1128 return; 1129 1130 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1131 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1132 return; 1133 if (GVar->use_empty()) 1134 return; 1135 } 1136 1137 const Function *demotedFunc = nullptr; 1138 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1139 O << "// " << GVar->getName() << " has been demoted\n"; 1140 if (localDecls.find(demotedFunc) != localDecls.end()) 1141 localDecls[demotedFunc].push_back(GVar); 1142 else { 1143 std::vector<const GlobalVariable *> temp; 1144 temp.push_back(GVar); 1145 localDecls[demotedFunc] = temp; 1146 } 1147 return; 1148 } 1149 1150 O << "."; 1151 emitPTXAddressSpace(PTy->getAddressSpace(), O); 1152 1153 if (isManaged(*GVar)) { 1154 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1155 report_fatal_error( 1156 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1157 } 1158 O << " .attribute(.managed)"; 1159 } 1160 1161 if (MaybeAlign A = GVar->getAlign()) 1162 O << " .align " << A->value(); 1163 else 1164 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1165 1166 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1167 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1168 O << " ."; 1169 // Special case: ABI requires that we use .u8 for predicates 1170 if (ETy->isIntegerTy(1)) 1171 O << "u8"; 1172 else 1173 O << getPTXFundamentalTypeStr(ETy, false); 1174 O << " "; 1175 getSymbol(GVar)->print(O, MAI); 1176 1177 // Ptx allows variable initilization only for constant and global state 1178 // spaces. 1179 if (GVar->hasInitializer()) { 1180 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1181 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1182 const Constant *Initializer = GVar->getInitializer(); 1183 // 'undef' is treated as there is no value specified. 1184 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1185 O << " = "; 1186 printScalarConstant(Initializer, O); 1187 } 1188 } else { 1189 // The frontend adds zero-initializer to device and constant variables 1190 // that don't have an initial value, and UndefValue to shared 1191 // variables, so skip warning for this case. 1192 if (!GVar->getInitializer()->isNullValue() && 1193 !isa<UndefValue>(GVar->getInitializer())) { 1194 report_fatal_error("initial value of '" + GVar->getName() + 1195 "' is not allowed in addrspace(" + 1196 Twine(PTy->getAddressSpace()) + ")"); 1197 } 1198 } 1199 } 1200 } else { 1201 uint64_t ElementSize = 0; 1202 1203 // Although PTX has direct support for struct type and array type and 1204 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1205 // targets that support these high level field accesses. Structs, arrays 1206 // and vectors are lowered into arrays of bytes. 1207 switch (ETy->getTypeID()) { 1208 case Type::IntegerTyID: // Integers larger than 64 bits 1209 case Type::StructTyID: 1210 case Type::ArrayTyID: 1211 case Type::FixedVectorTyID: 1212 ElementSize = DL.getTypeStoreSize(ETy); 1213 // Ptx allows variable initilization only for constant and 1214 // global state spaces. 1215 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1216 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1217 GVar->hasInitializer()) { 1218 const Constant *Initializer = GVar->getInitializer(); 1219 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1220 AggBuffer aggBuffer(ElementSize, *this); 1221 bufferAggregateConstant(Initializer, &aggBuffer); 1222 if (aggBuffer.numSymbols()) { 1223 unsigned int ptrSize = MAI->getCodePointerSize(); 1224 if (ElementSize % ptrSize || 1225 !aggBuffer.allSymbolsAligned(ptrSize)) { 1226 // Print in bytes and use the mask() operator for pointers. 1227 if (!STI.hasMaskOperator()) 1228 report_fatal_error( 1229 "initialized packed aggregate with pointers '" + 1230 GVar->getName() + 1231 "' requires at least PTX ISA version 7.1"); 1232 O << " .u8 "; 1233 getSymbol(GVar)->print(O, MAI); 1234 O << "[" << ElementSize << "] = {"; 1235 aggBuffer.printBytes(O); 1236 O << "}"; 1237 } else { 1238 O << " .u" << ptrSize * 8 << " "; 1239 getSymbol(GVar)->print(O, MAI); 1240 O << "[" << ElementSize / ptrSize << "] = {"; 1241 aggBuffer.printWords(O); 1242 O << "}"; 1243 } 1244 } else { 1245 O << " .b8 "; 1246 getSymbol(GVar)->print(O, MAI); 1247 O << "[" << ElementSize << "] = {"; 1248 aggBuffer.printBytes(O); 1249 O << "}"; 1250 } 1251 } else { 1252 O << " .b8 "; 1253 getSymbol(GVar)->print(O, MAI); 1254 if (ElementSize) { 1255 O << "["; 1256 O << ElementSize; 1257 O << "]"; 1258 } 1259 } 1260 } else { 1261 O << " .b8 "; 1262 getSymbol(GVar)->print(O, MAI); 1263 if (ElementSize) { 1264 O << "["; 1265 O << ElementSize; 1266 O << "]"; 1267 } 1268 } 1269 break; 1270 default: 1271 llvm_unreachable("type not supported yet"); 1272 } 1273 } 1274 O << ";\n"; 1275 } 1276 1277 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1278 const Value *v = Symbols[nSym]; 1279 const Value *v0 = SymbolsBeforeStripping[nSym]; 1280 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1281 MCSymbol *Name = AP.getSymbol(GVar); 1282 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1283 // Is v0 a generic pointer? 1284 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1285 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1286 os << "generic("; 1287 Name->print(os, AP.MAI); 1288 os << ")"; 1289 } else { 1290 Name->print(os, AP.MAI); 1291 } 1292 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1293 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1294 AP.printMCExpr(*Expr, os); 1295 } else 1296 llvm_unreachable("symbol type unknown"); 1297 } 1298 1299 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1300 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1301 symbolPosInBuffer.push_back(size); 1302 unsigned int nSym = 0; 1303 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1304 for (unsigned int pos = 0; pos < size;) { 1305 if (pos) 1306 os << ", "; 1307 if (pos != nextSymbolPos) { 1308 os << (unsigned int)buffer[pos]; 1309 ++pos; 1310 continue; 1311 } 1312 // Generate a per-byte mask() operator for the symbol, which looks like: 1313 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1314 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1315 std::string symText; 1316 llvm::raw_string_ostream oss(symText); 1317 printSymbol(nSym, oss); 1318 for (unsigned i = 0; i < ptrSize; ++i) { 1319 if (i) 1320 os << ", "; 1321 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1322 os << "(" << symText << ")"; 1323 } 1324 pos += ptrSize; 1325 nextSymbolPos = symbolPosInBuffer[++nSym]; 1326 assert(nextSymbolPos >= pos); 1327 } 1328 } 1329 1330 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1331 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1332 symbolPosInBuffer.push_back(size); 1333 unsigned int nSym = 0; 1334 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1335 assert(nextSymbolPos % ptrSize == 0); 1336 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1337 if (pos) 1338 os << ", "; 1339 if (pos == nextSymbolPos) { 1340 printSymbol(nSym, os); 1341 nextSymbolPos = symbolPosInBuffer[++nSym]; 1342 assert(nextSymbolPos % ptrSize == 0); 1343 assert(nextSymbolPos >= pos + ptrSize); 1344 } else if (ptrSize == 4) 1345 os << support::endian::read32le(&buffer[pos]); 1346 else 1347 os << support::endian::read64le(&buffer[pos]); 1348 } 1349 } 1350 1351 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1352 if (localDecls.find(f) == localDecls.end()) 1353 return; 1354 1355 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1356 1357 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1358 const NVPTXSubtarget &STI = 1359 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1360 1361 for (const GlobalVariable *GV : gvars) { 1362 O << "\t// demoted variable\n\t"; 1363 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1364 } 1365 } 1366 1367 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1368 raw_ostream &O) const { 1369 switch (AddressSpace) { 1370 case ADDRESS_SPACE_LOCAL: 1371 O << "local"; 1372 break; 1373 case ADDRESS_SPACE_GLOBAL: 1374 O << "global"; 1375 break; 1376 case ADDRESS_SPACE_CONST: 1377 O << "const"; 1378 break; 1379 case ADDRESS_SPACE_SHARED: 1380 O << "shared"; 1381 break; 1382 default: 1383 report_fatal_error("Bad address space found while emitting PTX: " + 1384 llvm::Twine(AddressSpace)); 1385 break; 1386 } 1387 } 1388 1389 std::string 1390 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1391 switch (Ty->getTypeID()) { 1392 case Type::IntegerTyID: { 1393 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1394 if (NumBits == 1) 1395 return "pred"; 1396 else if (NumBits <= 64) { 1397 std::string name = "u"; 1398 return name + utostr(NumBits); 1399 } else { 1400 llvm_unreachable("Integer too large"); 1401 break; 1402 } 1403 break; 1404 } 1405 case Type::BFloatTyID: 1406 case Type::HalfTyID: 1407 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 1408 // PTX assembly. 1409 return "b16"; 1410 case Type::FloatTyID: 1411 return "f32"; 1412 case Type::DoubleTyID: 1413 return "f64"; 1414 case Type::PointerTyID: { 1415 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace()); 1416 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size"); 1417 1418 if (PtrSize == 64) 1419 if (useB4PTR) 1420 return "b64"; 1421 else 1422 return "u64"; 1423 else if (useB4PTR) 1424 return "b32"; 1425 else 1426 return "u32"; 1427 } 1428 default: 1429 break; 1430 } 1431 llvm_unreachable("unexpected type"); 1432 } 1433 1434 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1435 raw_ostream &O, 1436 const NVPTXSubtarget &STI) { 1437 const DataLayout &DL = getDataLayout(); 1438 1439 // GlobalVariables are always constant pointers themselves. 1440 Type *ETy = GVar->getValueType(); 1441 1442 O << "."; 1443 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1444 if (isManaged(*GVar)) { 1445 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1446 report_fatal_error( 1447 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1448 } 1449 O << " .attribute(.managed)"; 1450 } 1451 if (MaybeAlign A = GVar->getAlign()) 1452 O << " .align " << A->value(); 1453 else 1454 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1455 1456 // Special case for i128 1457 if (ETy->isIntegerTy(128)) { 1458 O << " .b8 "; 1459 getSymbol(GVar)->print(O, MAI); 1460 O << "[16]"; 1461 return; 1462 } 1463 1464 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1465 O << " ."; 1466 O << getPTXFundamentalTypeStr(ETy); 1467 O << " "; 1468 getSymbol(GVar)->print(O, MAI); 1469 return; 1470 } 1471 1472 int64_t ElementSize = 0; 1473 1474 // Although PTX has direct support for struct type and array type and LLVM IR 1475 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1476 // support these high level field accesses. Structs and arrays are lowered 1477 // into arrays of bytes. 1478 switch (ETy->getTypeID()) { 1479 case Type::StructTyID: 1480 case Type::ArrayTyID: 1481 case Type::FixedVectorTyID: 1482 ElementSize = DL.getTypeStoreSize(ETy); 1483 O << " .b8 "; 1484 getSymbol(GVar)->print(O, MAI); 1485 O << "["; 1486 if (ElementSize) { 1487 O << ElementSize; 1488 } 1489 O << "]"; 1490 break; 1491 default: 1492 llvm_unreachable("type not supported yet"); 1493 } 1494 } 1495 1496 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1497 const DataLayout &DL = getDataLayout(); 1498 const AttributeList &PAL = F->getAttributes(); 1499 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1500 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1501 1502 Function::const_arg_iterator I, E; 1503 unsigned paramIndex = 0; 1504 bool first = true; 1505 bool isKernelFunc = isKernelFunction(*F); 1506 bool isABI = (STI.getSmVersion() >= 20); 1507 bool hasImageHandles = STI.hasImageHandles(); 1508 1509 if (F->arg_empty() && !F->isVarArg()) { 1510 O << "()"; 1511 return; 1512 } 1513 1514 O << "(\n"; 1515 1516 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1517 Type *Ty = I->getType(); 1518 1519 if (!first) 1520 O << ",\n"; 1521 1522 first = false; 1523 1524 // Handle image/sampler parameters 1525 if (isKernelFunction(*F)) { 1526 if (isSampler(*I) || isImage(*I)) { 1527 if (isImage(*I)) { 1528 std::string sname = std::string(I->getName()); 1529 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1530 if (hasImageHandles) 1531 O << "\t.param .u64 .ptr .surfref "; 1532 else 1533 O << "\t.param .surfref "; 1534 O << TLI->getParamName(F, paramIndex); 1535 } 1536 else { // Default image is read_only 1537 if (hasImageHandles) 1538 O << "\t.param .u64 .ptr .texref "; 1539 else 1540 O << "\t.param .texref "; 1541 O << TLI->getParamName(F, paramIndex); 1542 } 1543 } else { 1544 if (hasImageHandles) 1545 O << "\t.param .u64 .ptr .samplerref "; 1546 else 1547 O << "\t.param .samplerref "; 1548 O << TLI->getParamName(F, paramIndex); 1549 } 1550 continue; 1551 } 1552 } 1553 1554 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1555 paramIndex](Type *Ty) -> Align { 1556 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1557 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1558 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1559 }; 1560 1561 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1562 if (ShouldPassAsArray(Ty)) { 1563 // Just print .param .align <a> .b8 .param[size]; 1564 // <a> = optimal alignment for the element type; always multiple of 1565 // PAL.getParamAlignment 1566 // size = typeallocsize of element type 1567 Align OptimalAlign = getOptimalAlignForParam(Ty); 1568 1569 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1570 O << TLI->getParamName(F, paramIndex); 1571 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1572 1573 continue; 1574 } 1575 // Just a scalar 1576 auto *PTy = dyn_cast<PointerType>(Ty); 1577 unsigned PTySizeInBits = 0; 1578 if (PTy) { 1579 PTySizeInBits = 1580 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); 1581 assert(PTySizeInBits && "Invalid pointer size"); 1582 } 1583 1584 if (isKernelFunc) { 1585 if (PTy) { 1586 // Special handling for pointer arguments to kernel 1587 O << "\t.param .u" << PTySizeInBits << " "; 1588 1589 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1590 NVPTX::CUDA) { 1591 int addrSpace = PTy->getAddressSpace(); 1592 switch (addrSpace) { 1593 default: 1594 O << ".ptr "; 1595 break; 1596 case ADDRESS_SPACE_CONST: 1597 O << ".ptr .const "; 1598 break; 1599 case ADDRESS_SPACE_SHARED: 1600 O << ".ptr .shared "; 1601 break; 1602 case ADDRESS_SPACE_GLOBAL: 1603 O << ".ptr .global "; 1604 break; 1605 } 1606 Align ParamAlign = I->getParamAlign().valueOrOne(); 1607 O << ".align " << ParamAlign.value() << " "; 1608 } 1609 O << TLI->getParamName(F, paramIndex); 1610 continue; 1611 } 1612 1613 // non-pointer scalar to kernel func 1614 O << "\t.param ."; 1615 // Special case: predicate operands become .u8 types 1616 if (Ty->isIntegerTy(1)) 1617 O << "u8"; 1618 else 1619 O << getPTXFundamentalTypeStr(Ty); 1620 O << " "; 1621 O << TLI->getParamName(F, paramIndex); 1622 continue; 1623 } 1624 // Non-kernel function, just print .param .b<size> for ABI 1625 // and .reg .b<size> for non-ABI 1626 unsigned sz = 0; 1627 if (isa<IntegerType>(Ty)) { 1628 sz = cast<IntegerType>(Ty)->getBitWidth(); 1629 sz = promoteScalarArgumentSize(sz); 1630 } else if (PTy) { 1631 assert(PTySizeInBits && "Invalid pointer size"); 1632 sz = PTySizeInBits; 1633 } else 1634 sz = Ty->getPrimitiveSizeInBits(); 1635 if (isABI) 1636 O << "\t.param .b" << sz << " "; 1637 else 1638 O << "\t.reg .b" << sz << " "; 1639 O << TLI->getParamName(F, paramIndex); 1640 continue; 1641 } 1642 1643 // param has byVal attribute. 1644 Type *ETy = PAL.getParamByValType(paramIndex); 1645 assert(ETy && "Param should have byval type"); 1646 1647 if (isABI || isKernelFunc) { 1648 // Just print .param .align <a> .b8 .param[size]; 1649 // <a> = optimal alignment for the element type; always multiple of 1650 // PAL.getParamAlignment 1651 // size = typeallocsize of element type 1652 Align OptimalAlign = 1653 isKernelFunc 1654 ? getOptimalAlignForParam(ETy) 1655 : TLI->getFunctionByValParamAlign( 1656 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); 1657 1658 unsigned sz = DL.getTypeAllocSize(ETy); 1659 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1660 O << TLI->getParamName(F, paramIndex); 1661 O << "[" << sz << "]"; 1662 continue; 1663 } else { 1664 // Split the ETy into constituent parts and 1665 // print .param .b<size> <name> for each part. 1666 // Further, if a part is vector, print the above for 1667 // each vector element. 1668 SmallVector<EVT, 16> vtparts; 1669 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1670 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1671 unsigned elems = 1; 1672 EVT elemtype = vtparts[i]; 1673 if (vtparts[i].isVector()) { 1674 elems = vtparts[i].getVectorNumElements(); 1675 elemtype = vtparts[i].getVectorElementType(); 1676 } 1677 1678 for (unsigned j = 0, je = elems; j != je; ++j) { 1679 unsigned sz = elemtype.getSizeInBits(); 1680 if (elemtype.isInteger()) 1681 sz = promoteScalarArgumentSize(sz); 1682 O << "\t.reg .b" << sz << " "; 1683 O << TLI->getParamName(F, paramIndex); 1684 if (j < je - 1) 1685 O << ",\n"; 1686 ++paramIndex; 1687 } 1688 if (i < e - 1) 1689 O << ",\n"; 1690 } 1691 --paramIndex; 1692 continue; 1693 } 1694 } 1695 1696 if (F->isVarArg()) { 1697 if (!first) 1698 O << ",\n"; 1699 O << "\t.param .align " << STI.getMaxRequiredAlignment(); 1700 O << " .b8 "; 1701 O << TLI->getParamName(F, /* vararg */ -1) << "[]"; 1702 } 1703 1704 O << "\n)"; 1705 } 1706 1707 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1708 const MachineFunction &MF) { 1709 SmallString<128> Str; 1710 raw_svector_ostream O(Str); 1711 1712 // Map the global virtual register number to a register class specific 1713 // virtual register number starting from 1 with that class. 1714 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1715 //unsigned numRegClasses = TRI->getNumRegClasses(); 1716 1717 // Emit the Fake Stack Object 1718 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1719 int NumBytes = (int) MFI.getStackSize(); 1720 if (NumBytes) { 1721 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1722 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1723 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1724 O << "\t.reg .b64 \t%SP;\n"; 1725 O << "\t.reg .b64 \t%SPL;\n"; 1726 } else { 1727 O << "\t.reg .b32 \t%SP;\n"; 1728 O << "\t.reg .b32 \t%SPL;\n"; 1729 } 1730 } 1731 1732 // Go through all virtual registers to establish the mapping between the 1733 // global virtual 1734 // register number and the per class virtual register number. 1735 // We use the per class virtual register number in the ptx output. 1736 unsigned int numVRs = MRI->getNumVirtRegs(); 1737 for (unsigned i = 0; i < numVRs; i++) { 1738 Register vr = Register::index2VirtReg(i); 1739 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1740 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1741 int n = regmap.size(); 1742 regmap.insert(std::make_pair(vr, n + 1)); 1743 } 1744 1745 // Emit register declarations 1746 // @TODO: Extract out the real register usage 1747 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1748 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1749 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1750 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1751 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1752 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1753 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1754 1755 // Emit declaration of the virtual registers or 'physical' registers for 1756 // each register class 1757 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1758 const TargetRegisterClass *RC = TRI->getRegClass(i); 1759 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1760 std::string rcname = getNVPTXRegClassName(RC); 1761 std::string rcStr = getNVPTXRegClassStr(RC); 1762 int n = regmap.size(); 1763 1764 // Only declare those registers that may be used. 1765 if (n) { 1766 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1767 << ">;\n"; 1768 } 1769 } 1770 1771 OutStreamer->emitRawText(O.str()); 1772 } 1773 1774 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1775 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1776 bool ignored; 1777 unsigned int numHex; 1778 const char *lead; 1779 1780 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1781 numHex = 8; 1782 lead = "0f"; 1783 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1784 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1785 numHex = 16; 1786 lead = "0d"; 1787 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1788 } else 1789 llvm_unreachable("unsupported fp type"); 1790 1791 APInt API = APF.bitcastToAPInt(); 1792 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1793 } 1794 1795 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1796 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1797 O << CI->getValue(); 1798 return; 1799 } 1800 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1801 printFPConstant(CFP, O); 1802 return; 1803 } 1804 if (isa<ConstantPointerNull>(CPV)) { 1805 O << "0"; 1806 return; 1807 } 1808 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1809 bool IsNonGenericPointer = false; 1810 if (GVar->getType()->getAddressSpace() != 0) { 1811 IsNonGenericPointer = true; 1812 } 1813 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1814 O << "generic("; 1815 getSymbol(GVar)->print(O, MAI); 1816 O << ")"; 1817 } else { 1818 getSymbol(GVar)->print(O, MAI); 1819 } 1820 return; 1821 } 1822 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1823 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false); 1824 printMCExpr(*E, O); 1825 return; 1826 } 1827 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1828 } 1829 1830 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1831 AggBuffer *AggBuffer) { 1832 const DataLayout &DL = getDataLayout(); 1833 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1834 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1835 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1836 // only the space allocated by CPV. 1837 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1838 return; 1839 } 1840 1841 // Helper for filling AggBuffer with APInts. 1842 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1843 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1844 SmallVector<unsigned char, 16> Buf(NumBytes); 1845 for (unsigned I = 0; I < NumBytes; ++I) { 1846 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1847 } 1848 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1849 }; 1850 1851 switch (CPV->getType()->getTypeID()) { 1852 case Type::IntegerTyID: 1853 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1854 AddIntToBuffer(CI->getValue()); 1855 break; 1856 } 1857 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1858 if (const auto *CI = 1859 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1860 AddIntToBuffer(CI->getValue()); 1861 break; 1862 } 1863 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1864 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1865 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1866 AggBuffer->addZeros(AllocSize); 1867 break; 1868 } 1869 } 1870 llvm_unreachable("unsupported integer const type"); 1871 break; 1872 1873 case Type::HalfTyID: 1874 case Type::BFloatTyID: 1875 case Type::FloatTyID: 1876 case Type::DoubleTyID: 1877 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1878 break; 1879 1880 case Type::PointerTyID: { 1881 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1882 AggBuffer->addSymbol(GVar, GVar); 1883 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1884 const Value *v = Cexpr->stripPointerCasts(); 1885 AggBuffer->addSymbol(v, Cexpr); 1886 } 1887 AggBuffer->addZeros(AllocSize); 1888 break; 1889 } 1890 1891 case Type::ArrayTyID: 1892 case Type::FixedVectorTyID: 1893 case Type::StructTyID: { 1894 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1895 bufferAggregateConstant(CPV, AggBuffer); 1896 if (Bytes > AllocSize) 1897 AggBuffer->addZeros(Bytes - AllocSize); 1898 } else if (isa<ConstantAggregateZero>(CPV)) 1899 AggBuffer->addZeros(Bytes); 1900 else 1901 llvm_unreachable("Unexpected Constant type"); 1902 break; 1903 } 1904 1905 default: 1906 llvm_unreachable("unsupported type"); 1907 } 1908 } 1909 1910 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1911 AggBuffer *aggBuffer) { 1912 const DataLayout &DL = getDataLayout(); 1913 int Bytes; 1914 1915 // Integers of arbitrary width 1916 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1917 APInt Val = CI->getValue(); 1918 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1919 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1920 aggBuffer->addBytes(&Byte, 1, 1); 1921 Val.lshrInPlace(8); 1922 } 1923 return; 1924 } 1925 1926 // Old constants 1927 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1928 if (CPV->getNumOperands()) 1929 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1930 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1931 return; 1932 } 1933 1934 if (const ConstantDataSequential *CDS = 1935 dyn_cast<ConstantDataSequential>(CPV)) { 1936 if (CDS->getNumElements()) 1937 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1938 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1939 aggBuffer); 1940 return; 1941 } 1942 1943 if (isa<ConstantStruct>(CPV)) { 1944 if (CPV->getNumOperands()) { 1945 StructType *ST = cast<StructType>(CPV->getType()); 1946 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1947 if (i == (e - 1)) 1948 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1949 DL.getTypeAllocSize(ST) - 1950 DL.getStructLayout(ST)->getElementOffset(i); 1951 else 1952 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1953 DL.getStructLayout(ST)->getElementOffset(i); 1954 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1955 } 1956 } 1957 return; 1958 } 1959 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1960 } 1961 1962 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1963 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1964 /// expressions that are representable in PTX and create 1965 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1966 const MCExpr * 1967 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1968 MCContext &Ctx = OutContext; 1969 1970 if (CV->isNullValue() || isa<UndefValue>(CV)) 1971 return MCConstantExpr::create(0, Ctx); 1972 1973 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1974 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1975 1976 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1977 const MCSymbolRefExpr *Expr = 1978 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1979 if (ProcessingGeneric) { 1980 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1981 } else { 1982 return Expr; 1983 } 1984 } 1985 1986 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 1987 if (!CE) { 1988 llvm_unreachable("Unknown constant value to lower!"); 1989 } 1990 1991 switch (CE->getOpcode()) { 1992 default: { 1993 // If the code isn't optimized, there may be outstanding folding 1994 // opportunities. Attempt to fold the expression using DataLayout as a 1995 // last resort before giving up. 1996 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 1997 if (C != CE) 1998 return lowerConstantForGV(C, ProcessingGeneric); 1999 2000 // Otherwise report the problem to the user. 2001 std::string S; 2002 raw_string_ostream OS(S); 2003 OS << "Unsupported expression in static initializer: "; 2004 CE->printAsOperand(OS, /*PrintType=*/false, 2005 !MF ? nullptr : MF->getFunction().getParent()); 2006 report_fatal_error(Twine(OS.str())); 2007 } 2008 2009 case Instruction::AddrSpaceCast: { 2010 // Strip the addrspacecast and pass along the operand 2011 PointerType *DstTy = cast<PointerType>(CE->getType()); 2012 if (DstTy->getAddressSpace() == 0) { 2013 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 2014 } 2015 std::string S; 2016 raw_string_ostream OS(S); 2017 OS << "Unsupported expression in static initializer: "; 2018 CE->printAsOperand(OS, /*PrintType=*/ false, 2019 !MF ? nullptr : MF->getFunction().getParent()); 2020 report_fatal_error(Twine(OS.str())); 2021 } 2022 2023 case Instruction::GetElementPtr: { 2024 const DataLayout &DL = getDataLayout(); 2025 2026 // Generate a symbolic expression for the byte address 2027 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 2028 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 2029 2030 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2031 ProcessingGeneric); 2032 if (!OffsetAI) 2033 return Base; 2034 2035 int64_t Offset = OffsetAI.getSExtValue(); 2036 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2037 Ctx); 2038 } 2039 2040 case Instruction::Trunc: 2041 // We emit the value and depend on the assembler to truncate the generated 2042 // expression properly. This is important for differences between 2043 // blockaddress labels. Since the two labels are in the same function, it 2044 // is reasonable to treat their delta as a 32-bit value. 2045 [[fallthrough]]; 2046 case Instruction::BitCast: 2047 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2048 2049 case Instruction::IntToPtr: { 2050 const DataLayout &DL = getDataLayout(); 2051 2052 // Handle casts to pointers by changing them into casts to the appropriate 2053 // integer type. This promotes constant folding and simplifies this code. 2054 Constant *Op = CE->getOperand(0); 2055 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2056 false/*ZExt*/); 2057 return lowerConstantForGV(Op, ProcessingGeneric); 2058 } 2059 2060 case Instruction::PtrToInt: { 2061 const DataLayout &DL = getDataLayout(); 2062 2063 // Support only foldable casts to/from pointers that can be eliminated by 2064 // changing the pointer to the appropriately sized integer type. 2065 Constant *Op = CE->getOperand(0); 2066 Type *Ty = CE->getType(); 2067 2068 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2069 2070 // We can emit the pointer value into this slot if the slot is an 2071 // integer slot equal to the size of the pointer. 2072 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2073 return OpExpr; 2074 2075 // Otherwise the pointer is smaller than the resultant integer, mask off 2076 // the high bits so we are sure to get a proper truncation if the input is 2077 // a constant expr. 2078 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2079 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2080 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2081 } 2082 2083 // The MC library also has a right-shift operator, but it isn't consistently 2084 // signed or unsigned between different targets. 2085 case Instruction::Add: { 2086 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2087 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2088 switch (CE->getOpcode()) { 2089 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2090 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2091 } 2092 } 2093 } 2094 } 2095 2096 // Copy of MCExpr::print customized for NVPTX 2097 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2098 switch (Expr.getKind()) { 2099 case MCExpr::Target: 2100 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2101 case MCExpr::Constant: 2102 OS << cast<MCConstantExpr>(Expr).getValue(); 2103 return; 2104 2105 case MCExpr::SymbolRef: { 2106 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2107 const MCSymbol &Sym = SRE.getSymbol(); 2108 Sym.print(OS, MAI); 2109 return; 2110 } 2111 2112 case MCExpr::Unary: { 2113 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2114 switch (UE.getOpcode()) { 2115 case MCUnaryExpr::LNot: OS << '!'; break; 2116 case MCUnaryExpr::Minus: OS << '-'; break; 2117 case MCUnaryExpr::Not: OS << '~'; break; 2118 case MCUnaryExpr::Plus: OS << '+'; break; 2119 } 2120 printMCExpr(*UE.getSubExpr(), OS); 2121 return; 2122 } 2123 2124 case MCExpr::Binary: { 2125 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2126 2127 // Only print parens around the LHS if it is non-trivial. 2128 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2129 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2130 printMCExpr(*BE.getLHS(), OS); 2131 } else { 2132 OS << '('; 2133 printMCExpr(*BE.getLHS(), OS); 2134 OS<< ')'; 2135 } 2136 2137 switch (BE.getOpcode()) { 2138 case MCBinaryExpr::Add: 2139 // Print "X-42" instead of "X+-42". 2140 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2141 if (RHSC->getValue() < 0) { 2142 OS << RHSC->getValue(); 2143 return; 2144 } 2145 } 2146 2147 OS << '+'; 2148 break; 2149 default: llvm_unreachable("Unhandled binary operator"); 2150 } 2151 2152 // Only print parens around the LHS if it is non-trivial. 2153 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2154 printMCExpr(*BE.getRHS(), OS); 2155 } else { 2156 OS << '('; 2157 printMCExpr(*BE.getRHS(), OS); 2158 OS << ')'; 2159 } 2160 return; 2161 } 2162 } 2163 2164 llvm_unreachable("Invalid expression kind!"); 2165 } 2166 2167 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2168 /// 2169 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2170 const char *ExtraCode, raw_ostream &O) { 2171 if (ExtraCode && ExtraCode[0]) { 2172 if (ExtraCode[1] != 0) 2173 return true; // Unknown modifier. 2174 2175 switch (ExtraCode[0]) { 2176 default: 2177 // See if this is a generic print operand 2178 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2179 case 'r': 2180 break; 2181 } 2182 } 2183 2184 printOperand(MI, OpNo, O); 2185 2186 return false; 2187 } 2188 2189 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2190 unsigned OpNo, 2191 const char *ExtraCode, 2192 raw_ostream &O) { 2193 if (ExtraCode && ExtraCode[0]) 2194 return true; // Unknown modifier 2195 2196 O << '['; 2197 printMemOperand(MI, OpNo, O); 2198 O << ']'; 2199 2200 return false; 2201 } 2202 2203 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, 2204 raw_ostream &O) { 2205 const MachineOperand &MO = MI->getOperand(opNum); 2206 switch (MO.getType()) { 2207 case MachineOperand::MO_Register: 2208 if (MO.getReg().isPhysical()) { 2209 if (MO.getReg() == NVPTX::VRDepot) 2210 O << DEPOTNAME << getFunctionNumber(); 2211 else 2212 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2213 } else { 2214 emitVirtualRegister(MO.getReg(), O); 2215 } 2216 break; 2217 2218 case MachineOperand::MO_Immediate: 2219 O << MO.getImm(); 2220 break; 2221 2222 case MachineOperand::MO_FPImmediate: 2223 printFPConstant(MO.getFPImm(), O); 2224 break; 2225 2226 case MachineOperand::MO_GlobalAddress: 2227 PrintSymbolOperand(MO, O); 2228 break; 2229 2230 case MachineOperand::MO_MachineBasicBlock: 2231 MO.getMBB()->getSymbol()->print(O, MAI); 2232 break; 2233 2234 default: 2235 llvm_unreachable("Operand type not supported."); 2236 } 2237 } 2238 2239 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, 2240 raw_ostream &O, const char *Modifier) { 2241 printOperand(MI, opNum, O); 2242 2243 if (Modifier && strcmp(Modifier, "add") == 0) { 2244 O << ", "; 2245 printOperand(MI, opNum + 1, O); 2246 } else { 2247 if (MI->getOperand(opNum + 1).isImm() && 2248 MI->getOperand(opNum + 1).getImm() == 0) 2249 return; // don't print ',0' or '+0' 2250 O << "+"; 2251 printOperand(MI, opNum + 1, O); 2252 } 2253 } 2254 2255 // Force static initialization. 2256 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2257 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2258 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2259 } 2260