1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Twine.h" 37 #include "llvm/Analysis/ConstantFolding.h" 38 #include "llvm/CodeGen/Analysis.h" 39 #include "llvm/CodeGen/MachineBasicBlock.h" 40 #include "llvm/CodeGen/MachineFrameInfo.h" 41 #include "llvm/CodeGen/MachineFunction.h" 42 #include "llvm/CodeGen/MachineInstr.h" 43 #include "llvm/CodeGen/MachineLoopInfo.h" 44 #include "llvm/CodeGen/MachineModuleInfo.h" 45 #include "llvm/CodeGen/MachineOperand.h" 46 #include "llvm/CodeGen/MachineRegisterInfo.h" 47 #include "llvm/CodeGen/MachineValueType.h" 48 #include "llvm/CodeGen/TargetRegisterInfo.h" 49 #include "llvm/CodeGen/ValueTypes.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalValue.h" 61 #include "llvm/IR/GlobalVariable.h" 62 #include "llvm/IR/Instruction.h" 63 #include "llvm/IR/LLVMContext.h" 64 #include "llvm/IR/Module.h" 65 #include "llvm/IR/Operator.h" 66 #include "llvm/IR/Type.h" 67 #include "llvm/IR/User.h" 68 #include "llvm/MC/MCExpr.h" 69 #include "llvm/MC/MCInst.h" 70 #include "llvm/MC/MCInstrDesc.h" 71 #include "llvm/MC/MCStreamer.h" 72 #include "llvm/MC/MCSymbol.h" 73 #include "llvm/MC/TargetRegistry.h" 74 #include "llvm/Support/Casting.h" 75 #include "llvm/Support/CommandLine.h" 76 #include "llvm/Support/Endian.h" 77 #include "llvm/Support/ErrorHandling.h" 78 #include "llvm/Support/NativeFormatting.h" 79 #include "llvm/Support/Path.h" 80 #include "llvm/Support/raw_ostream.h" 81 #include "llvm/Target/TargetLoweringObjectFile.h" 82 #include "llvm/Target/TargetMachine.h" 83 #include "llvm/TargetParser/Triple.h" 84 #include "llvm/Transforms/Utils/UnrollLoop.h" 85 #include <cassert> 86 #include <cstdint> 87 #include <cstring> 88 #include <new> 89 #include <string> 90 #include <utility> 91 #include <vector> 92 93 using namespace llvm; 94 95 static cl::opt<bool> 96 LowerCtorDtor("nvptx-lower-global-ctor-dtor", 97 cl::desc("Lower GPU ctor / dtors to globals on the device."), 98 cl::init(false), cl::Hidden); 99 100 #define DEPOTNAME "__local_depot" 101 102 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 103 /// depends. 104 static void 105 DiscoverDependentGlobals(const Value *V, 106 DenseSet<const GlobalVariable *> &Globals) { 107 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 108 Globals.insert(GV); 109 else { 110 if (const User *U = dyn_cast<User>(V)) { 111 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 112 DiscoverDependentGlobals(U->getOperand(i), Globals); 113 } 114 } 115 } 116 } 117 118 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 119 /// instances to be emitted, but only after any dependents have been added 120 /// first.s 121 static void 122 VisitGlobalVariableForEmission(const GlobalVariable *GV, 123 SmallVectorImpl<const GlobalVariable *> &Order, 124 DenseSet<const GlobalVariable *> &Visited, 125 DenseSet<const GlobalVariable *> &Visiting) { 126 // Have we already visited this one? 127 if (Visited.count(GV)) 128 return; 129 130 // Do we have a circular dependency? 131 if (!Visiting.insert(GV).second) 132 report_fatal_error("Circular dependency found in global variable set"); 133 134 // Make sure we visit all dependents first 135 DenseSet<const GlobalVariable *> Others; 136 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 137 DiscoverDependentGlobals(GV->getOperand(i), Others); 138 139 for (const GlobalVariable *GV : Others) 140 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 141 142 // Now we can visit ourself 143 Order.push_back(GV); 144 Visited.insert(GV); 145 Visiting.erase(GV); 146 } 147 148 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 149 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 150 getSubtargetInfo().getFeatureBits()); 151 152 MCInst Inst; 153 lowerToMCInst(MI, Inst); 154 EmitToStreamer(*OutStreamer, Inst); 155 } 156 157 // Handle symbol backtracking for targets that do not support image handles 158 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 159 unsigned OpNo, MCOperand &MCOp) { 160 const MachineOperand &MO = MI->getOperand(OpNo); 161 const MCInstrDesc &MCID = MI->getDesc(); 162 163 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 164 // This is a texture fetch, so operand 4 is a texref and operand 5 is 165 // a samplerref 166 if (OpNo == 4 && MO.isImm()) { 167 lowerImageHandleSymbol(MO.getImm(), MCOp); 168 return true; 169 } 170 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 171 lowerImageHandleSymbol(MO.getImm(), MCOp); 172 return true; 173 } 174 175 return false; 176 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 177 unsigned VecSize = 178 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 179 180 // For a surface load of vector size N, the Nth operand will be the surfref 181 if (OpNo == VecSize && MO.isImm()) { 182 lowerImageHandleSymbol(MO.getImm(), MCOp); 183 return true; 184 } 185 186 return false; 187 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 188 // This is a surface store, so operand 0 is a surfref 189 if (OpNo == 0 && MO.isImm()) { 190 lowerImageHandleSymbol(MO.getImm(), MCOp); 191 return true; 192 } 193 194 return false; 195 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 196 // This is a query, so operand 1 is a surfref/texref 197 if (OpNo == 1 && MO.isImm()) { 198 lowerImageHandleSymbol(MO.getImm(), MCOp); 199 return true; 200 } 201 202 return false; 203 } 204 205 return false; 206 } 207 208 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 209 // Ewwww 210 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 211 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 212 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 213 const char *Sym = MFI->getImageHandleSymbol(Index); 214 StringRef SymName = nvTM.getStrPool().save(Sym); 215 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); 216 } 217 218 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 219 OutMI.setOpcode(MI->getOpcode()); 220 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 221 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 222 const MachineOperand &MO = MI->getOperand(0); 223 OutMI.addOperand(GetSymbolRef( 224 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 225 return; 226 } 227 228 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 229 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 230 const MachineOperand &MO = MI->getOperand(i); 231 232 MCOperand MCOp; 233 if (!STI.hasImageHandles()) { 234 if (lowerImageHandleOperand(MI, i, MCOp)) { 235 OutMI.addOperand(MCOp); 236 continue; 237 } 238 } 239 240 if (lowerOperand(MO, MCOp)) 241 OutMI.addOperand(MCOp); 242 } 243 } 244 245 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 246 MCOperand &MCOp) { 247 switch (MO.getType()) { 248 default: llvm_unreachable("unknown operand type"); 249 case MachineOperand::MO_Register: 250 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 251 break; 252 case MachineOperand::MO_Immediate: 253 MCOp = MCOperand::createImm(MO.getImm()); 254 break; 255 case MachineOperand::MO_MachineBasicBlock: 256 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 257 MO.getMBB()->getSymbol(), OutContext)); 258 break; 259 case MachineOperand::MO_ExternalSymbol: 260 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 261 break; 262 case MachineOperand::MO_GlobalAddress: 263 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 264 break; 265 case MachineOperand::MO_FPImmediate: { 266 const ConstantFP *Cnt = MO.getFPImm(); 267 const APFloat &Val = Cnt->getValueAPF(); 268 269 switch (Cnt->getType()->getTypeID()) { 270 default: report_fatal_error("Unsupported FP type"); break; 271 case Type::HalfTyID: 272 MCOp = MCOperand::createExpr( 273 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 274 break; 275 case Type::BFloatTyID: 276 MCOp = MCOperand::createExpr( 277 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); 278 break; 279 case Type::FloatTyID: 280 MCOp = MCOperand::createExpr( 281 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 282 break; 283 case Type::DoubleTyID: 284 MCOp = MCOperand::createExpr( 285 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 286 break; 287 } 288 break; 289 } 290 } 291 return true; 292 } 293 294 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 295 if (Register::isVirtualRegister(Reg)) { 296 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 297 298 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 299 unsigned RegNum = RegMap[Reg]; 300 301 // Encode the register class in the upper 4 bits 302 // Must be kept in sync with NVPTXInstPrinter::printRegName 303 unsigned Ret = 0; 304 if (RC == &NVPTX::Int1RegsRegClass) { 305 Ret = (1 << 28); 306 } else if (RC == &NVPTX::Int16RegsRegClass) { 307 Ret = (2 << 28); 308 } else if (RC == &NVPTX::Int32RegsRegClass) { 309 Ret = (3 << 28); 310 } else if (RC == &NVPTX::Int64RegsRegClass) { 311 Ret = (4 << 28); 312 } else if (RC == &NVPTX::Float32RegsRegClass) { 313 Ret = (5 << 28); 314 } else if (RC == &NVPTX::Float64RegsRegClass) { 315 Ret = (6 << 28); 316 } else { 317 report_fatal_error("Bad register class"); 318 } 319 320 // Insert the vreg number 321 Ret |= (RegNum & 0x0FFFFFFF); 322 return Ret; 323 } else { 324 // Some special-use registers are actually physical registers. 325 // Encode this as the register class ID of 0 and the real register ID. 326 return Reg & 0x0FFFFFFF; 327 } 328 } 329 330 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 331 const MCExpr *Expr; 332 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 333 OutContext); 334 return MCOperand::createExpr(Expr); 335 } 336 337 static bool ShouldPassAsArray(Type *Ty) { 338 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || 339 Ty->isHalfTy() || Ty->isBFloatTy(); 340 } 341 342 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 343 const DataLayout &DL = getDataLayout(); 344 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 345 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 346 347 Type *Ty = F->getReturnType(); 348 349 bool isABI = (STI.getSmVersion() >= 20); 350 351 if (Ty->getTypeID() == Type::VoidTyID) 352 return; 353 O << " ("; 354 355 if (isABI) { 356 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) && 357 !ShouldPassAsArray(Ty)) { 358 unsigned size = 0; 359 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 360 size = ITy->getBitWidth(); 361 } else { 362 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 363 size = Ty->getPrimitiveSizeInBits(); 364 } 365 size = promoteScalarArgumentSize(size); 366 O << ".param .b" << size << " func_retval0"; 367 } else if (isa<PointerType>(Ty)) { 368 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 369 << " func_retval0"; 370 } else if (ShouldPassAsArray(Ty)) { 371 unsigned totalsz = DL.getTypeAllocSize(Ty); 372 unsigned retAlignment = 0; 373 if (!getAlign(*F, 0, retAlignment)) 374 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value(); 375 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz 376 << "]"; 377 } else 378 llvm_unreachable("Unknown return type"); 379 } else { 380 SmallVector<EVT, 16> vtparts; 381 ComputeValueVTs(*TLI, DL, Ty, vtparts); 382 unsigned idx = 0; 383 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 384 unsigned elems = 1; 385 EVT elemtype = vtparts[i]; 386 if (vtparts[i].isVector()) { 387 elems = vtparts[i].getVectorNumElements(); 388 elemtype = vtparts[i].getVectorElementType(); 389 } 390 391 for (unsigned j = 0, je = elems; j != je; ++j) { 392 unsigned sz = elemtype.getSizeInBits(); 393 if (elemtype.isInteger()) 394 sz = promoteScalarArgumentSize(sz); 395 O << ".reg .b" << sz << " func_retval" << idx; 396 if (j < je - 1) 397 O << ", "; 398 ++idx; 399 } 400 if (i < e - 1) 401 O << ", "; 402 } 403 } 404 O << ") "; 405 } 406 407 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 408 raw_ostream &O) { 409 const Function &F = MF.getFunction(); 410 printReturnValStr(&F, O); 411 } 412 413 // Return true if MBB is the header of a loop marked with 414 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1. 415 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 416 const MachineBasicBlock &MBB) const { 417 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 418 // We insert .pragma "nounroll" only to the loop header. 419 if (!LI.isLoopHeader(&MBB)) 420 return false; 421 422 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 423 // we iterate through each back edge of the loop with header MBB, and check 424 // whether its metadata contains llvm.loop.unroll.disable. 425 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 426 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 427 // Edges from other loops to MBB are not back edges. 428 continue; 429 } 430 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 431 if (MDNode *LoopID = 432 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 433 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 434 return true; 435 if (MDNode *UnrollCountMD = 436 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) { 437 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1)) 438 ->isOne()) 439 return true; 440 } 441 } 442 } 443 } 444 return false; 445 } 446 447 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 448 AsmPrinter::emitBasicBlockStart(MBB); 449 if (isLoopHeaderOfNoUnroll(MBB)) 450 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 451 } 452 453 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 454 SmallString<128> Str; 455 raw_svector_ostream O(Str); 456 457 if (!GlobalsEmitted) { 458 emitGlobals(*MF->getFunction().getParent()); 459 GlobalsEmitted = true; 460 } 461 462 // Set up 463 MRI = &MF->getRegInfo(); 464 F = &MF->getFunction(); 465 emitLinkageDirective(F, O); 466 if (isKernelFunction(*F)) 467 O << ".entry "; 468 else { 469 O << ".func "; 470 printReturnValStr(*MF, O); 471 } 472 473 CurrentFnSym->print(O, MAI); 474 475 emitFunctionParamList(F, O); 476 O << "\n"; 477 478 if (isKernelFunction(*F)) 479 emitKernelFunctionDirectives(*F, O); 480 481 if (shouldEmitPTXNoReturn(F, TM)) 482 O << ".noreturn"; 483 484 OutStreamer->emitRawText(O.str()); 485 486 VRegMapping.clear(); 487 // Emit open brace for function body. 488 OutStreamer->emitRawText(StringRef("{\n")); 489 setAndEmitFunctionVirtualRegisters(*MF); 490 // Emit initial .loc debug directive for correct relocation symbol data. 491 if (MMI && MMI->hasDebugInfo()) 492 emitInitialRawDwarfLocDirective(*MF); 493 } 494 495 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 496 bool Result = AsmPrinter::runOnMachineFunction(F); 497 // Emit closing brace for the body of function F. 498 // The closing brace must be emitted here because we need to emit additional 499 // debug labels/data after the last basic block. 500 // We need to emit the closing brace here because we don't have function that 501 // finished emission of the function body. 502 OutStreamer->emitRawText(StringRef("}\n")); 503 return Result; 504 } 505 506 void NVPTXAsmPrinter::emitFunctionBodyStart() { 507 SmallString<128> Str; 508 raw_svector_ostream O(Str); 509 emitDemotedVars(&MF->getFunction(), O); 510 OutStreamer->emitRawText(O.str()); 511 } 512 513 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 514 VRegMapping.clear(); 515 } 516 517 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 518 SmallString<128> Str; 519 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 520 return OutContext.getOrCreateSymbol(Str); 521 } 522 523 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 524 Register RegNo = MI->getOperand(0).getReg(); 525 if (RegNo.isVirtual()) { 526 OutStreamer->AddComment(Twine("implicit-def: ") + 527 getVirtualRegisterName(RegNo)); 528 } else { 529 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 530 OutStreamer->AddComment(Twine("implicit-def: ") + 531 STI.getRegisterInfo()->getName(RegNo)); 532 } 533 OutStreamer->addBlankLine(); 534 } 535 536 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 537 raw_ostream &O) const { 538 // If the NVVM IR has some of reqntid* specified, then output 539 // the reqntid directive, and set the unspecified ones to 1. 540 // If none of Reqntid* is specified, don't output reqntid directive. 541 unsigned Reqntidx, Reqntidy, Reqntidz; 542 Reqntidx = Reqntidy = Reqntidz = 1; 543 bool ReqSpecified = false; 544 ReqSpecified |= getReqNTIDx(F, Reqntidx); 545 ReqSpecified |= getReqNTIDy(F, Reqntidy); 546 ReqSpecified |= getReqNTIDz(F, Reqntidz); 547 548 if (ReqSpecified) 549 O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz 550 << "\n"; 551 552 // If the NVVM IR has some of maxntid* specified, then output 553 // the maxntid directive, and set the unspecified ones to 1. 554 // If none of maxntid* is specified, don't output maxntid directive. 555 unsigned Maxntidx, Maxntidy, Maxntidz; 556 Maxntidx = Maxntidy = Maxntidz = 1; 557 bool MaxSpecified = false; 558 MaxSpecified |= getMaxNTIDx(F, Maxntidx); 559 MaxSpecified |= getMaxNTIDy(F, Maxntidy); 560 MaxSpecified |= getMaxNTIDz(F, Maxntidz); 561 562 if (MaxSpecified) 563 O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz 564 << "\n"; 565 566 unsigned Mincta = 0; 567 if (getMinCTASm(F, Mincta)) 568 O << ".minnctapersm " << Mincta << "\n"; 569 570 unsigned Maxnreg = 0; 571 if (getMaxNReg(F, Maxnreg)) 572 O << ".maxnreg " << Maxnreg << "\n"; 573 574 // .maxclusterrank directive requires SM_90 or higher, make sure that we 575 // filter it out for lower SM versions, as it causes a hard ptxas crash. 576 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 577 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 578 unsigned Maxclusterrank = 0; 579 if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90) 580 O << ".maxclusterrank " << Maxclusterrank << "\n"; 581 } 582 583 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 584 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 585 586 std::string Name; 587 raw_string_ostream NameStr(Name); 588 589 VRegRCMap::const_iterator I = VRegMapping.find(RC); 590 assert(I != VRegMapping.end() && "Bad register class"); 591 const DenseMap<unsigned, unsigned> &RegMap = I->second; 592 593 VRegMap::const_iterator VI = RegMap.find(Reg); 594 assert(VI != RegMap.end() && "Bad virtual register"); 595 unsigned MappedVR = VI->second; 596 597 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 598 599 NameStr.flush(); 600 return Name; 601 } 602 603 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 604 raw_ostream &O) { 605 O << getVirtualRegisterName(vr); 606 } 607 608 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 609 emitLinkageDirective(F, O); 610 if (isKernelFunction(*F)) 611 O << ".entry "; 612 else 613 O << ".func "; 614 printReturnValStr(F, O); 615 getSymbol(F)->print(O, MAI); 616 O << "\n"; 617 emitFunctionParamList(F, O); 618 O << "\n"; 619 if (shouldEmitPTXNoReturn(F, TM)) 620 O << ".noreturn"; 621 O << ";\n"; 622 } 623 624 static bool usedInGlobalVarDef(const Constant *C) { 625 if (!C) 626 return false; 627 628 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 629 return GV->getName() != "llvm.used"; 630 } 631 632 for (const User *U : C->users()) 633 if (const Constant *C = dyn_cast<Constant>(U)) 634 if (usedInGlobalVarDef(C)) 635 return true; 636 637 return false; 638 } 639 640 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 641 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 642 if (othergv->getName() == "llvm.used") 643 return true; 644 } 645 646 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 647 if (instr->getParent() && instr->getParent()->getParent()) { 648 const Function *curFunc = instr->getParent()->getParent(); 649 if (oneFunc && (curFunc != oneFunc)) 650 return false; 651 oneFunc = curFunc; 652 return true; 653 } else 654 return false; 655 } 656 657 for (const User *UU : U->users()) 658 if (!usedInOneFunc(UU, oneFunc)) 659 return false; 660 661 return true; 662 } 663 664 /* Find out if a global variable can be demoted to local scope. 665 * Currently, this is valid for CUDA shared variables, which have local 666 * scope and global lifetime. So the conditions to check are : 667 * 1. Is the global variable in shared address space? 668 * 2. Does it have local linkage? 669 * 3. Is the global variable referenced only in one function? 670 */ 671 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 672 if (!gv->hasLocalLinkage()) 673 return false; 674 PointerType *Pty = gv->getType(); 675 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 676 return false; 677 678 const Function *oneFunc = nullptr; 679 680 bool flag = usedInOneFunc(gv, oneFunc); 681 if (!flag) 682 return false; 683 if (!oneFunc) 684 return false; 685 f = oneFunc; 686 return true; 687 } 688 689 static bool useFuncSeen(const Constant *C, 690 DenseMap<const Function *, bool> &seenMap) { 691 for (const User *U : C->users()) { 692 if (const Constant *cu = dyn_cast<Constant>(U)) { 693 if (useFuncSeen(cu, seenMap)) 694 return true; 695 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 696 const BasicBlock *bb = I->getParent(); 697 if (!bb) 698 continue; 699 const Function *caller = bb->getParent(); 700 if (!caller) 701 continue; 702 if (seenMap.contains(caller)) 703 return true; 704 } 705 } 706 return false; 707 } 708 709 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 710 DenseMap<const Function *, bool> seenMap; 711 for (const Function &F : M) { 712 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 713 emitDeclaration(&F, O); 714 continue; 715 } 716 717 if (F.isDeclaration()) { 718 if (F.use_empty()) 719 continue; 720 if (F.getIntrinsicID()) 721 continue; 722 emitDeclaration(&F, O); 723 continue; 724 } 725 for (const User *U : F.users()) { 726 if (const Constant *C = dyn_cast<Constant>(U)) { 727 if (usedInGlobalVarDef(C)) { 728 // The use is in the initialization of a global variable 729 // that is a function pointer, so print a declaration 730 // for the original function 731 emitDeclaration(&F, O); 732 break; 733 } 734 // Emit a declaration of this function if the function that 735 // uses this constant expr has already been seen. 736 if (useFuncSeen(C, seenMap)) { 737 emitDeclaration(&F, O); 738 break; 739 } 740 } 741 742 if (!isa<Instruction>(U)) 743 continue; 744 const Instruction *instr = cast<Instruction>(U); 745 const BasicBlock *bb = instr->getParent(); 746 if (!bb) 747 continue; 748 const Function *caller = bb->getParent(); 749 if (!caller) 750 continue; 751 752 // If a caller has already been seen, then the caller is 753 // appearing in the module before the callee. so print out 754 // a declaration for the callee. 755 if (seenMap.contains(caller)) { 756 emitDeclaration(&F, O); 757 break; 758 } 759 } 760 seenMap[&F] = true; 761 } 762 } 763 764 static bool isEmptyXXStructor(GlobalVariable *GV) { 765 if (!GV) return true; 766 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 767 if (!InitList) return true; // Not an array; we don't know how to parse. 768 return InitList->getNumOperands() == 0; 769 } 770 771 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 772 // Construct a default subtarget off of the TargetMachine defaults. The 773 // rest of NVPTX isn't friendly to change subtargets per function and 774 // so the default TargetMachine will have all of the options. 775 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 776 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 777 SmallString<128> Str1; 778 raw_svector_ostream OS1(Str1); 779 780 // Emit header before any dwarf directives are emitted below. 781 emitHeader(M, OS1, *STI); 782 OutStreamer->emitRawText(OS1.str()); 783 } 784 785 bool NVPTXAsmPrinter::doInitialization(Module &M) { 786 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 787 const NVPTXSubtarget &STI = 788 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 789 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30)) 790 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30"); 791 792 // OpenMP supports NVPTX global constructors and destructors. 793 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr; 794 795 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) && 796 !LowerCtorDtor && !IsOpenMP) { 797 report_fatal_error( 798 "Module has a nontrivial global ctor, which NVPTX does not support."); 799 return true; // error 800 } 801 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) && 802 !LowerCtorDtor && !IsOpenMP) { 803 report_fatal_error( 804 "Module has a nontrivial global dtor, which NVPTX does not support."); 805 return true; // error 806 } 807 808 // We need to call the parent's one explicitly. 809 bool Result = AsmPrinter::doInitialization(M); 810 811 GlobalsEmitted = false; 812 813 return Result; 814 } 815 816 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 817 SmallString<128> Str2; 818 raw_svector_ostream OS2(Str2); 819 820 emitDeclarations(M, OS2); 821 822 // As ptxas does not support forward references of globals, we need to first 823 // sort the list of module-level globals in def-use order. We visit each 824 // global variable in order, and ensure that we emit it *after* its dependent 825 // globals. We use a little extra memory maintaining both a set and a list to 826 // have fast searches while maintaining a strict ordering. 827 SmallVector<const GlobalVariable *, 8> Globals; 828 DenseSet<const GlobalVariable *> GVVisited; 829 DenseSet<const GlobalVariable *> GVVisiting; 830 831 // Visit each global variable, in order 832 for (const GlobalVariable &I : M.globals()) 833 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 834 835 assert(GVVisited.size() == M.global_size() && "Missed a global variable"); 836 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 837 838 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 839 const NVPTXSubtarget &STI = 840 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 841 842 // Print out module-level global variables in proper order 843 for (unsigned i = 0, e = Globals.size(); i != e; ++i) 844 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI); 845 846 OS2 << '\n'; 847 848 OutStreamer->emitRawText(OS2.str()); 849 } 850 851 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) { 852 SmallString<128> Str; 853 raw_svector_ostream OS(Str); 854 855 MCSymbol *Name = getSymbol(&GA); 856 const Function *F = dyn_cast<Function>(GA.getAliasee()); 857 if (!F || isKernelFunction(*F)) 858 report_fatal_error("NVPTX aliasee must be a non-kernel function"); 859 860 if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() || 861 GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage()) 862 report_fatal_error("NVPTX aliasee must not be '.weak'"); 863 864 OS << "\n"; 865 emitLinkageDirective(F, OS); 866 OS << ".func "; 867 printReturnValStr(F, OS); 868 OS << Name->getName(); 869 emitFunctionParamList(F, OS); 870 if (shouldEmitPTXNoReturn(F, TM)) 871 OS << "\n.noreturn"; 872 OS << ";\n"; 873 874 OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n"; 875 876 OutStreamer->emitRawText(OS.str()); 877 } 878 879 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 880 const NVPTXSubtarget &STI) { 881 O << "//\n"; 882 O << "// Generated by LLVM NVPTX Back-End\n"; 883 O << "//\n"; 884 O << "\n"; 885 886 unsigned PTXVersion = STI.getPTXVersion(); 887 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 888 889 O << ".target "; 890 O << STI.getTargetName(); 891 892 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 893 if (NTM.getDrvInterface() == NVPTX::NVCL) 894 O << ", texmode_independent"; 895 896 bool HasFullDebugInfo = false; 897 for (DICompileUnit *CU : M.debug_compile_units()) { 898 switch(CU->getEmissionKind()) { 899 case DICompileUnit::NoDebug: 900 case DICompileUnit::DebugDirectivesOnly: 901 break; 902 case DICompileUnit::LineTablesOnly: 903 case DICompileUnit::FullDebug: 904 HasFullDebugInfo = true; 905 break; 906 } 907 if (HasFullDebugInfo) 908 break; 909 } 910 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 911 O << ", debug"; 912 913 O << "\n"; 914 915 O << ".address_size "; 916 if (NTM.is64Bit()) 917 O << "64"; 918 else 919 O << "32"; 920 O << "\n"; 921 922 O << "\n"; 923 } 924 925 bool NVPTXAsmPrinter::doFinalization(Module &M) { 926 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 927 928 // If we did not emit any functions, then the global declarations have not 929 // yet been emitted. 930 if (!GlobalsEmitted) { 931 emitGlobals(M); 932 GlobalsEmitted = true; 933 } 934 935 // If we have any aliases we emit them at the end. 936 SmallVector<GlobalAlias *> AliasesToRemove; 937 for (GlobalAlias &Alias : M.aliases()) { 938 emitGlobalAlias(M, Alias); 939 AliasesToRemove.push_back(&Alias); 940 } 941 942 for (GlobalAlias *A : AliasesToRemove) 943 A->eraseFromParent(); 944 945 // call doFinalization 946 bool ret = AsmPrinter::doFinalization(M); 947 948 clearAnnotationCache(&M); 949 950 auto *TS = 951 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()); 952 // Close the last emitted section 953 if (HasDebugInfo) { 954 TS->closeLastSection(); 955 // Emit empty .debug_loc section for better support of the empty files. 956 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 957 } 958 959 // Output last DWARF .file directives, if any. 960 TS->outputDwarfFileDirectives(); 961 962 return ret; 963 } 964 965 // This function emits appropriate linkage directives for 966 // functions and global variables. 967 // 968 // extern function declaration -> .extern 969 // extern function definition -> .visible 970 // external global variable with init -> .visible 971 // external without init -> .extern 972 // appending -> not allowed, assert. 973 // for any linkage other than 974 // internal, private, linker_private, 975 // linker_private_weak, linker_private_weak_def_auto, 976 // we emit -> .weak. 977 978 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 979 raw_ostream &O) { 980 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 981 if (V->hasExternalLinkage()) { 982 if (isa<GlobalVariable>(V)) { 983 const GlobalVariable *GVar = cast<GlobalVariable>(V); 984 if (GVar) { 985 if (GVar->hasInitializer()) 986 O << ".visible "; 987 else 988 O << ".extern "; 989 } 990 } else if (V->isDeclaration()) 991 O << ".extern "; 992 else 993 O << ".visible "; 994 } else if (V->hasAppendingLinkage()) { 995 std::string msg; 996 msg.append("Error: "); 997 msg.append("Symbol "); 998 if (V->hasName()) 999 msg.append(std::string(V->getName())); 1000 msg.append("has unsupported appending linkage type"); 1001 llvm_unreachable(msg.c_str()); 1002 } else if (!V->hasInternalLinkage() && 1003 !V->hasPrivateLinkage()) { 1004 O << ".weak "; 1005 } 1006 } 1007 } 1008 1009 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 1010 raw_ostream &O, bool processDemoted, 1011 const NVPTXSubtarget &STI) { 1012 // Skip meta data 1013 if (GVar->hasSection()) { 1014 if (GVar->getSection() == "llvm.metadata") 1015 return; 1016 } 1017 1018 // Skip LLVM intrinsic global variables 1019 if (GVar->getName().starts_with("llvm.") || 1020 GVar->getName().starts_with("nvvm.")) 1021 return; 1022 1023 const DataLayout &DL = getDataLayout(); 1024 1025 // GlobalVariables are always constant pointers themselves. 1026 PointerType *PTy = GVar->getType(); 1027 Type *ETy = GVar->getValueType(); 1028 1029 if (GVar->hasExternalLinkage()) { 1030 if (GVar->hasInitializer()) 1031 O << ".visible "; 1032 else 1033 O << ".extern "; 1034 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 1035 GVar->hasAvailableExternallyLinkage() || 1036 GVar->hasCommonLinkage()) { 1037 O << ".weak "; 1038 } 1039 1040 if (isTexture(*GVar)) { 1041 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1042 return; 1043 } 1044 1045 if (isSurface(*GVar)) { 1046 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1047 return; 1048 } 1049 1050 if (GVar->isDeclaration()) { 1051 // (extern) declarations, no definition or initializer 1052 // Currently the only known declaration is for an automatic __local 1053 // (.shared) promoted to global. 1054 emitPTXGlobalVariable(GVar, O, STI); 1055 O << ";\n"; 1056 return; 1057 } 1058 1059 if (isSampler(*GVar)) { 1060 O << ".global .samplerref " << getSamplerName(*GVar); 1061 1062 const Constant *Initializer = nullptr; 1063 if (GVar->hasInitializer()) 1064 Initializer = GVar->getInitializer(); 1065 const ConstantInt *CI = nullptr; 1066 if (Initializer) 1067 CI = dyn_cast<ConstantInt>(Initializer); 1068 if (CI) { 1069 unsigned sample = CI->getZExtValue(); 1070 1071 O << " = { "; 1072 1073 for (int i = 0, 1074 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1075 i < 3; i++) { 1076 O << "addr_mode_" << i << " = "; 1077 switch (addr) { 1078 case 0: 1079 O << "wrap"; 1080 break; 1081 case 1: 1082 O << "clamp_to_border"; 1083 break; 1084 case 2: 1085 O << "clamp_to_edge"; 1086 break; 1087 case 3: 1088 O << "wrap"; 1089 break; 1090 case 4: 1091 O << "mirror"; 1092 break; 1093 } 1094 O << ", "; 1095 } 1096 O << "filter_mode = "; 1097 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1098 case 0: 1099 O << "nearest"; 1100 break; 1101 case 1: 1102 O << "linear"; 1103 break; 1104 case 2: 1105 llvm_unreachable("Anisotropic filtering is not supported"); 1106 default: 1107 O << "nearest"; 1108 break; 1109 } 1110 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1111 O << ", force_unnormalized_coords = 1"; 1112 } 1113 O << " }"; 1114 } 1115 1116 O << ";\n"; 1117 return; 1118 } 1119 1120 if (GVar->hasPrivateLinkage()) { 1121 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1122 return; 1123 1124 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1125 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1126 return; 1127 if (GVar->use_empty()) 1128 return; 1129 } 1130 1131 const Function *demotedFunc = nullptr; 1132 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1133 O << "// " << GVar->getName() << " has been demoted\n"; 1134 if (localDecls.find(demotedFunc) != localDecls.end()) 1135 localDecls[demotedFunc].push_back(GVar); 1136 else { 1137 std::vector<const GlobalVariable *> temp; 1138 temp.push_back(GVar); 1139 localDecls[demotedFunc] = temp; 1140 } 1141 return; 1142 } 1143 1144 O << "."; 1145 emitPTXAddressSpace(PTy->getAddressSpace(), O); 1146 1147 if (isManaged(*GVar)) { 1148 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1149 report_fatal_error( 1150 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1151 } 1152 O << " .attribute(.managed)"; 1153 } 1154 1155 if (MaybeAlign A = GVar->getAlign()) 1156 O << " .align " << A->value(); 1157 else 1158 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1159 1160 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1161 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1162 O << " ."; 1163 // Special case: ABI requires that we use .u8 for predicates 1164 if (ETy->isIntegerTy(1)) 1165 O << "u8"; 1166 else 1167 O << getPTXFundamentalTypeStr(ETy, false); 1168 O << " "; 1169 getSymbol(GVar)->print(O, MAI); 1170 1171 // Ptx allows variable initilization only for constant and global state 1172 // spaces. 1173 if (GVar->hasInitializer()) { 1174 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1175 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1176 const Constant *Initializer = GVar->getInitializer(); 1177 // 'undef' is treated as there is no value specified. 1178 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1179 O << " = "; 1180 printScalarConstant(Initializer, O); 1181 } 1182 } else { 1183 // The frontend adds zero-initializer to device and constant variables 1184 // that don't have an initial value, and UndefValue to shared 1185 // variables, so skip warning for this case. 1186 if (!GVar->getInitializer()->isNullValue() && 1187 !isa<UndefValue>(GVar->getInitializer())) { 1188 report_fatal_error("initial value of '" + GVar->getName() + 1189 "' is not allowed in addrspace(" + 1190 Twine(PTy->getAddressSpace()) + ")"); 1191 } 1192 } 1193 } 1194 } else { 1195 uint64_t ElementSize = 0; 1196 1197 // Although PTX has direct support for struct type and array type and 1198 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1199 // targets that support these high level field accesses. Structs, arrays 1200 // and vectors are lowered into arrays of bytes. 1201 switch (ETy->getTypeID()) { 1202 case Type::IntegerTyID: // Integers larger than 64 bits 1203 case Type::StructTyID: 1204 case Type::ArrayTyID: 1205 case Type::FixedVectorTyID: 1206 ElementSize = DL.getTypeStoreSize(ETy); 1207 // Ptx allows variable initilization only for constant and 1208 // global state spaces. 1209 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1210 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1211 GVar->hasInitializer()) { 1212 const Constant *Initializer = GVar->getInitializer(); 1213 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1214 AggBuffer aggBuffer(ElementSize, *this); 1215 bufferAggregateConstant(Initializer, &aggBuffer); 1216 if (aggBuffer.numSymbols()) { 1217 unsigned int ptrSize = MAI->getCodePointerSize(); 1218 if (ElementSize % ptrSize || 1219 !aggBuffer.allSymbolsAligned(ptrSize)) { 1220 // Print in bytes and use the mask() operator for pointers. 1221 if (!STI.hasMaskOperator()) 1222 report_fatal_error( 1223 "initialized packed aggregate with pointers '" + 1224 GVar->getName() + 1225 "' requires at least PTX ISA version 7.1"); 1226 O << " .u8 "; 1227 getSymbol(GVar)->print(O, MAI); 1228 O << "[" << ElementSize << "] = {"; 1229 aggBuffer.printBytes(O); 1230 O << "}"; 1231 } else { 1232 O << " .u" << ptrSize * 8 << " "; 1233 getSymbol(GVar)->print(O, MAI); 1234 O << "[" << ElementSize / ptrSize << "] = {"; 1235 aggBuffer.printWords(O); 1236 O << "}"; 1237 } 1238 } else { 1239 O << " .b8 "; 1240 getSymbol(GVar)->print(O, MAI); 1241 O << "[" << ElementSize << "] = {"; 1242 aggBuffer.printBytes(O); 1243 O << "}"; 1244 } 1245 } else { 1246 O << " .b8 "; 1247 getSymbol(GVar)->print(O, MAI); 1248 if (ElementSize) { 1249 O << "["; 1250 O << ElementSize; 1251 O << "]"; 1252 } 1253 } 1254 } else { 1255 O << " .b8 "; 1256 getSymbol(GVar)->print(O, MAI); 1257 if (ElementSize) { 1258 O << "["; 1259 O << ElementSize; 1260 O << "]"; 1261 } 1262 } 1263 break; 1264 default: 1265 llvm_unreachable("type not supported yet"); 1266 } 1267 } 1268 O << ";\n"; 1269 } 1270 1271 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1272 const Value *v = Symbols[nSym]; 1273 const Value *v0 = SymbolsBeforeStripping[nSym]; 1274 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1275 MCSymbol *Name = AP.getSymbol(GVar); 1276 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1277 // Is v0 a generic pointer? 1278 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1279 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1280 os << "generic("; 1281 Name->print(os, AP.MAI); 1282 os << ")"; 1283 } else { 1284 Name->print(os, AP.MAI); 1285 } 1286 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1287 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1288 AP.printMCExpr(*Expr, os); 1289 } else 1290 llvm_unreachable("symbol type unknown"); 1291 } 1292 1293 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1294 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1295 symbolPosInBuffer.push_back(size); 1296 unsigned int nSym = 0; 1297 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1298 for (unsigned int pos = 0; pos < size;) { 1299 if (pos) 1300 os << ", "; 1301 if (pos != nextSymbolPos) { 1302 os << (unsigned int)buffer[pos]; 1303 ++pos; 1304 continue; 1305 } 1306 // Generate a per-byte mask() operator for the symbol, which looks like: 1307 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1308 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1309 std::string symText; 1310 llvm::raw_string_ostream oss(symText); 1311 printSymbol(nSym, oss); 1312 for (unsigned i = 0; i < ptrSize; ++i) { 1313 if (i) 1314 os << ", "; 1315 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1316 os << "(" << symText << ")"; 1317 } 1318 pos += ptrSize; 1319 nextSymbolPos = symbolPosInBuffer[++nSym]; 1320 assert(nextSymbolPos >= pos); 1321 } 1322 } 1323 1324 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1325 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1326 symbolPosInBuffer.push_back(size); 1327 unsigned int nSym = 0; 1328 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1329 assert(nextSymbolPos % ptrSize == 0); 1330 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1331 if (pos) 1332 os << ", "; 1333 if (pos == nextSymbolPos) { 1334 printSymbol(nSym, os); 1335 nextSymbolPos = symbolPosInBuffer[++nSym]; 1336 assert(nextSymbolPos % ptrSize == 0); 1337 assert(nextSymbolPos >= pos + ptrSize); 1338 } else if (ptrSize == 4) 1339 os << support::endian::read32le(&buffer[pos]); 1340 else 1341 os << support::endian::read64le(&buffer[pos]); 1342 } 1343 } 1344 1345 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1346 if (localDecls.find(f) == localDecls.end()) 1347 return; 1348 1349 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1350 1351 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1352 const NVPTXSubtarget &STI = 1353 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1354 1355 for (const GlobalVariable *GV : gvars) { 1356 O << "\t// demoted variable\n\t"; 1357 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1358 } 1359 } 1360 1361 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1362 raw_ostream &O) const { 1363 switch (AddressSpace) { 1364 case ADDRESS_SPACE_LOCAL: 1365 O << "local"; 1366 break; 1367 case ADDRESS_SPACE_GLOBAL: 1368 O << "global"; 1369 break; 1370 case ADDRESS_SPACE_CONST: 1371 O << "const"; 1372 break; 1373 case ADDRESS_SPACE_SHARED: 1374 O << "shared"; 1375 break; 1376 default: 1377 report_fatal_error("Bad address space found while emitting PTX: " + 1378 llvm::Twine(AddressSpace)); 1379 break; 1380 } 1381 } 1382 1383 std::string 1384 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1385 switch (Ty->getTypeID()) { 1386 case Type::IntegerTyID: { 1387 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1388 if (NumBits == 1) 1389 return "pred"; 1390 else if (NumBits <= 64) { 1391 std::string name = "u"; 1392 return name + utostr(NumBits); 1393 } else { 1394 llvm_unreachable("Integer too large"); 1395 break; 1396 } 1397 break; 1398 } 1399 case Type::BFloatTyID: 1400 case Type::HalfTyID: 1401 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 1402 // PTX assembly. 1403 return "b16"; 1404 case Type::FloatTyID: 1405 return "f32"; 1406 case Type::DoubleTyID: 1407 return "f64"; 1408 case Type::PointerTyID: { 1409 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace()); 1410 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size"); 1411 1412 if (PtrSize == 64) 1413 if (useB4PTR) 1414 return "b64"; 1415 else 1416 return "u64"; 1417 else if (useB4PTR) 1418 return "b32"; 1419 else 1420 return "u32"; 1421 } 1422 default: 1423 break; 1424 } 1425 llvm_unreachable("unexpected type"); 1426 } 1427 1428 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1429 raw_ostream &O, 1430 const NVPTXSubtarget &STI) { 1431 const DataLayout &DL = getDataLayout(); 1432 1433 // GlobalVariables are always constant pointers themselves. 1434 Type *ETy = GVar->getValueType(); 1435 1436 O << "."; 1437 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1438 if (isManaged(*GVar)) { 1439 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1440 report_fatal_error( 1441 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1442 } 1443 O << " .attribute(.managed)"; 1444 } 1445 if (MaybeAlign A = GVar->getAlign()) 1446 O << " .align " << A->value(); 1447 else 1448 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1449 1450 // Special case for i128 1451 if (ETy->isIntegerTy(128)) { 1452 O << " .b8 "; 1453 getSymbol(GVar)->print(O, MAI); 1454 O << "[16]"; 1455 return; 1456 } 1457 1458 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1459 O << " ."; 1460 O << getPTXFundamentalTypeStr(ETy); 1461 O << " "; 1462 getSymbol(GVar)->print(O, MAI); 1463 return; 1464 } 1465 1466 int64_t ElementSize = 0; 1467 1468 // Although PTX has direct support for struct type and array type and LLVM IR 1469 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1470 // support these high level field accesses. Structs and arrays are lowered 1471 // into arrays of bytes. 1472 switch (ETy->getTypeID()) { 1473 case Type::StructTyID: 1474 case Type::ArrayTyID: 1475 case Type::FixedVectorTyID: 1476 ElementSize = DL.getTypeStoreSize(ETy); 1477 O << " .b8 "; 1478 getSymbol(GVar)->print(O, MAI); 1479 O << "["; 1480 if (ElementSize) { 1481 O << ElementSize; 1482 } 1483 O << "]"; 1484 break; 1485 default: 1486 llvm_unreachable("type not supported yet"); 1487 } 1488 } 1489 1490 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1491 const DataLayout &DL = getDataLayout(); 1492 const AttributeList &PAL = F->getAttributes(); 1493 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1494 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1495 1496 Function::const_arg_iterator I, E; 1497 unsigned paramIndex = 0; 1498 bool first = true; 1499 bool isKernelFunc = isKernelFunction(*F); 1500 bool isABI = (STI.getSmVersion() >= 20); 1501 bool hasImageHandles = STI.hasImageHandles(); 1502 1503 if (F->arg_empty() && !F->isVarArg()) { 1504 O << "()"; 1505 return; 1506 } 1507 1508 O << "(\n"; 1509 1510 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1511 Type *Ty = I->getType(); 1512 1513 if (!first) 1514 O << ",\n"; 1515 1516 first = false; 1517 1518 // Handle image/sampler parameters 1519 if (isKernelFunction(*F)) { 1520 if (isSampler(*I) || isImage(*I)) { 1521 if (isImage(*I)) { 1522 std::string sname = std::string(I->getName()); 1523 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1524 if (hasImageHandles) 1525 O << "\t.param .u64 .ptr .surfref "; 1526 else 1527 O << "\t.param .surfref "; 1528 O << TLI->getParamName(F, paramIndex); 1529 } 1530 else { // Default image is read_only 1531 if (hasImageHandles) 1532 O << "\t.param .u64 .ptr .texref "; 1533 else 1534 O << "\t.param .texref "; 1535 O << TLI->getParamName(F, paramIndex); 1536 } 1537 } else { 1538 if (hasImageHandles) 1539 O << "\t.param .u64 .ptr .samplerref "; 1540 else 1541 O << "\t.param .samplerref "; 1542 O << TLI->getParamName(F, paramIndex); 1543 } 1544 continue; 1545 } 1546 } 1547 1548 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1549 paramIndex](Type *Ty) -> Align { 1550 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1551 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1552 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1553 }; 1554 1555 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1556 if (ShouldPassAsArray(Ty)) { 1557 // Just print .param .align <a> .b8 .param[size]; 1558 // <a> = optimal alignment for the element type; always multiple of 1559 // PAL.getParamAlignment 1560 // size = typeallocsize of element type 1561 Align OptimalAlign = getOptimalAlignForParam(Ty); 1562 1563 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1564 O << TLI->getParamName(F, paramIndex); 1565 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1566 1567 continue; 1568 } 1569 // Just a scalar 1570 auto *PTy = dyn_cast<PointerType>(Ty); 1571 unsigned PTySizeInBits = 0; 1572 if (PTy) { 1573 PTySizeInBits = 1574 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); 1575 assert(PTySizeInBits && "Invalid pointer size"); 1576 } 1577 1578 if (isKernelFunc) { 1579 if (PTy) { 1580 // Special handling for pointer arguments to kernel 1581 O << "\t.param .u" << PTySizeInBits << " "; 1582 1583 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1584 NVPTX::CUDA) { 1585 int addrSpace = PTy->getAddressSpace(); 1586 switch (addrSpace) { 1587 default: 1588 O << ".ptr "; 1589 break; 1590 case ADDRESS_SPACE_CONST: 1591 O << ".ptr .const "; 1592 break; 1593 case ADDRESS_SPACE_SHARED: 1594 O << ".ptr .shared "; 1595 break; 1596 case ADDRESS_SPACE_GLOBAL: 1597 O << ".ptr .global "; 1598 break; 1599 } 1600 Align ParamAlign = I->getParamAlign().valueOrOne(); 1601 O << ".align " << ParamAlign.value() << " "; 1602 } 1603 O << TLI->getParamName(F, paramIndex); 1604 continue; 1605 } 1606 1607 // non-pointer scalar to kernel func 1608 O << "\t.param ."; 1609 // Special case: predicate operands become .u8 types 1610 if (Ty->isIntegerTy(1)) 1611 O << "u8"; 1612 else 1613 O << getPTXFundamentalTypeStr(Ty); 1614 O << " "; 1615 O << TLI->getParamName(F, paramIndex); 1616 continue; 1617 } 1618 // Non-kernel function, just print .param .b<size> for ABI 1619 // and .reg .b<size> for non-ABI 1620 unsigned sz = 0; 1621 if (isa<IntegerType>(Ty)) { 1622 sz = cast<IntegerType>(Ty)->getBitWidth(); 1623 sz = promoteScalarArgumentSize(sz); 1624 } else if (PTy) { 1625 assert(PTySizeInBits && "Invalid pointer size"); 1626 sz = PTySizeInBits; 1627 } else 1628 sz = Ty->getPrimitiveSizeInBits(); 1629 if (isABI) 1630 O << "\t.param .b" << sz << " "; 1631 else 1632 O << "\t.reg .b" << sz << " "; 1633 O << TLI->getParamName(F, paramIndex); 1634 continue; 1635 } 1636 1637 // param has byVal attribute. 1638 Type *ETy = PAL.getParamByValType(paramIndex); 1639 assert(ETy && "Param should have byval type"); 1640 1641 if (isABI || isKernelFunc) { 1642 // Just print .param .align <a> .b8 .param[size]; 1643 // <a> = optimal alignment for the element type; always multiple of 1644 // PAL.getParamAlignment 1645 // size = typeallocsize of element type 1646 Align OptimalAlign = 1647 isKernelFunc 1648 ? getOptimalAlignForParam(ETy) 1649 : TLI->getFunctionByValParamAlign( 1650 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); 1651 1652 unsigned sz = DL.getTypeAllocSize(ETy); 1653 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1654 O << TLI->getParamName(F, paramIndex); 1655 O << "[" << sz << "]"; 1656 continue; 1657 } else { 1658 // Split the ETy into constituent parts and 1659 // print .param .b<size> <name> for each part. 1660 // Further, if a part is vector, print the above for 1661 // each vector element. 1662 SmallVector<EVT, 16> vtparts; 1663 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1664 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1665 unsigned elems = 1; 1666 EVT elemtype = vtparts[i]; 1667 if (vtparts[i].isVector()) { 1668 elems = vtparts[i].getVectorNumElements(); 1669 elemtype = vtparts[i].getVectorElementType(); 1670 } 1671 1672 for (unsigned j = 0, je = elems; j != je; ++j) { 1673 unsigned sz = elemtype.getSizeInBits(); 1674 if (elemtype.isInteger()) 1675 sz = promoteScalarArgumentSize(sz); 1676 O << "\t.reg .b" << sz << " "; 1677 O << TLI->getParamName(F, paramIndex); 1678 if (j < je - 1) 1679 O << ",\n"; 1680 ++paramIndex; 1681 } 1682 if (i < e - 1) 1683 O << ",\n"; 1684 } 1685 --paramIndex; 1686 continue; 1687 } 1688 } 1689 1690 if (F->isVarArg()) { 1691 if (!first) 1692 O << ",\n"; 1693 O << "\t.param .align " << STI.getMaxRequiredAlignment(); 1694 O << " .b8 "; 1695 O << TLI->getParamName(F, /* vararg */ -1) << "[]"; 1696 } 1697 1698 O << "\n)"; 1699 } 1700 1701 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1702 const MachineFunction &MF) { 1703 SmallString<128> Str; 1704 raw_svector_ostream O(Str); 1705 1706 // Map the global virtual register number to a register class specific 1707 // virtual register number starting from 1 with that class. 1708 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1709 //unsigned numRegClasses = TRI->getNumRegClasses(); 1710 1711 // Emit the Fake Stack Object 1712 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1713 int NumBytes = (int) MFI.getStackSize(); 1714 if (NumBytes) { 1715 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1716 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1717 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1718 O << "\t.reg .b64 \t%SP;\n"; 1719 O << "\t.reg .b64 \t%SPL;\n"; 1720 } else { 1721 O << "\t.reg .b32 \t%SP;\n"; 1722 O << "\t.reg .b32 \t%SPL;\n"; 1723 } 1724 } 1725 1726 // Go through all virtual registers to establish the mapping between the 1727 // global virtual 1728 // register number and the per class virtual register number. 1729 // We use the per class virtual register number in the ptx output. 1730 unsigned int numVRs = MRI->getNumVirtRegs(); 1731 for (unsigned i = 0; i < numVRs; i++) { 1732 Register vr = Register::index2VirtReg(i); 1733 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1734 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1735 int n = regmap.size(); 1736 regmap.insert(std::make_pair(vr, n + 1)); 1737 } 1738 1739 // Emit register declarations 1740 // @TODO: Extract out the real register usage 1741 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1742 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1743 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1744 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1745 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1746 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1747 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1748 1749 // Emit declaration of the virtual registers or 'physical' registers for 1750 // each register class 1751 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1752 const TargetRegisterClass *RC = TRI->getRegClass(i); 1753 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1754 std::string rcname = getNVPTXRegClassName(RC); 1755 std::string rcStr = getNVPTXRegClassStr(RC); 1756 int n = regmap.size(); 1757 1758 // Only declare those registers that may be used. 1759 if (n) { 1760 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1761 << ">;\n"; 1762 } 1763 } 1764 1765 OutStreamer->emitRawText(O.str()); 1766 } 1767 1768 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1769 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1770 bool ignored; 1771 unsigned int numHex; 1772 const char *lead; 1773 1774 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1775 numHex = 8; 1776 lead = "0f"; 1777 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1778 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1779 numHex = 16; 1780 lead = "0d"; 1781 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1782 } else 1783 llvm_unreachable("unsupported fp type"); 1784 1785 APInt API = APF.bitcastToAPInt(); 1786 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1787 } 1788 1789 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1790 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1791 O << CI->getValue(); 1792 return; 1793 } 1794 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1795 printFPConstant(CFP, O); 1796 return; 1797 } 1798 if (isa<ConstantPointerNull>(CPV)) { 1799 O << "0"; 1800 return; 1801 } 1802 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1803 bool IsNonGenericPointer = false; 1804 if (GVar->getType()->getAddressSpace() != 0) { 1805 IsNonGenericPointer = true; 1806 } 1807 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1808 O << "generic("; 1809 getSymbol(GVar)->print(O, MAI); 1810 O << ")"; 1811 } else { 1812 getSymbol(GVar)->print(O, MAI); 1813 } 1814 return; 1815 } 1816 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1817 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false); 1818 printMCExpr(*E, O); 1819 return; 1820 } 1821 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1822 } 1823 1824 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1825 AggBuffer *AggBuffer) { 1826 const DataLayout &DL = getDataLayout(); 1827 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1828 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1829 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1830 // only the space allocated by CPV. 1831 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1832 return; 1833 } 1834 1835 // Helper for filling AggBuffer with APInts. 1836 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1837 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1838 SmallVector<unsigned char, 16> Buf(NumBytes); 1839 for (unsigned I = 0; I < NumBytes; ++I) { 1840 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1841 } 1842 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1843 }; 1844 1845 switch (CPV->getType()->getTypeID()) { 1846 case Type::IntegerTyID: 1847 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1848 AddIntToBuffer(CI->getValue()); 1849 break; 1850 } 1851 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1852 if (const auto *CI = 1853 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1854 AddIntToBuffer(CI->getValue()); 1855 break; 1856 } 1857 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1858 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1859 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1860 AggBuffer->addZeros(AllocSize); 1861 break; 1862 } 1863 } 1864 llvm_unreachable("unsupported integer const type"); 1865 break; 1866 1867 case Type::HalfTyID: 1868 case Type::BFloatTyID: 1869 case Type::FloatTyID: 1870 case Type::DoubleTyID: 1871 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1872 break; 1873 1874 case Type::PointerTyID: { 1875 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1876 AggBuffer->addSymbol(GVar, GVar); 1877 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1878 const Value *v = Cexpr->stripPointerCasts(); 1879 AggBuffer->addSymbol(v, Cexpr); 1880 } 1881 AggBuffer->addZeros(AllocSize); 1882 break; 1883 } 1884 1885 case Type::ArrayTyID: 1886 case Type::FixedVectorTyID: 1887 case Type::StructTyID: { 1888 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1889 bufferAggregateConstant(CPV, AggBuffer); 1890 if (Bytes > AllocSize) 1891 AggBuffer->addZeros(Bytes - AllocSize); 1892 } else if (isa<ConstantAggregateZero>(CPV)) 1893 AggBuffer->addZeros(Bytes); 1894 else 1895 llvm_unreachable("Unexpected Constant type"); 1896 break; 1897 } 1898 1899 default: 1900 llvm_unreachable("unsupported type"); 1901 } 1902 } 1903 1904 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1905 AggBuffer *aggBuffer) { 1906 const DataLayout &DL = getDataLayout(); 1907 int Bytes; 1908 1909 // Integers of arbitrary width 1910 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1911 APInt Val = CI->getValue(); 1912 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1913 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1914 aggBuffer->addBytes(&Byte, 1, 1); 1915 Val.lshrInPlace(8); 1916 } 1917 return; 1918 } 1919 1920 // Old constants 1921 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1922 if (CPV->getNumOperands()) 1923 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1924 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1925 return; 1926 } 1927 1928 if (const ConstantDataSequential *CDS = 1929 dyn_cast<ConstantDataSequential>(CPV)) { 1930 if (CDS->getNumElements()) 1931 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1932 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1933 aggBuffer); 1934 return; 1935 } 1936 1937 if (isa<ConstantStruct>(CPV)) { 1938 if (CPV->getNumOperands()) { 1939 StructType *ST = cast<StructType>(CPV->getType()); 1940 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1941 if (i == (e - 1)) 1942 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1943 DL.getTypeAllocSize(ST) - 1944 DL.getStructLayout(ST)->getElementOffset(i); 1945 else 1946 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1947 DL.getStructLayout(ST)->getElementOffset(i); 1948 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1949 } 1950 } 1951 return; 1952 } 1953 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1954 } 1955 1956 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1957 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1958 /// expressions that are representable in PTX and create 1959 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1960 const MCExpr * 1961 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1962 MCContext &Ctx = OutContext; 1963 1964 if (CV->isNullValue() || isa<UndefValue>(CV)) 1965 return MCConstantExpr::create(0, Ctx); 1966 1967 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1968 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1969 1970 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1971 const MCSymbolRefExpr *Expr = 1972 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1973 if (ProcessingGeneric) { 1974 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1975 } else { 1976 return Expr; 1977 } 1978 } 1979 1980 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 1981 if (!CE) { 1982 llvm_unreachable("Unknown constant value to lower!"); 1983 } 1984 1985 switch (CE->getOpcode()) { 1986 default: 1987 break; // Error 1988 1989 case Instruction::AddrSpaceCast: { 1990 // Strip the addrspacecast and pass along the operand 1991 PointerType *DstTy = cast<PointerType>(CE->getType()); 1992 if (DstTy->getAddressSpace() == 0) 1993 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 1994 1995 break; // Error 1996 } 1997 1998 case Instruction::GetElementPtr: { 1999 const DataLayout &DL = getDataLayout(); 2000 2001 // Generate a symbolic expression for the byte address 2002 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 2003 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 2004 2005 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2006 ProcessingGeneric); 2007 if (!OffsetAI) 2008 return Base; 2009 2010 int64_t Offset = OffsetAI.getSExtValue(); 2011 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2012 Ctx); 2013 } 2014 2015 case Instruction::Trunc: 2016 // We emit the value and depend on the assembler to truncate the generated 2017 // expression properly. This is important for differences between 2018 // blockaddress labels. Since the two labels are in the same function, it 2019 // is reasonable to treat their delta as a 32-bit value. 2020 [[fallthrough]]; 2021 case Instruction::BitCast: 2022 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2023 2024 case Instruction::IntToPtr: { 2025 const DataLayout &DL = getDataLayout(); 2026 2027 // Handle casts to pointers by changing them into casts to the appropriate 2028 // integer type. This promotes constant folding and simplifies this code. 2029 Constant *Op = CE->getOperand(0); 2030 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2031 /*IsSigned*/ false, DL); 2032 if (Op) 2033 return lowerConstantForGV(Op, ProcessingGeneric); 2034 2035 break; // Error 2036 } 2037 2038 case Instruction::PtrToInt: { 2039 const DataLayout &DL = getDataLayout(); 2040 2041 // Support only foldable casts to/from pointers that can be eliminated by 2042 // changing the pointer to the appropriately sized integer type. 2043 Constant *Op = CE->getOperand(0); 2044 Type *Ty = CE->getType(); 2045 2046 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2047 2048 // We can emit the pointer value into this slot if the slot is an 2049 // integer slot equal to the size of the pointer. 2050 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2051 return OpExpr; 2052 2053 // Otherwise the pointer is smaller than the resultant integer, mask off 2054 // the high bits so we are sure to get a proper truncation if the input is 2055 // a constant expr. 2056 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2057 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2058 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2059 } 2060 2061 // The MC library also has a right-shift operator, but it isn't consistently 2062 // signed or unsigned between different targets. 2063 case Instruction::Add: { 2064 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2065 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2066 switch (CE->getOpcode()) { 2067 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2068 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2069 } 2070 } 2071 } 2072 2073 // If the code isn't optimized, there may be outstanding folding 2074 // opportunities. Attempt to fold the expression using DataLayout as a 2075 // last resort before giving up. 2076 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 2077 if (C != CE) 2078 return lowerConstantForGV(C, ProcessingGeneric); 2079 2080 // Otherwise report the problem to the user. 2081 std::string S; 2082 raw_string_ostream OS(S); 2083 OS << "Unsupported expression in static initializer: "; 2084 CE->printAsOperand(OS, /*PrintType=*/false, 2085 !MF ? nullptr : MF->getFunction().getParent()); 2086 report_fatal_error(Twine(OS.str())); 2087 } 2088 2089 // Copy of MCExpr::print customized for NVPTX 2090 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2091 switch (Expr.getKind()) { 2092 case MCExpr::Target: 2093 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2094 case MCExpr::Constant: 2095 OS << cast<MCConstantExpr>(Expr).getValue(); 2096 return; 2097 2098 case MCExpr::SymbolRef: { 2099 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2100 const MCSymbol &Sym = SRE.getSymbol(); 2101 Sym.print(OS, MAI); 2102 return; 2103 } 2104 2105 case MCExpr::Unary: { 2106 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2107 switch (UE.getOpcode()) { 2108 case MCUnaryExpr::LNot: OS << '!'; break; 2109 case MCUnaryExpr::Minus: OS << '-'; break; 2110 case MCUnaryExpr::Not: OS << '~'; break; 2111 case MCUnaryExpr::Plus: OS << '+'; break; 2112 } 2113 printMCExpr(*UE.getSubExpr(), OS); 2114 return; 2115 } 2116 2117 case MCExpr::Binary: { 2118 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2119 2120 // Only print parens around the LHS if it is non-trivial. 2121 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2122 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2123 printMCExpr(*BE.getLHS(), OS); 2124 } else { 2125 OS << '('; 2126 printMCExpr(*BE.getLHS(), OS); 2127 OS<< ')'; 2128 } 2129 2130 switch (BE.getOpcode()) { 2131 case MCBinaryExpr::Add: 2132 // Print "X-42" instead of "X+-42". 2133 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2134 if (RHSC->getValue() < 0) { 2135 OS << RHSC->getValue(); 2136 return; 2137 } 2138 } 2139 2140 OS << '+'; 2141 break; 2142 default: llvm_unreachable("Unhandled binary operator"); 2143 } 2144 2145 // Only print parens around the LHS if it is non-trivial. 2146 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2147 printMCExpr(*BE.getRHS(), OS); 2148 } else { 2149 OS << '('; 2150 printMCExpr(*BE.getRHS(), OS); 2151 OS << ')'; 2152 } 2153 return; 2154 } 2155 } 2156 2157 llvm_unreachable("Invalid expression kind!"); 2158 } 2159 2160 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2161 /// 2162 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2163 const char *ExtraCode, raw_ostream &O) { 2164 if (ExtraCode && ExtraCode[0]) { 2165 if (ExtraCode[1] != 0) 2166 return true; // Unknown modifier. 2167 2168 switch (ExtraCode[0]) { 2169 default: 2170 // See if this is a generic print operand 2171 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2172 case 'r': 2173 break; 2174 } 2175 } 2176 2177 printOperand(MI, OpNo, O); 2178 2179 return false; 2180 } 2181 2182 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2183 unsigned OpNo, 2184 const char *ExtraCode, 2185 raw_ostream &O) { 2186 if (ExtraCode && ExtraCode[0]) 2187 return true; // Unknown modifier 2188 2189 O << '['; 2190 printMemOperand(MI, OpNo, O); 2191 O << ']'; 2192 2193 return false; 2194 } 2195 2196 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum, 2197 raw_ostream &O) { 2198 const MachineOperand &MO = MI->getOperand(OpNum); 2199 switch (MO.getType()) { 2200 case MachineOperand::MO_Register: 2201 if (MO.getReg().isPhysical()) { 2202 if (MO.getReg() == NVPTX::VRDepot) 2203 O << DEPOTNAME << getFunctionNumber(); 2204 else 2205 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2206 } else { 2207 emitVirtualRegister(MO.getReg(), O); 2208 } 2209 break; 2210 2211 case MachineOperand::MO_Immediate: 2212 O << MO.getImm(); 2213 break; 2214 2215 case MachineOperand::MO_FPImmediate: 2216 printFPConstant(MO.getFPImm(), O); 2217 break; 2218 2219 case MachineOperand::MO_GlobalAddress: 2220 PrintSymbolOperand(MO, O); 2221 break; 2222 2223 case MachineOperand::MO_MachineBasicBlock: 2224 MO.getMBB()->getSymbol()->print(O, MAI); 2225 break; 2226 2227 default: 2228 llvm_unreachable("Operand type not supported."); 2229 } 2230 } 2231 2232 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum, 2233 raw_ostream &O, const char *Modifier) { 2234 printOperand(MI, OpNum, O); 2235 2236 if (Modifier && strcmp(Modifier, "add") == 0) { 2237 O << ", "; 2238 printOperand(MI, OpNum + 1, O); 2239 } else { 2240 if (MI->getOperand(OpNum + 1).isImm() && 2241 MI->getOperand(OpNum + 1).getImm() == 0) 2242 return; // don't print ',0' or '+0' 2243 O << "+"; 2244 printOperand(MI, OpNum + 1, O); 2245 } 2246 } 2247 2248 // Force static initialization. 2249 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2250 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2251 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2252 } 2253