1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Triple.h" 37 #include "llvm/ADT/Twine.h" 38 #include "llvm/Analysis/ConstantFolding.h" 39 #include "llvm/CodeGen/Analysis.h" 40 #include "llvm/CodeGen/MachineBasicBlock.h" 41 #include "llvm/CodeGen/MachineFrameInfo.h" 42 #include "llvm/CodeGen/MachineFunction.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineLoopInfo.h" 45 #include "llvm/CodeGen/MachineModuleInfo.h" 46 #include "llvm/CodeGen/MachineOperand.h" 47 #include "llvm/CodeGen/MachineRegisterInfo.h" 48 #include "llvm/CodeGen/TargetRegisterInfo.h" 49 #include "llvm/CodeGen/ValueTypes.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalValue.h" 61 #include "llvm/IR/GlobalVariable.h" 62 #include "llvm/IR/Instruction.h" 63 #include "llvm/IR/LLVMContext.h" 64 #include "llvm/IR/Module.h" 65 #include "llvm/IR/Operator.h" 66 #include "llvm/IR/Type.h" 67 #include "llvm/IR/User.h" 68 #include "llvm/MC/MCExpr.h" 69 #include "llvm/MC/MCInst.h" 70 #include "llvm/MC/MCInstrDesc.h" 71 #include "llvm/MC/MCStreamer.h" 72 #include "llvm/MC/MCSymbol.h" 73 #include "llvm/MC/TargetRegistry.h" 74 #include "llvm/Support/Casting.h" 75 #include "llvm/Support/CommandLine.h" 76 #include "llvm/Support/Endian.h" 77 #include "llvm/Support/ErrorHandling.h" 78 #include "llvm/Support/MachineValueType.h" 79 #include "llvm/Support/NativeFormatting.h" 80 #include "llvm/Support/Path.h" 81 #include "llvm/Support/raw_ostream.h" 82 #include "llvm/Target/TargetLoweringObjectFile.h" 83 #include "llvm/Target/TargetMachine.h" 84 #include "llvm/Transforms/Utils/UnrollLoop.h" 85 #include <cassert> 86 #include <cstdint> 87 #include <cstring> 88 #include <new> 89 #include <string> 90 #include <utility> 91 #include <vector> 92 93 using namespace llvm; 94 95 #define DEPOTNAME "__local_depot" 96 97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 98 /// depends. 99 static void 100 DiscoverDependentGlobals(const Value *V, 101 DenseSet<const GlobalVariable *> &Globals) { 102 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 103 Globals.insert(GV); 104 else { 105 if (const User *U = dyn_cast<User>(V)) { 106 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 107 DiscoverDependentGlobals(U->getOperand(i), Globals); 108 } 109 } 110 } 111 } 112 113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 114 /// instances to be emitted, but only after any dependents have been added 115 /// first.s 116 static void 117 VisitGlobalVariableForEmission(const GlobalVariable *GV, 118 SmallVectorImpl<const GlobalVariable *> &Order, 119 DenseSet<const GlobalVariable *> &Visited, 120 DenseSet<const GlobalVariable *> &Visiting) { 121 // Have we already visited this one? 122 if (Visited.count(GV)) 123 return; 124 125 // Do we have a circular dependency? 126 if (!Visiting.insert(GV).second) 127 report_fatal_error("Circular dependency found in global variable set"); 128 129 // Make sure we visit all dependents first 130 DenseSet<const GlobalVariable *> Others; 131 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 132 DiscoverDependentGlobals(GV->getOperand(i), Others); 133 134 for (const GlobalVariable *GV : Others) 135 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 136 137 // Now we can visit ourself 138 Order.push_back(GV); 139 Visited.insert(GV); 140 Visiting.erase(GV); 141 } 142 143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 144 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 145 getSubtargetInfo().getFeatureBits()); 146 147 MCInst Inst; 148 lowerToMCInst(MI, Inst); 149 EmitToStreamer(*OutStreamer, Inst); 150 } 151 152 // Handle symbol backtracking for targets that do not support image handles 153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 154 unsigned OpNo, MCOperand &MCOp) { 155 const MachineOperand &MO = MI->getOperand(OpNo); 156 const MCInstrDesc &MCID = MI->getDesc(); 157 158 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 159 // This is a texture fetch, so operand 4 is a texref and operand 5 is 160 // a samplerref 161 if (OpNo == 4 && MO.isImm()) { 162 lowerImageHandleSymbol(MO.getImm(), MCOp); 163 return true; 164 } 165 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 166 lowerImageHandleSymbol(MO.getImm(), MCOp); 167 return true; 168 } 169 170 return false; 171 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 172 unsigned VecSize = 173 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 174 175 // For a surface load of vector size N, the Nth operand will be the surfref 176 if (OpNo == VecSize && MO.isImm()) { 177 lowerImageHandleSymbol(MO.getImm(), MCOp); 178 return true; 179 } 180 181 return false; 182 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 183 // This is a surface store, so operand 0 is a surfref 184 if (OpNo == 0 && MO.isImm()) { 185 lowerImageHandleSymbol(MO.getImm(), MCOp); 186 return true; 187 } 188 189 return false; 190 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 191 // This is a query, so operand 1 is a surfref/texref 192 if (OpNo == 1 && MO.isImm()) { 193 lowerImageHandleSymbol(MO.getImm(), MCOp); 194 return true; 195 } 196 197 return false; 198 } 199 200 return false; 201 } 202 203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 204 // Ewwww 205 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 206 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 207 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 208 const char *Sym = MFI->getImageHandleSymbol(Index); 209 StringRef SymName = nvTM.getStrPool().save(Sym); 210 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); 211 } 212 213 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 214 OutMI.setOpcode(MI->getOpcode()); 215 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 216 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 217 const MachineOperand &MO = MI->getOperand(0); 218 OutMI.addOperand(GetSymbolRef( 219 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 220 return; 221 } 222 223 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 224 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 225 const MachineOperand &MO = MI->getOperand(i); 226 227 MCOperand MCOp; 228 if (!STI.hasImageHandles()) { 229 if (lowerImageHandleOperand(MI, i, MCOp)) { 230 OutMI.addOperand(MCOp); 231 continue; 232 } 233 } 234 235 if (lowerOperand(MO, MCOp)) 236 OutMI.addOperand(MCOp); 237 } 238 } 239 240 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 241 MCOperand &MCOp) { 242 switch (MO.getType()) { 243 default: llvm_unreachable("unknown operand type"); 244 case MachineOperand::MO_Register: 245 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 246 break; 247 case MachineOperand::MO_Immediate: 248 MCOp = MCOperand::createImm(MO.getImm()); 249 break; 250 case MachineOperand::MO_MachineBasicBlock: 251 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 252 MO.getMBB()->getSymbol(), OutContext)); 253 break; 254 case MachineOperand::MO_ExternalSymbol: 255 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 256 break; 257 case MachineOperand::MO_GlobalAddress: 258 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 259 break; 260 case MachineOperand::MO_FPImmediate: { 261 const ConstantFP *Cnt = MO.getFPImm(); 262 const APFloat &Val = Cnt->getValueAPF(); 263 264 switch (Cnt->getType()->getTypeID()) { 265 default: report_fatal_error("Unsupported FP type"); break; 266 case Type::HalfTyID: 267 MCOp = MCOperand::createExpr( 268 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 269 break; 270 case Type::FloatTyID: 271 MCOp = MCOperand::createExpr( 272 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 273 break; 274 case Type::DoubleTyID: 275 MCOp = MCOperand::createExpr( 276 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 277 break; 278 } 279 break; 280 } 281 } 282 return true; 283 } 284 285 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 286 if (Register::isVirtualRegister(Reg)) { 287 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 288 289 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 290 unsigned RegNum = RegMap[Reg]; 291 292 // Encode the register class in the upper 4 bits 293 // Must be kept in sync with NVPTXInstPrinter::printRegName 294 unsigned Ret = 0; 295 if (RC == &NVPTX::Int1RegsRegClass) { 296 Ret = (1 << 28); 297 } else if (RC == &NVPTX::Int16RegsRegClass) { 298 Ret = (2 << 28); 299 } else if (RC == &NVPTX::Int32RegsRegClass) { 300 Ret = (3 << 28); 301 } else if (RC == &NVPTX::Int64RegsRegClass) { 302 Ret = (4 << 28); 303 } else if (RC == &NVPTX::Float32RegsRegClass) { 304 Ret = (5 << 28); 305 } else if (RC == &NVPTX::Float64RegsRegClass) { 306 Ret = (6 << 28); 307 } else if (RC == &NVPTX::Float16RegsRegClass) { 308 Ret = (7 << 28); 309 } else if (RC == &NVPTX::Float16x2RegsRegClass) { 310 Ret = (8 << 28); 311 } else { 312 report_fatal_error("Bad register class"); 313 } 314 315 // Insert the vreg number 316 Ret |= (RegNum & 0x0FFFFFFF); 317 return Ret; 318 } else { 319 // Some special-use registers are actually physical registers. 320 // Encode this as the register class ID of 0 and the real register ID. 321 return Reg & 0x0FFFFFFF; 322 } 323 } 324 325 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 326 const MCExpr *Expr; 327 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 328 OutContext); 329 return MCOperand::createExpr(Expr); 330 } 331 332 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 333 const DataLayout &DL = getDataLayout(); 334 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 335 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 336 337 Type *Ty = F->getReturnType(); 338 339 bool isABI = (STI.getSmVersion() >= 20); 340 341 if (Ty->getTypeID() == Type::VoidTyID) 342 return; 343 344 O << " ("; 345 346 if (isABI) { 347 if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) { 348 unsigned size = 0; 349 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 350 size = ITy->getBitWidth(); 351 } else { 352 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 353 size = Ty->getPrimitiveSizeInBits(); 354 } 355 // PTX ABI requires all scalar return values to be at least 32 356 // bits in size. fp16 normally uses .b16 as its storage type in 357 // PTX, so its size must be adjusted here, too. 358 size = promoteScalarArgumentSize(size); 359 360 O << ".param .b" << size << " func_retval0"; 361 } else if (isa<PointerType>(Ty)) { 362 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 363 << " func_retval0"; 364 } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 365 unsigned totalsz = DL.getTypeAllocSize(Ty); 366 unsigned retAlignment = 0; 367 if (!getAlign(*F, 0, retAlignment)) 368 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value(); 369 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz 370 << "]"; 371 } else 372 llvm_unreachable("Unknown return type"); 373 } else { 374 SmallVector<EVT, 16> vtparts; 375 ComputeValueVTs(*TLI, DL, Ty, vtparts); 376 unsigned idx = 0; 377 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 378 unsigned elems = 1; 379 EVT elemtype = vtparts[i]; 380 if (vtparts[i].isVector()) { 381 elems = vtparts[i].getVectorNumElements(); 382 elemtype = vtparts[i].getVectorElementType(); 383 } 384 385 for (unsigned j = 0, je = elems; j != je; ++j) { 386 unsigned sz = elemtype.getSizeInBits(); 387 if (elemtype.isInteger()) 388 sz = promoteScalarArgumentSize(sz); 389 O << ".reg .b" << sz << " func_retval" << idx; 390 if (j < je - 1) 391 O << ", "; 392 ++idx; 393 } 394 if (i < e - 1) 395 O << ", "; 396 } 397 } 398 O << ") "; 399 } 400 401 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 402 raw_ostream &O) { 403 const Function &F = MF.getFunction(); 404 printReturnValStr(&F, O); 405 } 406 407 // Return true if MBB is the header of a loop marked with 408 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1. 409 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 410 const MachineBasicBlock &MBB) const { 411 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 412 // We insert .pragma "nounroll" only to the loop header. 413 if (!LI.isLoopHeader(&MBB)) 414 return false; 415 416 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 417 // we iterate through each back edge of the loop with header MBB, and check 418 // whether its metadata contains llvm.loop.unroll.disable. 419 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 420 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 421 // Edges from other loops to MBB are not back edges. 422 continue; 423 } 424 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 425 if (MDNode *LoopID = 426 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 427 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 428 return true; 429 if (MDNode *UnrollCountMD = 430 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) { 431 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1)) 432 ->getZExtValue() == 1) 433 return true; 434 } 435 } 436 } 437 } 438 return false; 439 } 440 441 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 442 AsmPrinter::emitBasicBlockStart(MBB); 443 if (isLoopHeaderOfNoUnroll(MBB)) 444 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 445 } 446 447 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 448 SmallString<128> Str; 449 raw_svector_ostream O(Str); 450 451 if (!GlobalsEmitted) { 452 emitGlobals(*MF->getFunction().getParent()); 453 GlobalsEmitted = true; 454 } 455 456 // Set up 457 MRI = &MF->getRegInfo(); 458 F = &MF->getFunction(); 459 emitLinkageDirective(F, O); 460 if (isKernelFunction(*F)) 461 O << ".entry "; 462 else { 463 O << ".func "; 464 printReturnValStr(*MF, O); 465 } 466 467 CurrentFnSym->print(O, MAI); 468 469 emitFunctionParamList(*MF, O); 470 471 if (isKernelFunction(*F)) 472 emitKernelFunctionDirectives(*F, O); 473 474 if (shouldEmitPTXNoReturn(F, TM)) 475 O << ".noreturn"; 476 477 OutStreamer->emitRawText(O.str()); 478 479 VRegMapping.clear(); 480 // Emit open brace for function body. 481 OutStreamer->emitRawText(StringRef("{\n")); 482 setAndEmitFunctionVirtualRegisters(*MF); 483 // Emit initial .loc debug directive for correct relocation symbol data. 484 if (MMI && MMI->hasDebugInfo()) 485 emitInitialRawDwarfLocDirective(*MF); 486 } 487 488 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 489 bool Result = AsmPrinter::runOnMachineFunction(F); 490 // Emit closing brace for the body of function F. 491 // The closing brace must be emitted here because we need to emit additional 492 // debug labels/data after the last basic block. 493 // We need to emit the closing brace here because we don't have function that 494 // finished emission of the function body. 495 OutStreamer->emitRawText(StringRef("}\n")); 496 return Result; 497 } 498 499 void NVPTXAsmPrinter::emitFunctionBodyStart() { 500 SmallString<128> Str; 501 raw_svector_ostream O(Str); 502 emitDemotedVars(&MF->getFunction(), O); 503 OutStreamer->emitRawText(O.str()); 504 } 505 506 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 507 VRegMapping.clear(); 508 } 509 510 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 511 SmallString<128> Str; 512 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 513 return OutContext.getOrCreateSymbol(Str); 514 } 515 516 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 517 Register RegNo = MI->getOperand(0).getReg(); 518 if (RegNo.isVirtual()) { 519 OutStreamer->AddComment(Twine("implicit-def: ") + 520 getVirtualRegisterName(RegNo)); 521 } else { 522 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 523 OutStreamer->AddComment(Twine("implicit-def: ") + 524 STI.getRegisterInfo()->getName(RegNo)); 525 } 526 OutStreamer->addBlankLine(); 527 } 528 529 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 530 raw_ostream &O) const { 531 // If the NVVM IR has some of reqntid* specified, then output 532 // the reqntid directive, and set the unspecified ones to 1. 533 // If none of reqntid* is specified, don't output reqntid directive. 534 unsigned reqntidx, reqntidy, reqntidz; 535 bool specified = false; 536 if (!getReqNTIDx(F, reqntidx)) 537 reqntidx = 1; 538 else 539 specified = true; 540 if (!getReqNTIDy(F, reqntidy)) 541 reqntidy = 1; 542 else 543 specified = true; 544 if (!getReqNTIDz(F, reqntidz)) 545 reqntidz = 1; 546 else 547 specified = true; 548 549 if (specified) 550 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz 551 << "\n"; 552 553 // If the NVVM IR has some of maxntid* specified, then output 554 // the maxntid directive, and set the unspecified ones to 1. 555 // If none of maxntid* is specified, don't output maxntid directive. 556 unsigned maxntidx, maxntidy, maxntidz; 557 specified = false; 558 if (!getMaxNTIDx(F, maxntidx)) 559 maxntidx = 1; 560 else 561 specified = true; 562 if (!getMaxNTIDy(F, maxntidy)) 563 maxntidy = 1; 564 else 565 specified = true; 566 if (!getMaxNTIDz(F, maxntidz)) 567 maxntidz = 1; 568 else 569 specified = true; 570 571 if (specified) 572 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz 573 << "\n"; 574 575 unsigned mincta; 576 if (getMinCTASm(F, mincta)) 577 O << ".minnctapersm " << mincta << "\n"; 578 579 unsigned maxnreg; 580 if (getMaxNReg(F, maxnreg)) 581 O << ".maxnreg " << maxnreg << "\n"; 582 } 583 584 std::string 585 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 586 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 587 588 std::string Name; 589 raw_string_ostream NameStr(Name); 590 591 VRegRCMap::const_iterator I = VRegMapping.find(RC); 592 assert(I != VRegMapping.end() && "Bad register class"); 593 const DenseMap<unsigned, unsigned> &RegMap = I->second; 594 595 VRegMap::const_iterator VI = RegMap.find(Reg); 596 assert(VI != RegMap.end() && "Bad virtual register"); 597 unsigned MappedVR = VI->second; 598 599 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 600 601 NameStr.flush(); 602 return Name; 603 } 604 605 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 606 raw_ostream &O) { 607 O << getVirtualRegisterName(vr); 608 } 609 610 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 611 emitLinkageDirective(F, O); 612 if (isKernelFunction(*F)) 613 O << ".entry "; 614 else 615 O << ".func "; 616 printReturnValStr(F, O); 617 getSymbol(F)->print(O, MAI); 618 O << "\n"; 619 emitFunctionParamList(F, O); 620 if (shouldEmitPTXNoReturn(F, TM)) 621 O << ".noreturn"; 622 O << ";\n"; 623 } 624 625 static bool usedInGlobalVarDef(const Constant *C) { 626 if (!C) 627 return false; 628 629 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 630 return GV->getName() != "llvm.used"; 631 } 632 633 for (const User *U : C->users()) 634 if (const Constant *C = dyn_cast<Constant>(U)) 635 if (usedInGlobalVarDef(C)) 636 return true; 637 638 return false; 639 } 640 641 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 642 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 643 if (othergv->getName() == "llvm.used") 644 return true; 645 } 646 647 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 648 if (instr->getParent() && instr->getParent()->getParent()) { 649 const Function *curFunc = instr->getParent()->getParent(); 650 if (oneFunc && (curFunc != oneFunc)) 651 return false; 652 oneFunc = curFunc; 653 return true; 654 } else 655 return false; 656 } 657 658 for (const User *UU : U->users()) 659 if (!usedInOneFunc(UU, oneFunc)) 660 return false; 661 662 return true; 663 } 664 665 /* Find out if a global variable can be demoted to local scope. 666 * Currently, this is valid for CUDA shared variables, which have local 667 * scope and global lifetime. So the conditions to check are : 668 * 1. Is the global variable in shared address space? 669 * 2. Does it have internal linkage? 670 * 3. Is the global variable referenced only in one function? 671 */ 672 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 673 if (!gv->hasInternalLinkage()) 674 return false; 675 PointerType *Pty = gv->getType(); 676 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 677 return false; 678 679 const Function *oneFunc = nullptr; 680 681 bool flag = usedInOneFunc(gv, oneFunc); 682 if (!flag) 683 return false; 684 if (!oneFunc) 685 return false; 686 f = oneFunc; 687 return true; 688 } 689 690 static bool useFuncSeen(const Constant *C, 691 DenseMap<const Function *, bool> &seenMap) { 692 for (const User *U : C->users()) { 693 if (const Constant *cu = dyn_cast<Constant>(U)) { 694 if (useFuncSeen(cu, seenMap)) 695 return true; 696 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 697 const BasicBlock *bb = I->getParent(); 698 if (!bb) 699 continue; 700 const Function *caller = bb->getParent(); 701 if (!caller) 702 continue; 703 if (seenMap.find(caller) != seenMap.end()) 704 return true; 705 } 706 } 707 return false; 708 } 709 710 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 711 DenseMap<const Function *, bool> seenMap; 712 for (const Function &F : M) { 713 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 714 emitDeclaration(&F, O); 715 continue; 716 } 717 718 if (F.isDeclaration()) { 719 if (F.use_empty()) 720 continue; 721 if (F.getIntrinsicID()) 722 continue; 723 emitDeclaration(&F, O); 724 continue; 725 } 726 for (const User *U : F.users()) { 727 if (const Constant *C = dyn_cast<Constant>(U)) { 728 if (usedInGlobalVarDef(C)) { 729 // The use is in the initialization of a global variable 730 // that is a function pointer, so print a declaration 731 // for the original function 732 emitDeclaration(&F, O); 733 break; 734 } 735 // Emit a declaration of this function if the function that 736 // uses this constant expr has already been seen. 737 if (useFuncSeen(C, seenMap)) { 738 emitDeclaration(&F, O); 739 break; 740 } 741 } 742 743 if (!isa<Instruction>(U)) 744 continue; 745 const Instruction *instr = cast<Instruction>(U); 746 const BasicBlock *bb = instr->getParent(); 747 if (!bb) 748 continue; 749 const Function *caller = bb->getParent(); 750 if (!caller) 751 continue; 752 753 // If a caller has already been seen, then the caller is 754 // appearing in the module before the callee. so print out 755 // a declaration for the callee. 756 if (seenMap.find(caller) != seenMap.end()) { 757 emitDeclaration(&F, O); 758 break; 759 } 760 } 761 seenMap[&F] = true; 762 } 763 } 764 765 static bool isEmptyXXStructor(GlobalVariable *GV) { 766 if (!GV) return true; 767 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 768 if (!InitList) return true; // Not an array; we don't know how to parse. 769 return InitList->getNumOperands() == 0; 770 } 771 772 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 773 // Construct a default subtarget off of the TargetMachine defaults. The 774 // rest of NVPTX isn't friendly to change subtargets per function and 775 // so the default TargetMachine will have all of the options. 776 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 777 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 778 SmallString<128> Str1; 779 raw_svector_ostream OS1(Str1); 780 781 // Emit header before any dwarf directives are emitted below. 782 emitHeader(M, OS1, *STI); 783 OutStreamer->emitRawText(OS1.str()); 784 } 785 786 bool NVPTXAsmPrinter::doInitialization(Module &M) { 787 if (M.alias_size()) { 788 report_fatal_error("Module has aliases, which NVPTX does not support."); 789 return true; // error 790 } 791 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) { 792 report_fatal_error( 793 "Module has a nontrivial global ctor, which NVPTX does not support."); 794 return true; // error 795 } 796 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) { 797 report_fatal_error( 798 "Module has a nontrivial global dtor, which NVPTX does not support."); 799 return true; // error 800 } 801 802 // We need to call the parent's one explicitly. 803 bool Result = AsmPrinter::doInitialization(M); 804 805 GlobalsEmitted = false; 806 807 return Result; 808 } 809 810 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 811 SmallString<128> Str2; 812 raw_svector_ostream OS2(Str2); 813 814 emitDeclarations(M, OS2); 815 816 // As ptxas does not support forward references of globals, we need to first 817 // sort the list of module-level globals in def-use order. We visit each 818 // global variable in order, and ensure that we emit it *after* its dependent 819 // globals. We use a little extra memory maintaining both a set and a list to 820 // have fast searches while maintaining a strict ordering. 821 SmallVector<const GlobalVariable *, 8> Globals; 822 DenseSet<const GlobalVariable *> GVVisited; 823 DenseSet<const GlobalVariable *> GVVisiting; 824 825 // Visit each global variable, in order 826 for (const GlobalVariable &I : M.globals()) 827 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 828 829 assert(GVVisited.size() == M.getGlobalList().size() && 830 "Missed a global variable"); 831 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 832 833 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 834 const NVPTXSubtarget &STI = 835 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 836 837 // Print out module-level global variables in proper order 838 for (unsigned i = 0, e = Globals.size(); i != e; ++i) 839 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI); 840 841 OS2 << '\n'; 842 843 OutStreamer->emitRawText(OS2.str()); 844 } 845 846 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 847 const NVPTXSubtarget &STI) { 848 O << "//\n"; 849 O << "// Generated by LLVM NVPTX Back-End\n"; 850 O << "//\n"; 851 O << "\n"; 852 853 unsigned PTXVersion = STI.getPTXVersion(); 854 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 855 856 O << ".target "; 857 O << STI.getTargetName(); 858 859 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 860 if (NTM.getDrvInterface() == NVPTX::NVCL) 861 O << ", texmode_independent"; 862 863 bool HasFullDebugInfo = false; 864 for (DICompileUnit *CU : M.debug_compile_units()) { 865 switch(CU->getEmissionKind()) { 866 case DICompileUnit::NoDebug: 867 case DICompileUnit::DebugDirectivesOnly: 868 break; 869 case DICompileUnit::LineTablesOnly: 870 case DICompileUnit::FullDebug: 871 HasFullDebugInfo = true; 872 break; 873 } 874 if (HasFullDebugInfo) 875 break; 876 } 877 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 878 O << ", debug"; 879 880 O << "\n"; 881 882 O << ".address_size "; 883 if (NTM.is64Bit()) 884 O << "64"; 885 else 886 O << "32"; 887 O << "\n"; 888 889 O << "\n"; 890 } 891 892 bool NVPTXAsmPrinter::doFinalization(Module &M) { 893 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 894 895 // If we did not emit any functions, then the global declarations have not 896 // yet been emitted. 897 if (!GlobalsEmitted) { 898 emitGlobals(M); 899 GlobalsEmitted = true; 900 } 901 902 // call doFinalization 903 bool ret = AsmPrinter::doFinalization(M); 904 905 clearAnnotationCache(&M); 906 907 auto *TS = 908 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()); 909 // Close the last emitted section 910 if (HasDebugInfo) { 911 TS->closeLastSection(); 912 // Emit empty .debug_loc section for better support of the empty files. 913 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 914 } 915 916 // Output last DWARF .file directives, if any. 917 TS->outputDwarfFileDirectives(); 918 919 return ret; 920 } 921 922 // This function emits appropriate linkage directives for 923 // functions and global variables. 924 // 925 // extern function declaration -> .extern 926 // extern function definition -> .visible 927 // external global variable with init -> .visible 928 // external without init -> .extern 929 // appending -> not allowed, assert. 930 // for any linkage other than 931 // internal, private, linker_private, 932 // linker_private_weak, linker_private_weak_def_auto, 933 // we emit -> .weak. 934 935 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 936 raw_ostream &O) { 937 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 938 if (V->hasExternalLinkage()) { 939 if (isa<GlobalVariable>(V)) { 940 const GlobalVariable *GVar = cast<GlobalVariable>(V); 941 if (GVar) { 942 if (GVar->hasInitializer()) 943 O << ".visible "; 944 else 945 O << ".extern "; 946 } 947 } else if (V->isDeclaration()) 948 O << ".extern "; 949 else 950 O << ".visible "; 951 } else if (V->hasAppendingLinkage()) { 952 std::string msg; 953 msg.append("Error: "); 954 msg.append("Symbol "); 955 if (V->hasName()) 956 msg.append(std::string(V->getName())); 957 msg.append("has unsupported appending linkage type"); 958 llvm_unreachable(msg.c_str()); 959 } else if (!V->hasInternalLinkage() && 960 !V->hasPrivateLinkage()) { 961 O << ".weak "; 962 } 963 } 964 } 965 966 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 967 raw_ostream &O, bool processDemoted, 968 const NVPTXSubtarget &STI) { 969 // Skip meta data 970 if (GVar->hasSection()) { 971 if (GVar->getSection() == "llvm.metadata") 972 return; 973 } 974 975 // Skip LLVM intrinsic global variables 976 if (GVar->getName().startswith("llvm.") || 977 GVar->getName().startswith("nvvm.")) 978 return; 979 980 const DataLayout &DL = getDataLayout(); 981 982 // GlobalVariables are always constant pointers themselves. 983 PointerType *PTy = GVar->getType(); 984 Type *ETy = GVar->getValueType(); 985 986 if (GVar->hasExternalLinkage()) { 987 if (GVar->hasInitializer()) 988 O << ".visible "; 989 else 990 O << ".extern "; 991 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 992 GVar->hasAvailableExternallyLinkage() || 993 GVar->hasCommonLinkage()) { 994 O << ".weak "; 995 } 996 997 if (isTexture(*GVar)) { 998 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 999 return; 1000 } 1001 1002 if (isSurface(*GVar)) { 1003 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1004 return; 1005 } 1006 1007 if (GVar->isDeclaration()) { 1008 // (extern) declarations, no definition or initializer 1009 // Currently the only known declaration is for an automatic __local 1010 // (.shared) promoted to global. 1011 emitPTXGlobalVariable(GVar, O, STI); 1012 O << ";\n"; 1013 return; 1014 } 1015 1016 if (isSampler(*GVar)) { 1017 O << ".global .samplerref " << getSamplerName(*GVar); 1018 1019 const Constant *Initializer = nullptr; 1020 if (GVar->hasInitializer()) 1021 Initializer = GVar->getInitializer(); 1022 const ConstantInt *CI = nullptr; 1023 if (Initializer) 1024 CI = dyn_cast<ConstantInt>(Initializer); 1025 if (CI) { 1026 unsigned sample = CI->getZExtValue(); 1027 1028 O << " = { "; 1029 1030 for (int i = 0, 1031 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1032 i < 3; i++) { 1033 O << "addr_mode_" << i << " = "; 1034 switch (addr) { 1035 case 0: 1036 O << "wrap"; 1037 break; 1038 case 1: 1039 O << "clamp_to_border"; 1040 break; 1041 case 2: 1042 O << "clamp_to_edge"; 1043 break; 1044 case 3: 1045 O << "wrap"; 1046 break; 1047 case 4: 1048 O << "mirror"; 1049 break; 1050 } 1051 O << ", "; 1052 } 1053 O << "filter_mode = "; 1054 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1055 case 0: 1056 O << "nearest"; 1057 break; 1058 case 1: 1059 O << "linear"; 1060 break; 1061 case 2: 1062 llvm_unreachable("Anisotropic filtering is not supported"); 1063 default: 1064 O << "nearest"; 1065 break; 1066 } 1067 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1068 O << ", force_unnormalized_coords = 1"; 1069 } 1070 O << " }"; 1071 } 1072 1073 O << ";\n"; 1074 return; 1075 } 1076 1077 if (GVar->hasPrivateLinkage()) { 1078 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1079 return; 1080 1081 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1082 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1083 return; 1084 if (GVar->use_empty()) 1085 return; 1086 } 1087 1088 const Function *demotedFunc = nullptr; 1089 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1090 O << "// " << GVar->getName() << " has been demoted\n"; 1091 if (localDecls.find(demotedFunc) != localDecls.end()) 1092 localDecls[demotedFunc].push_back(GVar); 1093 else { 1094 std::vector<const GlobalVariable *> temp; 1095 temp.push_back(GVar); 1096 localDecls[demotedFunc] = temp; 1097 } 1098 return; 1099 } 1100 1101 O << "."; 1102 emitPTXAddressSpace(PTy->getAddressSpace(), O); 1103 1104 if (isManaged(*GVar)) { 1105 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1106 report_fatal_error( 1107 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1108 } 1109 O << " .attribute(.managed)"; 1110 } 1111 1112 if (MaybeAlign A = GVar->getAlign()) 1113 O << " .align " << A->value(); 1114 else 1115 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1116 1117 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1118 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1119 O << " ."; 1120 // Special case: ABI requires that we use .u8 for predicates 1121 if (ETy->isIntegerTy(1)) 1122 O << "u8"; 1123 else 1124 O << getPTXFundamentalTypeStr(ETy, false); 1125 O << " "; 1126 getSymbol(GVar)->print(O, MAI); 1127 1128 // Ptx allows variable initilization only for constant and global state 1129 // spaces. 1130 if (GVar->hasInitializer()) { 1131 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1132 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1133 const Constant *Initializer = GVar->getInitializer(); 1134 // 'undef' is treated as there is no value specified. 1135 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1136 O << " = "; 1137 printScalarConstant(Initializer, O); 1138 } 1139 } else { 1140 // The frontend adds zero-initializer to device and constant variables 1141 // that don't have an initial value, and UndefValue to shared 1142 // variables, so skip warning for this case. 1143 if (!GVar->getInitializer()->isNullValue() && 1144 !isa<UndefValue>(GVar->getInitializer())) { 1145 report_fatal_error("initial value of '" + GVar->getName() + 1146 "' is not allowed in addrspace(" + 1147 Twine(PTy->getAddressSpace()) + ")"); 1148 } 1149 } 1150 } 1151 } else { 1152 unsigned int ElementSize = 0; 1153 1154 // Although PTX has direct support for struct type and array type and 1155 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1156 // targets that support these high level field accesses. Structs, arrays 1157 // and vectors are lowered into arrays of bytes. 1158 switch (ETy->getTypeID()) { 1159 case Type::IntegerTyID: // Integers larger than 64 bits 1160 case Type::StructTyID: 1161 case Type::ArrayTyID: 1162 case Type::FixedVectorTyID: 1163 ElementSize = DL.getTypeStoreSize(ETy); 1164 // Ptx allows variable initilization only for constant and 1165 // global state spaces. 1166 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1167 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1168 GVar->hasInitializer()) { 1169 const Constant *Initializer = GVar->getInitializer(); 1170 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1171 AggBuffer aggBuffer(ElementSize, *this); 1172 bufferAggregateConstant(Initializer, &aggBuffer); 1173 if (aggBuffer.numSymbols()) { 1174 unsigned int ptrSize = MAI->getCodePointerSize(); 1175 if (ElementSize % ptrSize || 1176 !aggBuffer.allSymbolsAligned(ptrSize)) { 1177 // Print in bytes and use the mask() operator for pointers. 1178 if (!STI.hasMaskOperator()) 1179 report_fatal_error( 1180 "initialized packed aggregate with pointers '" + 1181 GVar->getName() + 1182 "' requires at least PTX ISA version 7.1"); 1183 O << " .u8 "; 1184 getSymbol(GVar)->print(O, MAI); 1185 O << "[" << ElementSize << "] = {"; 1186 aggBuffer.printBytes(O); 1187 O << "}"; 1188 } else { 1189 O << " .u" << ptrSize * 8 << " "; 1190 getSymbol(GVar)->print(O, MAI); 1191 O << "[" << ElementSize / ptrSize << "] = {"; 1192 aggBuffer.printWords(O); 1193 O << "}"; 1194 } 1195 } else { 1196 O << " .b8 "; 1197 getSymbol(GVar)->print(O, MAI); 1198 O << "[" << ElementSize << "] = {"; 1199 aggBuffer.printBytes(O); 1200 O << "}"; 1201 } 1202 } else { 1203 O << " .b8 "; 1204 getSymbol(GVar)->print(O, MAI); 1205 if (ElementSize) { 1206 O << "["; 1207 O << ElementSize; 1208 O << "]"; 1209 } 1210 } 1211 } else { 1212 O << " .b8 "; 1213 getSymbol(GVar)->print(O, MAI); 1214 if (ElementSize) { 1215 O << "["; 1216 O << ElementSize; 1217 O << "]"; 1218 } 1219 } 1220 break; 1221 default: 1222 llvm_unreachable("type not supported yet"); 1223 } 1224 } 1225 O << ";\n"; 1226 } 1227 1228 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1229 const Value *v = Symbols[nSym]; 1230 const Value *v0 = SymbolsBeforeStripping[nSym]; 1231 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1232 MCSymbol *Name = AP.getSymbol(GVar); 1233 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1234 // Is v0 a generic pointer? 1235 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1236 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1237 os << "generic("; 1238 Name->print(os, AP.MAI); 1239 os << ")"; 1240 } else { 1241 Name->print(os, AP.MAI); 1242 } 1243 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1244 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1245 AP.printMCExpr(*Expr, os); 1246 } else 1247 llvm_unreachable("symbol type unknown"); 1248 } 1249 1250 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1251 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1252 symbolPosInBuffer.push_back(size); 1253 unsigned int nSym = 0; 1254 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1255 for (unsigned int pos = 0; pos < size;) { 1256 if (pos) 1257 os << ", "; 1258 if (pos != nextSymbolPos) { 1259 os << (unsigned int)buffer[pos]; 1260 ++pos; 1261 continue; 1262 } 1263 // Generate a per-byte mask() operator for the symbol, which looks like: 1264 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1265 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1266 std::string symText; 1267 llvm::raw_string_ostream oss(symText); 1268 printSymbol(nSym, oss); 1269 for (unsigned i = 0; i < ptrSize; ++i) { 1270 if (i) 1271 os << ", "; 1272 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1273 os << "(" << symText << ")"; 1274 } 1275 pos += ptrSize; 1276 nextSymbolPos = symbolPosInBuffer[++nSym]; 1277 assert(nextSymbolPos >= pos); 1278 } 1279 } 1280 1281 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1282 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1283 symbolPosInBuffer.push_back(size); 1284 unsigned int nSym = 0; 1285 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1286 assert(nextSymbolPos % ptrSize == 0); 1287 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1288 if (pos) 1289 os << ", "; 1290 if (pos == nextSymbolPos) { 1291 printSymbol(nSym, os); 1292 nextSymbolPos = symbolPosInBuffer[++nSym]; 1293 assert(nextSymbolPos % ptrSize == 0); 1294 assert(nextSymbolPos >= pos + ptrSize); 1295 } else if (ptrSize == 4) 1296 os << support::endian::read32le(&buffer[pos]); 1297 else 1298 os << support::endian::read64le(&buffer[pos]); 1299 } 1300 } 1301 1302 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1303 if (localDecls.find(f) == localDecls.end()) 1304 return; 1305 1306 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1307 1308 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1309 const NVPTXSubtarget &STI = 1310 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1311 1312 for (const GlobalVariable *GV : gvars) { 1313 O << "\t// demoted variable\n\t"; 1314 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1315 } 1316 } 1317 1318 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1319 raw_ostream &O) const { 1320 switch (AddressSpace) { 1321 case ADDRESS_SPACE_LOCAL: 1322 O << "local"; 1323 break; 1324 case ADDRESS_SPACE_GLOBAL: 1325 O << "global"; 1326 break; 1327 case ADDRESS_SPACE_CONST: 1328 O << "const"; 1329 break; 1330 case ADDRESS_SPACE_SHARED: 1331 O << "shared"; 1332 break; 1333 default: 1334 report_fatal_error("Bad address space found while emitting PTX: " + 1335 llvm::Twine(AddressSpace)); 1336 break; 1337 } 1338 } 1339 1340 std::string 1341 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1342 switch (Ty->getTypeID()) { 1343 case Type::IntegerTyID: { 1344 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1345 if (NumBits == 1) 1346 return "pred"; 1347 else if (NumBits <= 64) { 1348 std::string name = "u"; 1349 return name + utostr(NumBits); 1350 } else { 1351 llvm_unreachable("Integer too large"); 1352 break; 1353 } 1354 break; 1355 } 1356 case Type::HalfTyID: 1357 // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. 1358 return "b16"; 1359 case Type::FloatTyID: 1360 return "f32"; 1361 case Type::DoubleTyID: 1362 return "f64"; 1363 case Type::PointerTyID: { 1364 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace()); 1365 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size"); 1366 1367 if (PtrSize == 64) 1368 if (useB4PTR) 1369 return "b64"; 1370 else 1371 return "u64"; 1372 else if (useB4PTR) 1373 return "b32"; 1374 else 1375 return "u32"; 1376 } 1377 default: 1378 break; 1379 } 1380 llvm_unreachable("unexpected type"); 1381 } 1382 1383 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1384 raw_ostream &O, 1385 const NVPTXSubtarget &STI) { 1386 const DataLayout &DL = getDataLayout(); 1387 1388 // GlobalVariables are always constant pointers themselves. 1389 Type *ETy = GVar->getValueType(); 1390 1391 O << "."; 1392 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1393 if (isManaged(*GVar)) { 1394 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1395 report_fatal_error( 1396 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1397 } 1398 O << " .attribute(.managed)"; 1399 } 1400 if (MaybeAlign A = GVar->getAlign()) 1401 O << " .align " << A->value(); 1402 else 1403 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1404 1405 // Special case for i128 1406 if (ETy->isIntegerTy(128)) { 1407 O << " .b8 "; 1408 getSymbol(GVar)->print(O, MAI); 1409 O << "[16]"; 1410 return; 1411 } 1412 1413 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1414 O << " ."; 1415 O << getPTXFundamentalTypeStr(ETy); 1416 O << " "; 1417 getSymbol(GVar)->print(O, MAI); 1418 return; 1419 } 1420 1421 int64_t ElementSize = 0; 1422 1423 // Although PTX has direct support for struct type and array type and LLVM IR 1424 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1425 // support these high level field accesses. Structs and arrays are lowered 1426 // into arrays of bytes. 1427 switch (ETy->getTypeID()) { 1428 case Type::StructTyID: 1429 case Type::ArrayTyID: 1430 case Type::FixedVectorTyID: 1431 ElementSize = DL.getTypeStoreSize(ETy); 1432 O << " .b8 "; 1433 getSymbol(GVar)->print(O, MAI); 1434 O << "["; 1435 if (ElementSize) { 1436 O << ElementSize; 1437 } 1438 O << "]"; 1439 break; 1440 default: 1441 llvm_unreachable("type not supported yet"); 1442 } 1443 } 1444 1445 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, 1446 int paramIndex, raw_ostream &O) { 1447 getSymbol(I->getParent())->print(O, MAI); 1448 O << "_param_" << paramIndex; 1449 } 1450 1451 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1452 const DataLayout &DL = getDataLayout(); 1453 const AttributeList &PAL = F->getAttributes(); 1454 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1455 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1456 1457 Function::const_arg_iterator I, E; 1458 unsigned paramIndex = 0; 1459 bool first = true; 1460 bool isKernelFunc = isKernelFunction(*F); 1461 bool isABI = (STI.getSmVersion() >= 20); 1462 bool hasImageHandles = STI.hasImageHandles(); 1463 1464 if (F->arg_empty() && !F->isVarArg()) { 1465 O << "()\n"; 1466 return; 1467 } 1468 1469 O << "(\n"; 1470 1471 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1472 Type *Ty = I->getType(); 1473 1474 if (!first) 1475 O << ",\n"; 1476 1477 first = false; 1478 1479 // Handle image/sampler parameters 1480 if (isKernelFunction(*F)) { 1481 if (isSampler(*I) || isImage(*I)) { 1482 if (isImage(*I)) { 1483 std::string sname = std::string(I->getName()); 1484 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1485 if (hasImageHandles) 1486 O << "\t.param .u64 .ptr .surfref "; 1487 else 1488 O << "\t.param .surfref "; 1489 CurrentFnSym->print(O, MAI); 1490 O << "_param_" << paramIndex; 1491 } 1492 else { // Default image is read_only 1493 if (hasImageHandles) 1494 O << "\t.param .u64 .ptr .texref "; 1495 else 1496 O << "\t.param .texref "; 1497 CurrentFnSym->print(O, MAI); 1498 O << "_param_" << paramIndex; 1499 } 1500 } else { 1501 if (hasImageHandles) 1502 O << "\t.param .u64 .ptr .samplerref "; 1503 else 1504 O << "\t.param .samplerref "; 1505 CurrentFnSym->print(O, MAI); 1506 O << "_param_" << paramIndex; 1507 } 1508 continue; 1509 } 1510 } 1511 1512 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1513 paramIndex](Type *Ty) -> Align { 1514 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1515 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1516 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1517 }; 1518 1519 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1520 if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 1521 // Just print .param .align <a> .b8 .param[size]; 1522 // <a> = optimal alignment for the element type; always multiple of 1523 // PAL.getParamAlignment 1524 // size = typeallocsize of element type 1525 Align OptimalAlign = getOptimalAlignForParam(Ty); 1526 1527 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1528 printParamName(I, paramIndex, O); 1529 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1530 1531 continue; 1532 } 1533 // Just a scalar 1534 auto *PTy = dyn_cast<PointerType>(Ty); 1535 unsigned PTySizeInBits = 0; 1536 if (PTy) { 1537 PTySizeInBits = 1538 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); 1539 assert(PTySizeInBits && "Invalid pointer size"); 1540 } 1541 1542 if (isKernelFunc) { 1543 if (PTy) { 1544 // Special handling for pointer arguments to kernel 1545 O << "\t.param .u" << PTySizeInBits << " "; 1546 1547 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1548 NVPTX::CUDA) { 1549 int addrSpace = PTy->getAddressSpace(); 1550 switch (addrSpace) { 1551 default: 1552 O << ".ptr "; 1553 break; 1554 case ADDRESS_SPACE_CONST: 1555 O << ".ptr .const "; 1556 break; 1557 case ADDRESS_SPACE_SHARED: 1558 O << ".ptr .shared "; 1559 break; 1560 case ADDRESS_SPACE_GLOBAL: 1561 O << ".ptr .global "; 1562 break; 1563 } 1564 Align ParamAlign = I->getParamAlign().valueOrOne(); 1565 O << ".align " << ParamAlign.value() << " "; 1566 } 1567 printParamName(I, paramIndex, O); 1568 continue; 1569 } 1570 1571 // non-pointer scalar to kernel func 1572 O << "\t.param ."; 1573 // Special case: predicate operands become .u8 types 1574 if (Ty->isIntegerTy(1)) 1575 O << "u8"; 1576 else 1577 O << getPTXFundamentalTypeStr(Ty); 1578 O << " "; 1579 printParamName(I, paramIndex, O); 1580 continue; 1581 } 1582 // Non-kernel function, just print .param .b<size> for ABI 1583 // and .reg .b<size> for non-ABI 1584 unsigned sz = 0; 1585 if (isa<IntegerType>(Ty)) { 1586 sz = cast<IntegerType>(Ty)->getBitWidth(); 1587 sz = promoteScalarArgumentSize(sz); 1588 } else if (PTy) { 1589 assert(PTySizeInBits && "Invalid pointer size"); 1590 sz = PTySizeInBits; 1591 } else if (Ty->isHalfTy()) 1592 // PTX ABI requires all scalar parameters to be at least 32 1593 // bits in size. fp16 normally uses .b16 as its storage type 1594 // in PTX, so its size must be adjusted here, too. 1595 sz = 32; 1596 else 1597 sz = Ty->getPrimitiveSizeInBits(); 1598 if (isABI) 1599 O << "\t.param .b" << sz << " "; 1600 else 1601 O << "\t.reg .b" << sz << " "; 1602 printParamName(I, paramIndex, O); 1603 continue; 1604 } 1605 1606 // param has byVal attribute. 1607 Type *ETy = PAL.getParamByValType(paramIndex); 1608 assert(ETy && "Param should have byval type"); 1609 1610 if (isABI || isKernelFunc) { 1611 // Just print .param .align <a> .b8 .param[size]; 1612 // <a> = optimal alignment for the element type; always multiple of 1613 // PAL.getParamAlignment 1614 // size = typeallocsize of element type 1615 Align OptimalAlign = 1616 isKernelFunc 1617 ? getOptimalAlignForParam(ETy) 1618 : TLI->getFunctionByValParamAlign( 1619 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); 1620 1621 unsigned sz = DL.getTypeAllocSize(ETy); 1622 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1623 printParamName(I, paramIndex, O); 1624 O << "[" << sz << "]"; 1625 continue; 1626 } else { 1627 // Split the ETy into constituent parts and 1628 // print .param .b<size> <name> for each part. 1629 // Further, if a part is vector, print the above for 1630 // each vector element. 1631 SmallVector<EVT, 16> vtparts; 1632 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1633 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1634 unsigned elems = 1; 1635 EVT elemtype = vtparts[i]; 1636 if (vtparts[i].isVector()) { 1637 elems = vtparts[i].getVectorNumElements(); 1638 elemtype = vtparts[i].getVectorElementType(); 1639 } 1640 1641 for (unsigned j = 0, je = elems; j != je; ++j) { 1642 unsigned sz = elemtype.getSizeInBits(); 1643 if (elemtype.isInteger()) 1644 sz = promoteScalarArgumentSize(sz); 1645 O << "\t.reg .b" << sz << " "; 1646 printParamName(I, paramIndex, O); 1647 if (j < je - 1) 1648 O << ",\n"; 1649 ++paramIndex; 1650 } 1651 if (i < e - 1) 1652 O << ",\n"; 1653 } 1654 --paramIndex; 1655 continue; 1656 } 1657 } 1658 1659 if (F->isVarArg()) { 1660 if (!first) 1661 O << ",\n"; 1662 O << "\t.param .align " << STI.getMaxRequiredAlignment(); 1663 O << " .b8 "; 1664 getSymbol(F)->print(O, MAI); 1665 O << "_vararg[]"; 1666 } 1667 1668 O << "\n)\n"; 1669 } 1670 1671 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF, 1672 raw_ostream &O) { 1673 const Function &F = MF.getFunction(); 1674 emitFunctionParamList(&F, O); 1675 } 1676 1677 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1678 const MachineFunction &MF) { 1679 SmallString<128> Str; 1680 raw_svector_ostream O(Str); 1681 1682 // Map the global virtual register number to a register class specific 1683 // virtual register number starting from 1 with that class. 1684 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1685 //unsigned numRegClasses = TRI->getNumRegClasses(); 1686 1687 // Emit the Fake Stack Object 1688 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1689 int NumBytes = (int) MFI.getStackSize(); 1690 if (NumBytes) { 1691 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1692 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1693 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1694 O << "\t.reg .b64 \t%SP;\n"; 1695 O << "\t.reg .b64 \t%SPL;\n"; 1696 } else { 1697 O << "\t.reg .b32 \t%SP;\n"; 1698 O << "\t.reg .b32 \t%SPL;\n"; 1699 } 1700 } 1701 1702 // Go through all virtual registers to establish the mapping between the 1703 // global virtual 1704 // register number and the per class virtual register number. 1705 // We use the per class virtual register number in the ptx output. 1706 unsigned int numVRs = MRI->getNumVirtRegs(); 1707 for (unsigned i = 0; i < numVRs; i++) { 1708 Register vr = Register::index2VirtReg(i); 1709 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1710 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1711 int n = regmap.size(); 1712 regmap.insert(std::make_pair(vr, n + 1)); 1713 } 1714 1715 // Emit register declarations 1716 // @TODO: Extract out the real register usage 1717 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1718 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1719 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1720 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1721 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1722 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1723 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1724 1725 // Emit declaration of the virtual registers or 'physical' registers for 1726 // each register class 1727 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1728 const TargetRegisterClass *RC = TRI->getRegClass(i); 1729 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1730 std::string rcname = getNVPTXRegClassName(RC); 1731 std::string rcStr = getNVPTXRegClassStr(RC); 1732 int n = regmap.size(); 1733 1734 // Only declare those registers that may be used. 1735 if (n) { 1736 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1737 << ">;\n"; 1738 } 1739 } 1740 1741 OutStreamer->emitRawText(O.str()); 1742 } 1743 1744 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1745 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1746 bool ignored; 1747 unsigned int numHex; 1748 const char *lead; 1749 1750 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1751 numHex = 8; 1752 lead = "0f"; 1753 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1754 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1755 numHex = 16; 1756 lead = "0d"; 1757 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1758 } else 1759 llvm_unreachable("unsupported fp type"); 1760 1761 APInt API = APF.bitcastToAPInt(); 1762 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1763 } 1764 1765 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1766 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1767 O << CI->getValue(); 1768 return; 1769 } 1770 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1771 printFPConstant(CFP, O); 1772 return; 1773 } 1774 if (isa<ConstantPointerNull>(CPV)) { 1775 O << "0"; 1776 return; 1777 } 1778 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1779 bool IsNonGenericPointer = false; 1780 if (GVar->getType()->getAddressSpace() != 0) { 1781 IsNonGenericPointer = true; 1782 } 1783 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1784 O << "generic("; 1785 getSymbol(GVar)->print(O, MAI); 1786 O << ")"; 1787 } else { 1788 getSymbol(GVar)->print(O, MAI); 1789 } 1790 return; 1791 } 1792 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1793 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false); 1794 printMCExpr(*E, O); 1795 return; 1796 } 1797 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1798 } 1799 1800 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1801 AggBuffer *AggBuffer) { 1802 const DataLayout &DL = getDataLayout(); 1803 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1804 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1805 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1806 // only the space allocated by CPV. 1807 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1808 return; 1809 } 1810 1811 // Helper for filling AggBuffer with APInts. 1812 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1813 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1814 SmallVector<unsigned char, 16> Buf(NumBytes); 1815 for (unsigned I = 0; I < NumBytes; ++I) { 1816 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1817 } 1818 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1819 }; 1820 1821 switch (CPV->getType()->getTypeID()) { 1822 case Type::IntegerTyID: 1823 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1824 AddIntToBuffer(CI->getValue()); 1825 break; 1826 } 1827 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1828 if (const auto *CI = 1829 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1830 AddIntToBuffer(CI->getValue()); 1831 break; 1832 } 1833 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1834 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1835 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1836 AggBuffer->addZeros(AllocSize); 1837 break; 1838 } 1839 } 1840 llvm_unreachable("unsupported integer const type"); 1841 break; 1842 1843 case Type::HalfTyID: 1844 case Type::BFloatTyID: 1845 case Type::FloatTyID: 1846 case Type::DoubleTyID: 1847 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1848 break; 1849 1850 case Type::PointerTyID: { 1851 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1852 AggBuffer->addSymbol(GVar, GVar); 1853 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1854 const Value *v = Cexpr->stripPointerCasts(); 1855 AggBuffer->addSymbol(v, Cexpr); 1856 } 1857 AggBuffer->addZeros(AllocSize); 1858 break; 1859 } 1860 1861 case Type::ArrayTyID: 1862 case Type::FixedVectorTyID: 1863 case Type::StructTyID: { 1864 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1865 bufferAggregateConstant(CPV, AggBuffer); 1866 if (Bytes > AllocSize) 1867 AggBuffer->addZeros(Bytes - AllocSize); 1868 } else if (isa<ConstantAggregateZero>(CPV)) 1869 AggBuffer->addZeros(Bytes); 1870 else 1871 llvm_unreachable("Unexpected Constant type"); 1872 break; 1873 } 1874 1875 default: 1876 llvm_unreachable("unsupported type"); 1877 } 1878 } 1879 1880 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1881 AggBuffer *aggBuffer) { 1882 const DataLayout &DL = getDataLayout(); 1883 int Bytes; 1884 1885 // Integers of arbitrary width 1886 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1887 APInt Val = CI->getValue(); 1888 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1889 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1890 aggBuffer->addBytes(&Byte, 1, 1); 1891 Val.lshrInPlace(8); 1892 } 1893 return; 1894 } 1895 1896 // Old constants 1897 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1898 if (CPV->getNumOperands()) 1899 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1900 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1901 return; 1902 } 1903 1904 if (const ConstantDataSequential *CDS = 1905 dyn_cast<ConstantDataSequential>(CPV)) { 1906 if (CDS->getNumElements()) 1907 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1908 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1909 aggBuffer); 1910 return; 1911 } 1912 1913 if (isa<ConstantStruct>(CPV)) { 1914 if (CPV->getNumOperands()) { 1915 StructType *ST = cast<StructType>(CPV->getType()); 1916 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1917 if (i == (e - 1)) 1918 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1919 DL.getTypeAllocSize(ST) - 1920 DL.getStructLayout(ST)->getElementOffset(i); 1921 else 1922 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1923 DL.getStructLayout(ST)->getElementOffset(i); 1924 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1925 } 1926 } 1927 return; 1928 } 1929 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1930 } 1931 1932 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1933 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1934 /// expressions that are representable in PTX and create 1935 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1936 const MCExpr * 1937 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1938 MCContext &Ctx = OutContext; 1939 1940 if (CV->isNullValue() || isa<UndefValue>(CV)) 1941 return MCConstantExpr::create(0, Ctx); 1942 1943 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1944 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1945 1946 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1947 const MCSymbolRefExpr *Expr = 1948 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1949 if (ProcessingGeneric) { 1950 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1951 } else { 1952 return Expr; 1953 } 1954 } 1955 1956 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 1957 if (!CE) { 1958 llvm_unreachable("Unknown constant value to lower!"); 1959 } 1960 1961 switch (CE->getOpcode()) { 1962 default: { 1963 // If the code isn't optimized, there may be outstanding folding 1964 // opportunities. Attempt to fold the expression using DataLayout as a 1965 // last resort before giving up. 1966 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 1967 if (C != CE) 1968 return lowerConstantForGV(C, ProcessingGeneric); 1969 1970 // Otherwise report the problem to the user. 1971 std::string S; 1972 raw_string_ostream OS(S); 1973 OS << "Unsupported expression in static initializer: "; 1974 CE->printAsOperand(OS, /*PrintType=*/false, 1975 !MF ? nullptr : MF->getFunction().getParent()); 1976 report_fatal_error(Twine(OS.str())); 1977 } 1978 1979 case Instruction::AddrSpaceCast: { 1980 // Strip the addrspacecast and pass along the operand 1981 PointerType *DstTy = cast<PointerType>(CE->getType()); 1982 if (DstTy->getAddressSpace() == 0) { 1983 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 1984 } 1985 std::string S; 1986 raw_string_ostream OS(S); 1987 OS << "Unsupported expression in static initializer: "; 1988 CE->printAsOperand(OS, /*PrintType=*/ false, 1989 !MF ? nullptr : MF->getFunction().getParent()); 1990 report_fatal_error(Twine(OS.str())); 1991 } 1992 1993 case Instruction::GetElementPtr: { 1994 const DataLayout &DL = getDataLayout(); 1995 1996 // Generate a symbolic expression for the byte address 1997 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 1998 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 1999 2000 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2001 ProcessingGeneric); 2002 if (!OffsetAI) 2003 return Base; 2004 2005 int64_t Offset = OffsetAI.getSExtValue(); 2006 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2007 Ctx); 2008 } 2009 2010 case Instruction::Trunc: 2011 // We emit the value and depend on the assembler to truncate the generated 2012 // expression properly. This is important for differences between 2013 // blockaddress labels. Since the two labels are in the same function, it 2014 // is reasonable to treat their delta as a 32-bit value. 2015 [[fallthrough]]; 2016 case Instruction::BitCast: 2017 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2018 2019 case Instruction::IntToPtr: { 2020 const DataLayout &DL = getDataLayout(); 2021 2022 // Handle casts to pointers by changing them into casts to the appropriate 2023 // integer type. This promotes constant folding and simplifies this code. 2024 Constant *Op = CE->getOperand(0); 2025 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2026 false/*ZExt*/); 2027 return lowerConstantForGV(Op, ProcessingGeneric); 2028 } 2029 2030 case Instruction::PtrToInt: { 2031 const DataLayout &DL = getDataLayout(); 2032 2033 // Support only foldable casts to/from pointers that can be eliminated by 2034 // changing the pointer to the appropriately sized integer type. 2035 Constant *Op = CE->getOperand(0); 2036 Type *Ty = CE->getType(); 2037 2038 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2039 2040 // We can emit the pointer value into this slot if the slot is an 2041 // integer slot equal to the size of the pointer. 2042 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2043 return OpExpr; 2044 2045 // Otherwise the pointer is smaller than the resultant integer, mask off 2046 // the high bits so we are sure to get a proper truncation if the input is 2047 // a constant expr. 2048 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2049 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2050 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2051 } 2052 2053 // The MC library also has a right-shift operator, but it isn't consistently 2054 // signed or unsigned between different targets. 2055 case Instruction::Add: { 2056 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2057 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2058 switch (CE->getOpcode()) { 2059 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2060 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2061 } 2062 } 2063 } 2064 } 2065 2066 // Copy of MCExpr::print customized for NVPTX 2067 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2068 switch (Expr.getKind()) { 2069 case MCExpr::Target: 2070 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2071 case MCExpr::Constant: 2072 OS << cast<MCConstantExpr>(Expr).getValue(); 2073 return; 2074 2075 case MCExpr::SymbolRef: { 2076 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2077 const MCSymbol &Sym = SRE.getSymbol(); 2078 Sym.print(OS, MAI); 2079 return; 2080 } 2081 2082 case MCExpr::Unary: { 2083 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2084 switch (UE.getOpcode()) { 2085 case MCUnaryExpr::LNot: OS << '!'; break; 2086 case MCUnaryExpr::Minus: OS << '-'; break; 2087 case MCUnaryExpr::Not: OS << '~'; break; 2088 case MCUnaryExpr::Plus: OS << '+'; break; 2089 } 2090 printMCExpr(*UE.getSubExpr(), OS); 2091 return; 2092 } 2093 2094 case MCExpr::Binary: { 2095 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2096 2097 // Only print parens around the LHS if it is non-trivial. 2098 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2099 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2100 printMCExpr(*BE.getLHS(), OS); 2101 } else { 2102 OS << '('; 2103 printMCExpr(*BE.getLHS(), OS); 2104 OS<< ')'; 2105 } 2106 2107 switch (BE.getOpcode()) { 2108 case MCBinaryExpr::Add: 2109 // Print "X-42" instead of "X+-42". 2110 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2111 if (RHSC->getValue() < 0) { 2112 OS << RHSC->getValue(); 2113 return; 2114 } 2115 } 2116 2117 OS << '+'; 2118 break; 2119 default: llvm_unreachable("Unhandled binary operator"); 2120 } 2121 2122 // Only print parens around the LHS if it is non-trivial. 2123 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2124 printMCExpr(*BE.getRHS(), OS); 2125 } else { 2126 OS << '('; 2127 printMCExpr(*BE.getRHS(), OS); 2128 OS << ')'; 2129 } 2130 return; 2131 } 2132 } 2133 2134 llvm_unreachable("Invalid expression kind!"); 2135 } 2136 2137 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2138 /// 2139 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2140 const char *ExtraCode, raw_ostream &O) { 2141 if (ExtraCode && ExtraCode[0]) { 2142 if (ExtraCode[1] != 0) 2143 return true; // Unknown modifier. 2144 2145 switch (ExtraCode[0]) { 2146 default: 2147 // See if this is a generic print operand 2148 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2149 case 'r': 2150 break; 2151 } 2152 } 2153 2154 printOperand(MI, OpNo, O); 2155 2156 return false; 2157 } 2158 2159 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2160 unsigned OpNo, 2161 const char *ExtraCode, 2162 raw_ostream &O) { 2163 if (ExtraCode && ExtraCode[0]) 2164 return true; // Unknown modifier 2165 2166 O << '['; 2167 printMemOperand(MI, OpNo, O); 2168 O << ']'; 2169 2170 return false; 2171 } 2172 2173 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, 2174 raw_ostream &O) { 2175 const MachineOperand &MO = MI->getOperand(opNum); 2176 switch (MO.getType()) { 2177 case MachineOperand::MO_Register: 2178 if (MO.getReg().isPhysical()) { 2179 if (MO.getReg() == NVPTX::VRDepot) 2180 O << DEPOTNAME << getFunctionNumber(); 2181 else 2182 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2183 } else { 2184 emitVirtualRegister(MO.getReg(), O); 2185 } 2186 break; 2187 2188 case MachineOperand::MO_Immediate: 2189 O << MO.getImm(); 2190 break; 2191 2192 case MachineOperand::MO_FPImmediate: 2193 printFPConstant(MO.getFPImm(), O); 2194 break; 2195 2196 case MachineOperand::MO_GlobalAddress: 2197 PrintSymbolOperand(MO, O); 2198 break; 2199 2200 case MachineOperand::MO_MachineBasicBlock: 2201 MO.getMBB()->getSymbol()->print(O, MAI); 2202 break; 2203 2204 default: 2205 llvm_unreachable("Operand type not supported."); 2206 } 2207 } 2208 2209 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, 2210 raw_ostream &O, const char *Modifier) { 2211 printOperand(MI, opNum, O); 2212 2213 if (Modifier && strcmp(Modifier, "add") == 0) { 2214 O << ", "; 2215 printOperand(MI, opNum + 1, O); 2216 } else { 2217 if (MI->getOperand(opNum + 1).isImm() && 2218 MI->getOperand(opNum + 1).getImm() == 0) 2219 return; // don't print ',0' or '+0' 2220 O << "+"; 2221 printOperand(MI, opNum + 1, O); 2222 } 2223 } 2224 2225 // Force static initialization. 2226 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2227 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2228 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2229 } 2230