1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Triple.h" 37 #include "llvm/ADT/Twine.h" 38 #include "llvm/Analysis/ConstantFolding.h" 39 #include "llvm/CodeGen/Analysis.h" 40 #include "llvm/CodeGen/MachineBasicBlock.h" 41 #include "llvm/CodeGen/MachineFrameInfo.h" 42 #include "llvm/CodeGen/MachineFunction.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineLoopInfo.h" 45 #include "llvm/CodeGen/MachineModuleInfo.h" 46 #include "llvm/CodeGen/MachineOperand.h" 47 #include "llvm/CodeGen/MachineRegisterInfo.h" 48 #include "llvm/CodeGen/TargetRegisterInfo.h" 49 #include "llvm/CodeGen/ValueTypes.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalValue.h" 61 #include "llvm/IR/GlobalVariable.h" 62 #include "llvm/IR/Instruction.h" 63 #include "llvm/IR/LLVMContext.h" 64 #include "llvm/IR/Module.h" 65 #include "llvm/IR/Operator.h" 66 #include "llvm/IR/Type.h" 67 #include "llvm/IR/User.h" 68 #include "llvm/MC/MCExpr.h" 69 #include "llvm/MC/MCInst.h" 70 #include "llvm/MC/MCInstrDesc.h" 71 #include "llvm/MC/MCStreamer.h" 72 #include "llvm/MC/MCSymbol.h" 73 #include "llvm/MC/TargetRegistry.h" 74 #include "llvm/Support/Casting.h" 75 #include "llvm/Support/CommandLine.h" 76 #include "llvm/Support/Endian.h" 77 #include "llvm/Support/ErrorHandling.h" 78 #include "llvm/Support/MachineValueType.h" 79 #include "llvm/Support/NativeFormatting.h" 80 #include "llvm/Support/Path.h" 81 #include "llvm/Support/raw_ostream.h" 82 #include "llvm/Target/TargetLoweringObjectFile.h" 83 #include "llvm/Target/TargetMachine.h" 84 #include "llvm/Transforms/Utils/UnrollLoop.h" 85 #include <cassert> 86 #include <cstdint> 87 #include <cstring> 88 #include <new> 89 #include <string> 90 #include <utility> 91 #include <vector> 92 93 using namespace llvm; 94 95 #define DEPOTNAME "__local_depot" 96 97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 98 /// depends. 99 static void 100 DiscoverDependentGlobals(const Value *V, 101 DenseSet<const GlobalVariable *> &Globals) { 102 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 103 Globals.insert(GV); 104 else { 105 if (const User *U = dyn_cast<User>(V)) { 106 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 107 DiscoverDependentGlobals(U->getOperand(i), Globals); 108 } 109 } 110 } 111 } 112 113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 114 /// instances to be emitted, but only after any dependents have been added 115 /// first.s 116 static void 117 VisitGlobalVariableForEmission(const GlobalVariable *GV, 118 SmallVectorImpl<const GlobalVariable *> &Order, 119 DenseSet<const GlobalVariable *> &Visited, 120 DenseSet<const GlobalVariable *> &Visiting) { 121 // Have we already visited this one? 122 if (Visited.count(GV)) 123 return; 124 125 // Do we have a circular dependency? 126 if (!Visiting.insert(GV).second) 127 report_fatal_error("Circular dependency found in global variable set"); 128 129 // Make sure we visit all dependents first 130 DenseSet<const GlobalVariable *> Others; 131 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 132 DiscoverDependentGlobals(GV->getOperand(i), Others); 133 134 for (const GlobalVariable *GV : Others) 135 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 136 137 // Now we can visit ourself 138 Order.push_back(GV); 139 Visited.insert(GV); 140 Visiting.erase(GV); 141 } 142 143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 144 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 145 getSubtargetInfo().getFeatureBits()); 146 147 MCInst Inst; 148 lowerToMCInst(MI, Inst); 149 EmitToStreamer(*OutStreamer, Inst); 150 } 151 152 // Handle symbol backtracking for targets that do not support image handles 153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 154 unsigned OpNo, MCOperand &MCOp) { 155 const MachineOperand &MO = MI->getOperand(OpNo); 156 const MCInstrDesc &MCID = MI->getDesc(); 157 158 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 159 // This is a texture fetch, so operand 4 is a texref and operand 5 is 160 // a samplerref 161 if (OpNo == 4 && MO.isImm()) { 162 lowerImageHandleSymbol(MO.getImm(), MCOp); 163 return true; 164 } 165 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 166 lowerImageHandleSymbol(MO.getImm(), MCOp); 167 return true; 168 } 169 170 return false; 171 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 172 unsigned VecSize = 173 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 174 175 // For a surface load of vector size N, the Nth operand will be the surfref 176 if (OpNo == VecSize && MO.isImm()) { 177 lowerImageHandleSymbol(MO.getImm(), MCOp); 178 return true; 179 } 180 181 return false; 182 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 183 // This is a surface store, so operand 0 is a surfref 184 if (OpNo == 0 && MO.isImm()) { 185 lowerImageHandleSymbol(MO.getImm(), MCOp); 186 return true; 187 } 188 189 return false; 190 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 191 // This is a query, so operand 1 is a surfref/texref 192 if (OpNo == 1 && MO.isImm()) { 193 lowerImageHandleSymbol(MO.getImm(), MCOp); 194 return true; 195 } 196 197 return false; 198 } 199 200 return false; 201 } 202 203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 204 // Ewwww 205 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 206 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 207 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 208 const char *Sym = MFI->getImageHandleSymbol(Index); 209 std::string *SymNamePtr = 210 nvTM.getManagedStrPool()->getManagedString(Sym); 211 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr))); 212 } 213 214 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 215 OutMI.setOpcode(MI->getOpcode()); 216 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 217 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 218 const MachineOperand &MO = MI->getOperand(0); 219 OutMI.addOperand(GetSymbolRef( 220 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 221 return; 222 } 223 224 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 225 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 226 const MachineOperand &MO = MI->getOperand(i); 227 228 MCOperand MCOp; 229 if (!STI.hasImageHandles()) { 230 if (lowerImageHandleOperand(MI, i, MCOp)) { 231 OutMI.addOperand(MCOp); 232 continue; 233 } 234 } 235 236 if (lowerOperand(MO, MCOp)) 237 OutMI.addOperand(MCOp); 238 } 239 } 240 241 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 242 MCOperand &MCOp) { 243 switch (MO.getType()) { 244 default: llvm_unreachable("unknown operand type"); 245 case MachineOperand::MO_Register: 246 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 247 break; 248 case MachineOperand::MO_Immediate: 249 MCOp = MCOperand::createImm(MO.getImm()); 250 break; 251 case MachineOperand::MO_MachineBasicBlock: 252 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 253 MO.getMBB()->getSymbol(), OutContext)); 254 break; 255 case MachineOperand::MO_ExternalSymbol: 256 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 257 break; 258 case MachineOperand::MO_GlobalAddress: 259 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 260 break; 261 case MachineOperand::MO_FPImmediate: { 262 const ConstantFP *Cnt = MO.getFPImm(); 263 const APFloat &Val = Cnt->getValueAPF(); 264 265 switch (Cnt->getType()->getTypeID()) { 266 default: report_fatal_error("Unsupported FP type"); break; 267 case Type::HalfTyID: 268 MCOp = MCOperand::createExpr( 269 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 270 break; 271 case Type::FloatTyID: 272 MCOp = MCOperand::createExpr( 273 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 274 break; 275 case Type::DoubleTyID: 276 MCOp = MCOperand::createExpr( 277 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 278 break; 279 } 280 break; 281 } 282 } 283 return true; 284 } 285 286 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 287 if (Register::isVirtualRegister(Reg)) { 288 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 289 290 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 291 unsigned RegNum = RegMap[Reg]; 292 293 // Encode the register class in the upper 4 bits 294 // Must be kept in sync with NVPTXInstPrinter::printRegName 295 unsigned Ret = 0; 296 if (RC == &NVPTX::Int1RegsRegClass) { 297 Ret = (1 << 28); 298 } else if (RC == &NVPTX::Int16RegsRegClass) { 299 Ret = (2 << 28); 300 } else if (RC == &NVPTX::Int32RegsRegClass) { 301 Ret = (3 << 28); 302 } else if (RC == &NVPTX::Int64RegsRegClass) { 303 Ret = (4 << 28); 304 } else if (RC == &NVPTX::Float32RegsRegClass) { 305 Ret = (5 << 28); 306 } else if (RC == &NVPTX::Float64RegsRegClass) { 307 Ret = (6 << 28); 308 } else if (RC == &NVPTX::Float16RegsRegClass) { 309 Ret = (7 << 28); 310 } else if (RC == &NVPTX::Float16x2RegsRegClass) { 311 Ret = (8 << 28); 312 } else { 313 report_fatal_error("Bad register class"); 314 } 315 316 // Insert the vreg number 317 Ret |= (RegNum & 0x0FFFFFFF); 318 return Ret; 319 } else { 320 // Some special-use registers are actually physical registers. 321 // Encode this as the register class ID of 0 and the real register ID. 322 return Reg & 0x0FFFFFFF; 323 } 324 } 325 326 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 327 const MCExpr *Expr; 328 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 329 OutContext); 330 return MCOperand::createExpr(Expr); 331 } 332 333 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 334 const DataLayout &DL = getDataLayout(); 335 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 336 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 337 338 Type *Ty = F->getReturnType(); 339 340 bool isABI = (STI.getSmVersion() >= 20); 341 342 if (Ty->getTypeID() == Type::VoidTyID) 343 return; 344 345 O << " ("; 346 347 if (isABI) { 348 if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) { 349 unsigned size = 0; 350 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 351 size = ITy->getBitWidth(); 352 } else { 353 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 354 size = Ty->getPrimitiveSizeInBits(); 355 } 356 // PTX ABI requires all scalar return values to be at least 32 357 // bits in size. fp16 normally uses .b16 as its storage type in 358 // PTX, so its size must be adjusted here, too. 359 size = promoteScalarArgumentSize(size); 360 361 O << ".param .b" << size << " func_retval0"; 362 } else if (isa<PointerType>(Ty)) { 363 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 364 << " func_retval0"; 365 } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 366 unsigned totalsz = DL.getTypeAllocSize(Ty); 367 unsigned retAlignment = 0; 368 if (!getAlign(*F, 0, retAlignment)) 369 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value(); 370 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz 371 << "]"; 372 } else 373 llvm_unreachable("Unknown return type"); 374 } else { 375 SmallVector<EVT, 16> vtparts; 376 ComputeValueVTs(*TLI, DL, Ty, vtparts); 377 unsigned idx = 0; 378 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 379 unsigned elems = 1; 380 EVT elemtype = vtparts[i]; 381 if (vtparts[i].isVector()) { 382 elems = vtparts[i].getVectorNumElements(); 383 elemtype = vtparts[i].getVectorElementType(); 384 } 385 386 for (unsigned j = 0, je = elems; j != je; ++j) { 387 unsigned sz = elemtype.getSizeInBits(); 388 if (elemtype.isInteger()) 389 sz = promoteScalarArgumentSize(sz); 390 O << ".reg .b" << sz << " func_retval" << idx; 391 if (j < je - 1) 392 O << ", "; 393 ++idx; 394 } 395 if (i < e - 1) 396 O << ", "; 397 } 398 } 399 O << ") "; 400 } 401 402 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 403 raw_ostream &O) { 404 const Function &F = MF.getFunction(); 405 printReturnValStr(&F, O); 406 } 407 408 // Return true if MBB is the header of a loop marked with 409 // llvm.loop.unroll.disable. 410 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll". 411 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 412 const MachineBasicBlock &MBB) const { 413 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 414 // We insert .pragma "nounroll" only to the loop header. 415 if (!LI.isLoopHeader(&MBB)) 416 return false; 417 418 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 419 // we iterate through each back edge of the loop with header MBB, and check 420 // whether its metadata contains llvm.loop.unroll.disable. 421 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 422 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 423 // Edges from other loops to MBB are not back edges. 424 continue; 425 } 426 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 427 if (MDNode *LoopID = 428 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 429 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 430 return true; 431 } 432 } 433 } 434 return false; 435 } 436 437 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 438 AsmPrinter::emitBasicBlockStart(MBB); 439 if (isLoopHeaderOfNoUnroll(MBB)) 440 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 441 } 442 443 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 444 SmallString<128> Str; 445 raw_svector_ostream O(Str); 446 447 if (!GlobalsEmitted) { 448 emitGlobals(*MF->getFunction().getParent()); 449 GlobalsEmitted = true; 450 } 451 452 // Set up 453 MRI = &MF->getRegInfo(); 454 F = &MF->getFunction(); 455 emitLinkageDirective(F, O); 456 if (isKernelFunction(*F)) 457 O << ".entry "; 458 else { 459 O << ".func "; 460 printReturnValStr(*MF, O); 461 } 462 463 CurrentFnSym->print(O, MAI); 464 465 emitFunctionParamList(*MF, O); 466 467 if (isKernelFunction(*F)) 468 emitKernelFunctionDirectives(*F, O); 469 470 OutStreamer->emitRawText(O.str()); 471 472 VRegMapping.clear(); 473 // Emit open brace for function body. 474 OutStreamer->emitRawText(StringRef("{\n")); 475 setAndEmitFunctionVirtualRegisters(*MF); 476 // Emit initial .loc debug directive for correct relocation symbol data. 477 if (MMI && MMI->hasDebugInfo()) 478 emitInitialRawDwarfLocDirective(*MF); 479 } 480 481 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 482 bool Result = AsmPrinter::runOnMachineFunction(F); 483 // Emit closing brace for the body of function F. 484 // The closing brace must be emitted here because we need to emit additional 485 // debug labels/data after the last basic block. 486 // We need to emit the closing brace here because we don't have function that 487 // finished emission of the function body. 488 OutStreamer->emitRawText(StringRef("}\n")); 489 return Result; 490 } 491 492 void NVPTXAsmPrinter::emitFunctionBodyStart() { 493 SmallString<128> Str; 494 raw_svector_ostream O(Str); 495 emitDemotedVars(&MF->getFunction(), O); 496 OutStreamer->emitRawText(O.str()); 497 } 498 499 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 500 VRegMapping.clear(); 501 } 502 503 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 504 SmallString<128> Str; 505 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 506 return OutContext.getOrCreateSymbol(Str); 507 } 508 509 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 510 Register RegNo = MI->getOperand(0).getReg(); 511 if (Register::isVirtualRegister(RegNo)) { 512 OutStreamer->AddComment(Twine("implicit-def: ") + 513 getVirtualRegisterName(RegNo)); 514 } else { 515 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 516 OutStreamer->AddComment(Twine("implicit-def: ") + 517 STI.getRegisterInfo()->getName(RegNo)); 518 } 519 OutStreamer->addBlankLine(); 520 } 521 522 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 523 raw_ostream &O) const { 524 // If the NVVM IR has some of reqntid* specified, then output 525 // the reqntid directive, and set the unspecified ones to 1. 526 // If none of reqntid* is specified, don't output reqntid directive. 527 unsigned reqntidx, reqntidy, reqntidz; 528 bool specified = false; 529 if (!getReqNTIDx(F, reqntidx)) 530 reqntidx = 1; 531 else 532 specified = true; 533 if (!getReqNTIDy(F, reqntidy)) 534 reqntidy = 1; 535 else 536 specified = true; 537 if (!getReqNTIDz(F, reqntidz)) 538 reqntidz = 1; 539 else 540 specified = true; 541 542 if (specified) 543 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz 544 << "\n"; 545 546 // If the NVVM IR has some of maxntid* specified, then output 547 // the maxntid directive, and set the unspecified ones to 1. 548 // If none of maxntid* is specified, don't output maxntid directive. 549 unsigned maxntidx, maxntidy, maxntidz; 550 specified = false; 551 if (!getMaxNTIDx(F, maxntidx)) 552 maxntidx = 1; 553 else 554 specified = true; 555 if (!getMaxNTIDy(F, maxntidy)) 556 maxntidy = 1; 557 else 558 specified = true; 559 if (!getMaxNTIDz(F, maxntidz)) 560 maxntidz = 1; 561 else 562 specified = true; 563 564 if (specified) 565 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz 566 << "\n"; 567 568 unsigned mincta; 569 if (getMinCTASm(F, mincta)) 570 O << ".minnctapersm " << mincta << "\n"; 571 572 unsigned maxnreg; 573 if (getMaxNReg(F, maxnreg)) 574 O << ".maxnreg " << maxnreg << "\n"; 575 } 576 577 std::string 578 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 579 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 580 581 std::string Name; 582 raw_string_ostream NameStr(Name); 583 584 VRegRCMap::const_iterator I = VRegMapping.find(RC); 585 assert(I != VRegMapping.end() && "Bad register class"); 586 const DenseMap<unsigned, unsigned> &RegMap = I->second; 587 588 VRegMap::const_iterator VI = RegMap.find(Reg); 589 assert(VI != RegMap.end() && "Bad virtual register"); 590 unsigned MappedVR = VI->second; 591 592 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 593 594 NameStr.flush(); 595 return Name; 596 } 597 598 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 599 raw_ostream &O) { 600 O << getVirtualRegisterName(vr); 601 } 602 603 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 604 emitLinkageDirective(F, O); 605 if (isKernelFunction(*F)) 606 O << ".entry "; 607 else 608 O << ".func "; 609 printReturnValStr(F, O); 610 getSymbol(F)->print(O, MAI); 611 O << "\n"; 612 emitFunctionParamList(F, O); 613 O << ";\n"; 614 } 615 616 static bool usedInGlobalVarDef(const Constant *C) { 617 if (!C) 618 return false; 619 620 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 621 return GV->getName() != "llvm.used"; 622 } 623 624 for (const User *U : C->users()) 625 if (const Constant *C = dyn_cast<Constant>(U)) 626 if (usedInGlobalVarDef(C)) 627 return true; 628 629 return false; 630 } 631 632 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 633 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 634 if (othergv->getName() == "llvm.used") 635 return true; 636 } 637 638 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 639 if (instr->getParent() && instr->getParent()->getParent()) { 640 const Function *curFunc = instr->getParent()->getParent(); 641 if (oneFunc && (curFunc != oneFunc)) 642 return false; 643 oneFunc = curFunc; 644 return true; 645 } else 646 return false; 647 } 648 649 for (const User *UU : U->users()) 650 if (!usedInOneFunc(UU, oneFunc)) 651 return false; 652 653 return true; 654 } 655 656 /* Find out if a global variable can be demoted to local scope. 657 * Currently, this is valid for CUDA shared variables, which have local 658 * scope and global lifetime. So the conditions to check are : 659 * 1. Is the global variable in shared address space? 660 * 2. Does it have internal linkage? 661 * 3. Is the global variable referenced only in one function? 662 */ 663 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 664 if (!gv->hasInternalLinkage()) 665 return false; 666 PointerType *Pty = gv->getType(); 667 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 668 return false; 669 670 const Function *oneFunc = nullptr; 671 672 bool flag = usedInOneFunc(gv, oneFunc); 673 if (!flag) 674 return false; 675 if (!oneFunc) 676 return false; 677 f = oneFunc; 678 return true; 679 } 680 681 static bool useFuncSeen(const Constant *C, 682 DenseMap<const Function *, bool> &seenMap) { 683 for (const User *U : C->users()) { 684 if (const Constant *cu = dyn_cast<Constant>(U)) { 685 if (useFuncSeen(cu, seenMap)) 686 return true; 687 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 688 const BasicBlock *bb = I->getParent(); 689 if (!bb) 690 continue; 691 const Function *caller = bb->getParent(); 692 if (!caller) 693 continue; 694 if (seenMap.find(caller) != seenMap.end()) 695 return true; 696 } 697 } 698 return false; 699 } 700 701 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 702 DenseMap<const Function *, bool> seenMap; 703 for (const Function &F : M) { 704 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 705 emitDeclaration(&F, O); 706 continue; 707 } 708 709 if (F.isDeclaration()) { 710 if (F.use_empty()) 711 continue; 712 if (F.getIntrinsicID()) 713 continue; 714 emitDeclaration(&F, O); 715 continue; 716 } 717 for (const User *U : F.users()) { 718 if (const Constant *C = dyn_cast<Constant>(U)) { 719 if (usedInGlobalVarDef(C)) { 720 // The use is in the initialization of a global variable 721 // that is a function pointer, so print a declaration 722 // for the original function 723 emitDeclaration(&F, O); 724 break; 725 } 726 // Emit a declaration of this function if the function that 727 // uses this constant expr has already been seen. 728 if (useFuncSeen(C, seenMap)) { 729 emitDeclaration(&F, O); 730 break; 731 } 732 } 733 734 if (!isa<Instruction>(U)) 735 continue; 736 const Instruction *instr = cast<Instruction>(U); 737 const BasicBlock *bb = instr->getParent(); 738 if (!bb) 739 continue; 740 const Function *caller = bb->getParent(); 741 if (!caller) 742 continue; 743 744 // If a caller has already been seen, then the caller is 745 // appearing in the module before the callee. so print out 746 // a declaration for the callee. 747 if (seenMap.find(caller) != seenMap.end()) { 748 emitDeclaration(&F, O); 749 break; 750 } 751 } 752 seenMap[&F] = true; 753 } 754 } 755 756 static bool isEmptyXXStructor(GlobalVariable *GV) { 757 if (!GV) return true; 758 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 759 if (!InitList) return true; // Not an array; we don't know how to parse. 760 return InitList->getNumOperands() == 0; 761 } 762 763 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 764 // Construct a default subtarget off of the TargetMachine defaults. The 765 // rest of NVPTX isn't friendly to change subtargets per function and 766 // so the default TargetMachine will have all of the options. 767 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 768 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 769 SmallString<128> Str1; 770 raw_svector_ostream OS1(Str1); 771 772 // Emit header before any dwarf directives are emitted below. 773 emitHeader(M, OS1, *STI); 774 OutStreamer->emitRawText(OS1.str()); 775 } 776 777 bool NVPTXAsmPrinter::doInitialization(Module &M) { 778 if (M.alias_size()) { 779 report_fatal_error("Module has aliases, which NVPTX does not support."); 780 return true; // error 781 } 782 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) { 783 report_fatal_error( 784 "Module has a nontrivial global ctor, which NVPTX does not support."); 785 return true; // error 786 } 787 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) { 788 report_fatal_error( 789 "Module has a nontrivial global dtor, which NVPTX does not support."); 790 return true; // error 791 } 792 793 // We need to call the parent's one explicitly. 794 bool Result = AsmPrinter::doInitialization(M); 795 796 GlobalsEmitted = false; 797 798 return Result; 799 } 800 801 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 802 SmallString<128> Str2; 803 raw_svector_ostream OS2(Str2); 804 805 emitDeclarations(M, OS2); 806 807 // As ptxas does not support forward references of globals, we need to first 808 // sort the list of module-level globals in def-use order. We visit each 809 // global variable in order, and ensure that we emit it *after* its dependent 810 // globals. We use a little extra memory maintaining both a set and a list to 811 // have fast searches while maintaining a strict ordering. 812 SmallVector<const GlobalVariable *, 8> Globals; 813 DenseSet<const GlobalVariable *> GVVisited; 814 DenseSet<const GlobalVariable *> GVVisiting; 815 816 // Visit each global variable, in order 817 for (const GlobalVariable &I : M.globals()) 818 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 819 820 assert(GVVisited.size() == M.getGlobalList().size() && 821 "Missed a global variable"); 822 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 823 824 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 825 const NVPTXSubtarget &STI = 826 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 827 828 // Print out module-level global variables in proper order 829 for (unsigned i = 0, e = Globals.size(); i != e; ++i) 830 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI); 831 832 OS2 << '\n'; 833 834 OutStreamer->emitRawText(OS2.str()); 835 } 836 837 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 838 const NVPTXSubtarget &STI) { 839 O << "//\n"; 840 O << "// Generated by LLVM NVPTX Back-End\n"; 841 O << "//\n"; 842 O << "\n"; 843 844 unsigned PTXVersion = STI.getPTXVersion(); 845 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 846 847 O << ".target "; 848 O << STI.getTargetName(); 849 850 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 851 if (NTM.getDrvInterface() == NVPTX::NVCL) 852 O << ", texmode_independent"; 853 854 bool HasFullDebugInfo = false; 855 for (DICompileUnit *CU : M.debug_compile_units()) { 856 switch(CU->getEmissionKind()) { 857 case DICompileUnit::NoDebug: 858 case DICompileUnit::DebugDirectivesOnly: 859 break; 860 case DICompileUnit::LineTablesOnly: 861 case DICompileUnit::FullDebug: 862 HasFullDebugInfo = true; 863 break; 864 } 865 if (HasFullDebugInfo) 866 break; 867 } 868 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 869 O << ", debug"; 870 871 O << "\n"; 872 873 O << ".address_size "; 874 if (NTM.is64Bit()) 875 O << "64"; 876 else 877 O << "32"; 878 O << "\n"; 879 880 O << "\n"; 881 } 882 883 bool NVPTXAsmPrinter::doFinalization(Module &M) { 884 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 885 886 // If we did not emit any functions, then the global declarations have not 887 // yet been emitted. 888 if (!GlobalsEmitted) { 889 emitGlobals(M); 890 GlobalsEmitted = true; 891 } 892 893 // call doFinalization 894 bool ret = AsmPrinter::doFinalization(M); 895 896 clearAnnotationCache(&M); 897 898 if (auto *TS = static_cast<NVPTXTargetStreamer *>( 899 OutStreamer->getTargetStreamer())) { 900 // Close the last emitted section 901 if (HasDebugInfo) { 902 TS->closeLastSection(); 903 // Emit empty .debug_loc section for better support of the empty files. 904 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 905 } 906 907 // Output last DWARF .file directives, if any. 908 TS->outputDwarfFileDirectives(); 909 } 910 911 return ret; 912 913 //bool Result = AsmPrinter::doFinalization(M); 914 // Instead of calling the parents doFinalization, we may 915 // clone parents doFinalization and customize here. 916 // Currently, we if NVISA out the EmitGlobals() in 917 // parent's doFinalization, which is too intrusive. 918 // 919 // Same for the doInitialization. 920 //return Result; 921 } 922 923 // This function emits appropriate linkage directives for 924 // functions and global variables. 925 // 926 // extern function declaration -> .extern 927 // extern function definition -> .visible 928 // external global variable with init -> .visible 929 // external without init -> .extern 930 // appending -> not allowed, assert. 931 // for any linkage other than 932 // internal, private, linker_private, 933 // linker_private_weak, linker_private_weak_def_auto, 934 // we emit -> .weak. 935 936 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 937 raw_ostream &O) { 938 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 939 if (V->hasExternalLinkage()) { 940 if (isa<GlobalVariable>(V)) { 941 const GlobalVariable *GVar = cast<GlobalVariable>(V); 942 if (GVar) { 943 if (GVar->hasInitializer()) 944 O << ".visible "; 945 else 946 O << ".extern "; 947 } 948 } else if (V->isDeclaration()) 949 O << ".extern "; 950 else 951 O << ".visible "; 952 } else if (V->hasAppendingLinkage()) { 953 std::string msg; 954 msg.append("Error: "); 955 msg.append("Symbol "); 956 if (V->hasName()) 957 msg.append(std::string(V->getName())); 958 msg.append("has unsupported appending linkage type"); 959 llvm_unreachable(msg.c_str()); 960 } else if (!V->hasInternalLinkage() && 961 !V->hasPrivateLinkage()) { 962 O << ".weak "; 963 } 964 } 965 } 966 967 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 968 raw_ostream &O, bool processDemoted, 969 const NVPTXSubtarget &STI) { 970 // Skip meta data 971 if (GVar->hasSection()) { 972 if (GVar->getSection() == "llvm.metadata") 973 return; 974 } 975 976 // Skip LLVM intrinsic global variables 977 if (GVar->getName().startswith("llvm.") || 978 GVar->getName().startswith("nvvm.")) 979 return; 980 981 const DataLayout &DL = getDataLayout(); 982 983 // GlobalVariables are always constant pointers themselves. 984 PointerType *PTy = GVar->getType(); 985 Type *ETy = GVar->getValueType(); 986 987 if (GVar->hasExternalLinkage()) { 988 if (GVar->hasInitializer()) 989 O << ".visible "; 990 else 991 O << ".extern "; 992 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 993 GVar->hasAvailableExternallyLinkage() || 994 GVar->hasCommonLinkage()) { 995 O << ".weak "; 996 } 997 998 if (isTexture(*GVar)) { 999 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1000 return; 1001 } 1002 1003 if (isSurface(*GVar)) { 1004 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1005 return; 1006 } 1007 1008 if (GVar->isDeclaration()) { 1009 // (extern) declarations, no definition or initializer 1010 // Currently the only known declaration is for an automatic __local 1011 // (.shared) promoted to global. 1012 emitPTXGlobalVariable(GVar, O, STI); 1013 O << ";\n"; 1014 return; 1015 } 1016 1017 if (isSampler(*GVar)) { 1018 O << ".global .samplerref " << getSamplerName(*GVar); 1019 1020 const Constant *Initializer = nullptr; 1021 if (GVar->hasInitializer()) 1022 Initializer = GVar->getInitializer(); 1023 const ConstantInt *CI = nullptr; 1024 if (Initializer) 1025 CI = dyn_cast<ConstantInt>(Initializer); 1026 if (CI) { 1027 unsigned sample = CI->getZExtValue(); 1028 1029 O << " = { "; 1030 1031 for (int i = 0, 1032 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1033 i < 3; i++) { 1034 O << "addr_mode_" << i << " = "; 1035 switch (addr) { 1036 case 0: 1037 O << "wrap"; 1038 break; 1039 case 1: 1040 O << "clamp_to_border"; 1041 break; 1042 case 2: 1043 O << "clamp_to_edge"; 1044 break; 1045 case 3: 1046 O << "wrap"; 1047 break; 1048 case 4: 1049 O << "mirror"; 1050 break; 1051 } 1052 O << ", "; 1053 } 1054 O << "filter_mode = "; 1055 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1056 case 0: 1057 O << "nearest"; 1058 break; 1059 case 1: 1060 O << "linear"; 1061 break; 1062 case 2: 1063 llvm_unreachable("Anisotropic filtering is not supported"); 1064 default: 1065 O << "nearest"; 1066 break; 1067 } 1068 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1069 O << ", force_unnormalized_coords = 1"; 1070 } 1071 O << " }"; 1072 } 1073 1074 O << ";\n"; 1075 return; 1076 } 1077 1078 if (GVar->hasPrivateLinkage()) { 1079 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1080 return; 1081 1082 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1083 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1084 return; 1085 if (GVar->use_empty()) 1086 return; 1087 } 1088 1089 const Function *demotedFunc = nullptr; 1090 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1091 O << "// " << GVar->getName() << " has been demoted\n"; 1092 if (localDecls.find(demotedFunc) != localDecls.end()) 1093 localDecls[demotedFunc].push_back(GVar); 1094 else { 1095 std::vector<const GlobalVariable *> temp; 1096 temp.push_back(GVar); 1097 localDecls[demotedFunc] = temp; 1098 } 1099 return; 1100 } 1101 1102 O << "."; 1103 emitPTXAddressSpace(PTy->getAddressSpace(), O); 1104 1105 if (isManaged(*GVar)) { 1106 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1107 report_fatal_error( 1108 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1109 } 1110 O << " .attribute(.managed)"; 1111 } 1112 1113 if (MaybeAlign A = GVar->getAlign()) 1114 O << " .align " << A->value(); 1115 else 1116 O << " .align " << (int)DL.getPrefTypeAlignment(ETy); 1117 1118 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1119 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1120 O << " ."; 1121 // Special case: ABI requires that we use .u8 for predicates 1122 if (ETy->isIntegerTy(1)) 1123 O << "u8"; 1124 else 1125 O << getPTXFundamentalTypeStr(ETy, false); 1126 O << " "; 1127 getSymbol(GVar)->print(O, MAI); 1128 1129 // Ptx allows variable initilization only for constant and global state 1130 // spaces. 1131 if (GVar->hasInitializer()) { 1132 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1133 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1134 const Constant *Initializer = GVar->getInitializer(); 1135 // 'undef' is treated as there is no value specified. 1136 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1137 O << " = "; 1138 printScalarConstant(Initializer, O); 1139 } 1140 } else { 1141 // The frontend adds zero-initializer to device and constant variables 1142 // that don't have an initial value, and UndefValue to shared 1143 // variables, so skip warning for this case. 1144 if (!GVar->getInitializer()->isNullValue() && 1145 !isa<UndefValue>(GVar->getInitializer())) { 1146 report_fatal_error("initial value of '" + GVar->getName() + 1147 "' is not allowed in addrspace(" + 1148 Twine(PTy->getAddressSpace()) + ")"); 1149 } 1150 } 1151 } 1152 } else { 1153 unsigned int ElementSize = 0; 1154 1155 // Although PTX has direct support for struct type and array type and 1156 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1157 // targets that support these high level field accesses. Structs, arrays 1158 // and vectors are lowered into arrays of bytes. 1159 switch (ETy->getTypeID()) { 1160 case Type::IntegerTyID: // Integers larger than 64 bits 1161 case Type::StructTyID: 1162 case Type::ArrayTyID: 1163 case Type::FixedVectorTyID: 1164 ElementSize = DL.getTypeStoreSize(ETy); 1165 // Ptx allows variable initilization only for constant and 1166 // global state spaces. 1167 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1168 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1169 GVar->hasInitializer()) { 1170 const Constant *Initializer = GVar->getInitializer(); 1171 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1172 AggBuffer aggBuffer(ElementSize, *this); 1173 bufferAggregateConstant(Initializer, &aggBuffer); 1174 if (aggBuffer.numSymbols()) { 1175 unsigned int ptrSize = MAI->getCodePointerSize(); 1176 if (ElementSize % ptrSize || 1177 !aggBuffer.allSymbolsAligned(ptrSize)) { 1178 // Print in bytes and use the mask() operator for pointers. 1179 if (!STI.hasMaskOperator()) 1180 report_fatal_error( 1181 "initialized packed aggregate with pointers '" + 1182 GVar->getName() + 1183 "' requires at least PTX ISA version 7.1"); 1184 O << " .u8 "; 1185 getSymbol(GVar)->print(O, MAI); 1186 O << "[" << ElementSize << "] = {"; 1187 aggBuffer.printBytes(O); 1188 O << "}"; 1189 } else { 1190 O << " .u" << ptrSize * 8 << " "; 1191 getSymbol(GVar)->print(O, MAI); 1192 O << "[" << ElementSize / ptrSize << "] = {"; 1193 aggBuffer.printWords(O); 1194 O << "}"; 1195 } 1196 } else { 1197 O << " .b8 "; 1198 getSymbol(GVar)->print(O, MAI); 1199 O << "[" << ElementSize << "] = {"; 1200 aggBuffer.printBytes(O); 1201 O << "}"; 1202 } 1203 } else { 1204 O << " .b8 "; 1205 getSymbol(GVar)->print(O, MAI); 1206 if (ElementSize) { 1207 O << "["; 1208 O << ElementSize; 1209 O << "]"; 1210 } 1211 } 1212 } else { 1213 O << " .b8 "; 1214 getSymbol(GVar)->print(O, MAI); 1215 if (ElementSize) { 1216 O << "["; 1217 O << ElementSize; 1218 O << "]"; 1219 } 1220 } 1221 break; 1222 default: 1223 llvm_unreachable("type not supported yet"); 1224 } 1225 } 1226 O << ";\n"; 1227 } 1228 1229 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1230 const Value *v = Symbols[nSym]; 1231 const Value *v0 = SymbolsBeforeStripping[nSym]; 1232 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1233 MCSymbol *Name = AP.getSymbol(GVar); 1234 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1235 // Is v0 a generic pointer? 1236 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1237 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1238 os << "generic("; 1239 Name->print(os, AP.MAI); 1240 os << ")"; 1241 } else { 1242 Name->print(os, AP.MAI); 1243 } 1244 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1245 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1246 AP.printMCExpr(*Expr, os); 1247 } else 1248 llvm_unreachable("symbol type unknown"); 1249 } 1250 1251 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1252 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1253 symbolPosInBuffer.push_back(size); 1254 unsigned int nSym = 0; 1255 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1256 for (unsigned int pos = 0; pos < size;) { 1257 if (pos) 1258 os << ", "; 1259 if (pos != nextSymbolPos) { 1260 os << (unsigned int)buffer[pos]; 1261 ++pos; 1262 continue; 1263 } 1264 // Generate a per-byte mask() operator for the symbol, which looks like: 1265 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1266 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1267 std::string symText; 1268 llvm::raw_string_ostream oss(symText); 1269 printSymbol(nSym, oss); 1270 for (unsigned i = 0; i < ptrSize; ++i) { 1271 if (i) 1272 os << ", "; 1273 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1274 os << "(" << symText << ")"; 1275 } 1276 pos += ptrSize; 1277 nextSymbolPos = symbolPosInBuffer[++nSym]; 1278 assert(nextSymbolPos >= pos); 1279 } 1280 } 1281 1282 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1283 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1284 symbolPosInBuffer.push_back(size); 1285 unsigned int nSym = 0; 1286 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1287 assert(nextSymbolPos % ptrSize == 0); 1288 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1289 if (pos) 1290 os << ", "; 1291 if (pos == nextSymbolPos) { 1292 printSymbol(nSym, os); 1293 nextSymbolPos = symbolPosInBuffer[++nSym]; 1294 assert(nextSymbolPos % ptrSize == 0); 1295 assert(nextSymbolPos >= pos + ptrSize); 1296 } else if (ptrSize == 4) 1297 os << support::endian::read32le(&buffer[pos]); 1298 else 1299 os << support::endian::read64le(&buffer[pos]); 1300 } 1301 } 1302 1303 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1304 if (localDecls.find(f) == localDecls.end()) 1305 return; 1306 1307 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1308 1309 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1310 const NVPTXSubtarget &STI = 1311 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1312 1313 for (const GlobalVariable *GV : gvars) { 1314 O << "\t// demoted variable\n\t"; 1315 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1316 } 1317 } 1318 1319 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1320 raw_ostream &O) const { 1321 switch (AddressSpace) { 1322 case ADDRESS_SPACE_LOCAL: 1323 O << "local"; 1324 break; 1325 case ADDRESS_SPACE_GLOBAL: 1326 O << "global"; 1327 break; 1328 case ADDRESS_SPACE_CONST: 1329 O << "const"; 1330 break; 1331 case ADDRESS_SPACE_SHARED: 1332 O << "shared"; 1333 break; 1334 default: 1335 report_fatal_error("Bad address space found while emitting PTX: " + 1336 llvm::Twine(AddressSpace)); 1337 break; 1338 } 1339 } 1340 1341 std::string 1342 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1343 switch (Ty->getTypeID()) { 1344 case Type::IntegerTyID: { 1345 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1346 if (NumBits == 1) 1347 return "pred"; 1348 else if (NumBits <= 64) { 1349 std::string name = "u"; 1350 return name + utostr(NumBits); 1351 } else { 1352 llvm_unreachable("Integer too large"); 1353 break; 1354 } 1355 break; 1356 } 1357 case Type::HalfTyID: 1358 // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. 1359 return "b16"; 1360 case Type::FloatTyID: 1361 return "f32"; 1362 case Type::DoubleTyID: 1363 return "f64"; 1364 case Type::PointerTyID: 1365 if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) 1366 if (useB4PTR) 1367 return "b64"; 1368 else 1369 return "u64"; 1370 else if (useB4PTR) 1371 return "b32"; 1372 else 1373 return "u32"; 1374 default: 1375 break; 1376 } 1377 llvm_unreachable("unexpected type"); 1378 } 1379 1380 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1381 raw_ostream &O, 1382 const NVPTXSubtarget &STI) { 1383 const DataLayout &DL = getDataLayout(); 1384 1385 // GlobalVariables are always constant pointers themselves. 1386 Type *ETy = GVar->getValueType(); 1387 1388 O << "."; 1389 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1390 if (isManaged(*GVar)) { 1391 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1392 report_fatal_error( 1393 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1394 } 1395 O << " .attribute(.managed)"; 1396 } 1397 if (MaybeAlign A = GVar->getAlign()) 1398 O << " .align " << A->value(); 1399 else 1400 O << " .align " << (int)DL.getPrefTypeAlignment(ETy); 1401 1402 // Special case for i128 1403 if (ETy->isIntegerTy(128)) { 1404 O << " .b8 "; 1405 getSymbol(GVar)->print(O, MAI); 1406 O << "[16]"; 1407 return; 1408 } 1409 1410 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1411 O << " ."; 1412 O << getPTXFundamentalTypeStr(ETy); 1413 O << " "; 1414 getSymbol(GVar)->print(O, MAI); 1415 return; 1416 } 1417 1418 int64_t ElementSize = 0; 1419 1420 // Although PTX has direct support for struct type and array type and LLVM IR 1421 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1422 // support these high level field accesses. Structs and arrays are lowered 1423 // into arrays of bytes. 1424 switch (ETy->getTypeID()) { 1425 case Type::StructTyID: 1426 case Type::ArrayTyID: 1427 case Type::FixedVectorTyID: 1428 ElementSize = DL.getTypeStoreSize(ETy); 1429 O << " .b8 "; 1430 getSymbol(GVar)->print(O, MAI); 1431 O << "["; 1432 if (ElementSize) { 1433 O << ElementSize; 1434 } 1435 O << "]"; 1436 break; 1437 default: 1438 llvm_unreachable("type not supported yet"); 1439 } 1440 } 1441 1442 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, 1443 int paramIndex, raw_ostream &O) { 1444 getSymbol(I->getParent())->print(O, MAI); 1445 O << "_param_" << paramIndex; 1446 } 1447 1448 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1449 const DataLayout &DL = getDataLayout(); 1450 const AttributeList &PAL = F->getAttributes(); 1451 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1452 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1453 1454 Function::const_arg_iterator I, E; 1455 unsigned paramIndex = 0; 1456 bool first = true; 1457 bool isKernelFunc = isKernelFunction(*F); 1458 bool isABI = (STI.getSmVersion() >= 20); 1459 bool hasImageHandles = STI.hasImageHandles(); 1460 MVT thePointerTy = TLI->getPointerTy(DL); 1461 1462 if (F->arg_empty()) { 1463 O << "()\n"; 1464 return; 1465 } 1466 1467 O << "(\n"; 1468 1469 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1470 Type *Ty = I->getType(); 1471 1472 if (!first) 1473 O << ",\n"; 1474 1475 first = false; 1476 1477 // Handle image/sampler parameters 1478 if (isKernelFunction(*F)) { 1479 if (isSampler(*I) || isImage(*I)) { 1480 if (isImage(*I)) { 1481 std::string sname = std::string(I->getName()); 1482 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1483 if (hasImageHandles) 1484 O << "\t.param .u64 .ptr .surfref "; 1485 else 1486 O << "\t.param .surfref "; 1487 CurrentFnSym->print(O, MAI); 1488 O << "_param_" << paramIndex; 1489 } 1490 else { // Default image is read_only 1491 if (hasImageHandles) 1492 O << "\t.param .u64 .ptr .texref "; 1493 else 1494 O << "\t.param .texref "; 1495 CurrentFnSym->print(O, MAI); 1496 O << "_param_" << paramIndex; 1497 } 1498 } else { 1499 if (hasImageHandles) 1500 O << "\t.param .u64 .ptr .samplerref "; 1501 else 1502 O << "\t.param .samplerref "; 1503 CurrentFnSym->print(O, MAI); 1504 O << "_param_" << paramIndex; 1505 } 1506 continue; 1507 } 1508 } 1509 1510 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1511 paramIndex](Type *Ty) -> Align { 1512 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1513 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1514 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1515 }; 1516 1517 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1518 if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 1519 // Just print .param .align <a> .b8 .param[size]; 1520 // <a> = optimal alignment for the element type; always multiple of 1521 // PAL.getParamAlignment 1522 // size = typeallocsize of element type 1523 Align OptimalAlign = getOptimalAlignForParam(Ty); 1524 1525 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1526 printParamName(I, paramIndex, O); 1527 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1528 1529 continue; 1530 } 1531 // Just a scalar 1532 auto *PTy = dyn_cast<PointerType>(Ty); 1533 if (isKernelFunc) { 1534 if (PTy) { 1535 // Special handling for pointer arguments to kernel 1536 O << "\t.param .u" << thePointerTy.getSizeInBits() << " "; 1537 1538 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1539 NVPTX::CUDA) { 1540 int addrSpace = PTy->getAddressSpace(); 1541 switch (addrSpace) { 1542 default: 1543 O << ".ptr "; 1544 break; 1545 case ADDRESS_SPACE_CONST: 1546 O << ".ptr .const "; 1547 break; 1548 case ADDRESS_SPACE_SHARED: 1549 O << ".ptr .shared "; 1550 break; 1551 case ADDRESS_SPACE_GLOBAL: 1552 O << ".ptr .global "; 1553 break; 1554 } 1555 Align ParamAlign = I->getParamAlign().valueOrOne(); 1556 O << ".align " << ParamAlign.value() << " "; 1557 } 1558 printParamName(I, paramIndex, O); 1559 continue; 1560 } 1561 1562 // non-pointer scalar to kernel func 1563 O << "\t.param ."; 1564 // Special case: predicate operands become .u8 types 1565 if (Ty->isIntegerTy(1)) 1566 O << "u8"; 1567 else 1568 O << getPTXFundamentalTypeStr(Ty); 1569 O << " "; 1570 printParamName(I, paramIndex, O); 1571 continue; 1572 } 1573 // Non-kernel function, just print .param .b<size> for ABI 1574 // and .reg .b<size> for non-ABI 1575 unsigned sz = 0; 1576 if (isa<IntegerType>(Ty)) { 1577 sz = cast<IntegerType>(Ty)->getBitWidth(); 1578 sz = promoteScalarArgumentSize(sz); 1579 } else if (isa<PointerType>(Ty)) 1580 sz = thePointerTy.getSizeInBits(); 1581 else if (Ty->isHalfTy()) 1582 // PTX ABI requires all scalar parameters to be at least 32 1583 // bits in size. fp16 normally uses .b16 as its storage type 1584 // in PTX, so its size must be adjusted here, too. 1585 sz = 32; 1586 else 1587 sz = Ty->getPrimitiveSizeInBits(); 1588 if (isABI) 1589 O << "\t.param .b" << sz << " "; 1590 else 1591 O << "\t.reg .b" << sz << " "; 1592 printParamName(I, paramIndex, O); 1593 continue; 1594 } 1595 1596 // param has byVal attribute. 1597 Type *ETy = PAL.getParamByValType(paramIndex); 1598 assert(ETy && "Param should have byval type"); 1599 1600 if (isABI || isKernelFunc) { 1601 // Just print .param .align <a> .b8 .param[size]; 1602 // <a> = optimal alignment for the element type; always multiple of 1603 // PAL.getParamAlignment 1604 // size = typeallocsize of element type 1605 Align OptimalAlign = getOptimalAlignForParam(ETy); 1606 1607 // Work around a bug in ptxas. When PTX code takes address of 1608 // byval parameter with alignment < 4, ptxas generates code to 1609 // spill argument into memory. Alas on sm_50+ ptxas generates 1610 // SASS code that fails with misaligned access. To work around 1611 // the problem, make sure that we align byval parameters by at 1612 // least 4. Matching change must be made in LowerCall() where we 1613 // prepare parameters for the call. 1614 // 1615 // TODO: this will need to be undone when we get to support multi-TU 1616 // device-side compilation as it breaks ABI compatibility with nvcc. 1617 // Hopefully ptxas bug is fixed by then. 1618 if (!isKernelFunc && OptimalAlign < Align(4)) 1619 OptimalAlign = Align(4); 1620 unsigned sz = DL.getTypeAllocSize(ETy); 1621 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1622 printParamName(I, paramIndex, O); 1623 O << "[" << sz << "]"; 1624 continue; 1625 } else { 1626 // Split the ETy into constituent parts and 1627 // print .param .b<size> <name> for each part. 1628 // Further, if a part is vector, print the above for 1629 // each vector element. 1630 SmallVector<EVT, 16> vtparts; 1631 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1632 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1633 unsigned elems = 1; 1634 EVT elemtype = vtparts[i]; 1635 if (vtparts[i].isVector()) { 1636 elems = vtparts[i].getVectorNumElements(); 1637 elemtype = vtparts[i].getVectorElementType(); 1638 } 1639 1640 for (unsigned j = 0, je = elems; j != je; ++j) { 1641 unsigned sz = elemtype.getSizeInBits(); 1642 if (elemtype.isInteger()) 1643 sz = promoteScalarArgumentSize(sz); 1644 O << "\t.reg .b" << sz << " "; 1645 printParamName(I, paramIndex, O); 1646 if (j < je - 1) 1647 O << ",\n"; 1648 ++paramIndex; 1649 } 1650 if (i < e - 1) 1651 O << ",\n"; 1652 } 1653 --paramIndex; 1654 continue; 1655 } 1656 } 1657 1658 O << "\n)\n"; 1659 } 1660 1661 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF, 1662 raw_ostream &O) { 1663 const Function &F = MF.getFunction(); 1664 emitFunctionParamList(&F, O); 1665 } 1666 1667 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1668 const MachineFunction &MF) { 1669 SmallString<128> Str; 1670 raw_svector_ostream O(Str); 1671 1672 // Map the global virtual register number to a register class specific 1673 // virtual register number starting from 1 with that class. 1674 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1675 //unsigned numRegClasses = TRI->getNumRegClasses(); 1676 1677 // Emit the Fake Stack Object 1678 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1679 int NumBytes = (int) MFI.getStackSize(); 1680 if (NumBytes) { 1681 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1682 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1683 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1684 O << "\t.reg .b64 \t%SP;\n"; 1685 O << "\t.reg .b64 \t%SPL;\n"; 1686 } else { 1687 O << "\t.reg .b32 \t%SP;\n"; 1688 O << "\t.reg .b32 \t%SPL;\n"; 1689 } 1690 } 1691 1692 // Go through all virtual registers to establish the mapping between the 1693 // global virtual 1694 // register number and the per class virtual register number. 1695 // We use the per class virtual register number in the ptx output. 1696 unsigned int numVRs = MRI->getNumVirtRegs(); 1697 for (unsigned i = 0; i < numVRs; i++) { 1698 Register vr = Register::index2VirtReg(i); 1699 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1700 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1701 int n = regmap.size(); 1702 regmap.insert(std::make_pair(vr, n + 1)); 1703 } 1704 1705 // Emit register declarations 1706 // @TODO: Extract out the real register usage 1707 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1708 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1709 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1710 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1711 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1712 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1713 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1714 1715 // Emit declaration of the virtual registers or 'physical' registers for 1716 // each register class 1717 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1718 const TargetRegisterClass *RC = TRI->getRegClass(i); 1719 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1720 std::string rcname = getNVPTXRegClassName(RC); 1721 std::string rcStr = getNVPTXRegClassStr(RC); 1722 int n = regmap.size(); 1723 1724 // Only declare those registers that may be used. 1725 if (n) { 1726 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1727 << ">;\n"; 1728 } 1729 } 1730 1731 OutStreamer->emitRawText(O.str()); 1732 } 1733 1734 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1735 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1736 bool ignored; 1737 unsigned int numHex; 1738 const char *lead; 1739 1740 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1741 numHex = 8; 1742 lead = "0f"; 1743 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1744 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1745 numHex = 16; 1746 lead = "0d"; 1747 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1748 } else 1749 llvm_unreachable("unsupported fp type"); 1750 1751 APInt API = APF.bitcastToAPInt(); 1752 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1753 } 1754 1755 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1756 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1757 O << CI->getValue(); 1758 return; 1759 } 1760 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1761 printFPConstant(CFP, O); 1762 return; 1763 } 1764 if (isa<ConstantPointerNull>(CPV)) { 1765 O << "0"; 1766 return; 1767 } 1768 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1769 bool IsNonGenericPointer = false; 1770 if (GVar->getType()->getAddressSpace() != 0) { 1771 IsNonGenericPointer = true; 1772 } 1773 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1774 O << "generic("; 1775 getSymbol(GVar)->print(O, MAI); 1776 O << ")"; 1777 } else { 1778 getSymbol(GVar)->print(O, MAI); 1779 } 1780 return; 1781 } 1782 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1783 const Value *v = Cexpr->stripPointerCasts(); 1784 PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType()); 1785 bool IsNonGenericPointer = false; 1786 if (PTy && PTy->getAddressSpace() != 0) { 1787 IsNonGenericPointer = true; 1788 } 1789 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1790 if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) { 1791 O << "generic("; 1792 getSymbol(GVar)->print(O, MAI); 1793 O << ")"; 1794 } else { 1795 getSymbol(GVar)->print(O, MAI); 1796 } 1797 return; 1798 } else { 1799 lowerConstant(CPV)->print(O, MAI); 1800 return; 1801 } 1802 } 1803 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1804 } 1805 1806 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1807 AggBuffer *AggBuffer) { 1808 const DataLayout &DL = getDataLayout(); 1809 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1810 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1811 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1812 // only the space allocated by CPV. 1813 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1814 return; 1815 } 1816 1817 // Helper for filling AggBuffer with APInts. 1818 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1819 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1820 SmallVector<unsigned char, 16> Buf(NumBytes); 1821 for (unsigned I = 0; I < NumBytes; ++I) { 1822 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1823 } 1824 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1825 }; 1826 1827 switch (CPV->getType()->getTypeID()) { 1828 case Type::IntegerTyID: 1829 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1830 AddIntToBuffer(CI->getValue()); 1831 break; 1832 } 1833 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1834 if (const auto *CI = 1835 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1836 AddIntToBuffer(CI->getValue()); 1837 break; 1838 } 1839 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1840 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1841 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1842 AggBuffer->addZeros(AllocSize); 1843 break; 1844 } 1845 } 1846 llvm_unreachable("unsupported integer const type"); 1847 break; 1848 1849 case Type::HalfTyID: 1850 case Type::FloatTyID: 1851 case Type::DoubleTyID: 1852 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1853 break; 1854 1855 case Type::PointerTyID: { 1856 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1857 AggBuffer->addSymbol(GVar, GVar); 1858 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1859 const Value *v = Cexpr->stripPointerCasts(); 1860 AggBuffer->addSymbol(v, Cexpr); 1861 } 1862 AggBuffer->addZeros(AllocSize); 1863 break; 1864 } 1865 1866 case Type::ArrayTyID: 1867 case Type::FixedVectorTyID: 1868 case Type::StructTyID: { 1869 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1870 bufferAggregateConstant(CPV, AggBuffer); 1871 if (Bytes > AllocSize) 1872 AggBuffer->addZeros(Bytes - AllocSize); 1873 } else if (isa<ConstantAggregateZero>(CPV)) 1874 AggBuffer->addZeros(Bytes); 1875 else 1876 llvm_unreachable("Unexpected Constant type"); 1877 break; 1878 } 1879 1880 default: 1881 llvm_unreachable("unsupported type"); 1882 } 1883 } 1884 1885 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1886 AggBuffer *aggBuffer) { 1887 const DataLayout &DL = getDataLayout(); 1888 int Bytes; 1889 1890 // Integers of arbitrary width 1891 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1892 APInt Val = CI->getValue(); 1893 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1894 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1895 aggBuffer->addBytes(&Byte, 1, 1); 1896 Val.lshrInPlace(8); 1897 } 1898 return; 1899 } 1900 1901 // Old constants 1902 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1903 if (CPV->getNumOperands()) 1904 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1905 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1906 return; 1907 } 1908 1909 if (const ConstantDataSequential *CDS = 1910 dyn_cast<ConstantDataSequential>(CPV)) { 1911 if (CDS->getNumElements()) 1912 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1913 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1914 aggBuffer); 1915 return; 1916 } 1917 1918 if (isa<ConstantStruct>(CPV)) { 1919 if (CPV->getNumOperands()) { 1920 StructType *ST = cast<StructType>(CPV->getType()); 1921 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1922 if (i == (e - 1)) 1923 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1924 DL.getTypeAllocSize(ST) - 1925 DL.getStructLayout(ST)->getElementOffset(i); 1926 else 1927 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1928 DL.getStructLayout(ST)->getElementOffset(i); 1929 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1930 } 1931 } 1932 return; 1933 } 1934 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1935 } 1936 1937 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1938 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1939 /// expressions that are representable in PTX and create 1940 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1941 const MCExpr * 1942 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1943 MCContext &Ctx = OutContext; 1944 1945 if (CV->isNullValue() || isa<UndefValue>(CV)) 1946 return MCConstantExpr::create(0, Ctx); 1947 1948 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1949 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1950 1951 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1952 const MCSymbolRefExpr *Expr = 1953 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1954 if (ProcessingGeneric) { 1955 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1956 } else { 1957 return Expr; 1958 } 1959 } 1960 1961 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 1962 if (!CE) { 1963 llvm_unreachable("Unknown constant value to lower!"); 1964 } 1965 1966 switch (CE->getOpcode()) { 1967 default: { 1968 // If the code isn't optimized, there may be outstanding folding 1969 // opportunities. Attempt to fold the expression using DataLayout as a 1970 // last resort before giving up. 1971 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 1972 if (C != CE) 1973 return lowerConstantForGV(C, ProcessingGeneric); 1974 1975 // Otherwise report the problem to the user. 1976 std::string S; 1977 raw_string_ostream OS(S); 1978 OS << "Unsupported expression in static initializer: "; 1979 CE->printAsOperand(OS, /*PrintType=*/false, 1980 !MF ? nullptr : MF->getFunction().getParent()); 1981 report_fatal_error(Twine(OS.str())); 1982 } 1983 1984 case Instruction::AddrSpaceCast: { 1985 // Strip the addrspacecast and pass along the operand 1986 PointerType *DstTy = cast<PointerType>(CE->getType()); 1987 if (DstTy->getAddressSpace() == 0) { 1988 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 1989 } 1990 std::string S; 1991 raw_string_ostream OS(S); 1992 OS << "Unsupported expression in static initializer: "; 1993 CE->printAsOperand(OS, /*PrintType=*/ false, 1994 !MF ? nullptr : MF->getFunction().getParent()); 1995 report_fatal_error(Twine(OS.str())); 1996 } 1997 1998 case Instruction::GetElementPtr: { 1999 const DataLayout &DL = getDataLayout(); 2000 2001 // Generate a symbolic expression for the byte address 2002 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 2003 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 2004 2005 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2006 ProcessingGeneric); 2007 if (!OffsetAI) 2008 return Base; 2009 2010 int64_t Offset = OffsetAI.getSExtValue(); 2011 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2012 Ctx); 2013 } 2014 2015 case Instruction::Trunc: 2016 // We emit the value and depend on the assembler to truncate the generated 2017 // expression properly. This is important for differences between 2018 // blockaddress labels. Since the two labels are in the same function, it 2019 // is reasonable to treat their delta as a 32-bit value. 2020 LLVM_FALLTHROUGH; 2021 case Instruction::BitCast: 2022 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2023 2024 case Instruction::IntToPtr: { 2025 const DataLayout &DL = getDataLayout(); 2026 2027 // Handle casts to pointers by changing them into casts to the appropriate 2028 // integer type. This promotes constant folding and simplifies this code. 2029 Constant *Op = CE->getOperand(0); 2030 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2031 false/*ZExt*/); 2032 return lowerConstantForGV(Op, ProcessingGeneric); 2033 } 2034 2035 case Instruction::PtrToInt: { 2036 const DataLayout &DL = getDataLayout(); 2037 2038 // Support only foldable casts to/from pointers that can be eliminated by 2039 // changing the pointer to the appropriately sized integer type. 2040 Constant *Op = CE->getOperand(0); 2041 Type *Ty = CE->getType(); 2042 2043 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2044 2045 // We can emit the pointer value into this slot if the slot is an 2046 // integer slot equal to the size of the pointer. 2047 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2048 return OpExpr; 2049 2050 // Otherwise the pointer is smaller than the resultant integer, mask off 2051 // the high bits so we are sure to get a proper truncation if the input is 2052 // a constant expr. 2053 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2054 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2055 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2056 } 2057 2058 // The MC library also has a right-shift operator, but it isn't consistently 2059 // signed or unsigned between different targets. 2060 case Instruction::Add: { 2061 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2062 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2063 switch (CE->getOpcode()) { 2064 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2065 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2066 } 2067 } 2068 } 2069 } 2070 2071 // Copy of MCExpr::print customized for NVPTX 2072 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2073 switch (Expr.getKind()) { 2074 case MCExpr::Target: 2075 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2076 case MCExpr::Constant: 2077 OS << cast<MCConstantExpr>(Expr).getValue(); 2078 return; 2079 2080 case MCExpr::SymbolRef: { 2081 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2082 const MCSymbol &Sym = SRE.getSymbol(); 2083 Sym.print(OS, MAI); 2084 return; 2085 } 2086 2087 case MCExpr::Unary: { 2088 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2089 switch (UE.getOpcode()) { 2090 case MCUnaryExpr::LNot: OS << '!'; break; 2091 case MCUnaryExpr::Minus: OS << '-'; break; 2092 case MCUnaryExpr::Not: OS << '~'; break; 2093 case MCUnaryExpr::Plus: OS << '+'; break; 2094 } 2095 printMCExpr(*UE.getSubExpr(), OS); 2096 return; 2097 } 2098 2099 case MCExpr::Binary: { 2100 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2101 2102 // Only print parens around the LHS if it is non-trivial. 2103 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2104 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2105 printMCExpr(*BE.getLHS(), OS); 2106 } else { 2107 OS << '('; 2108 printMCExpr(*BE.getLHS(), OS); 2109 OS<< ')'; 2110 } 2111 2112 switch (BE.getOpcode()) { 2113 case MCBinaryExpr::Add: 2114 // Print "X-42" instead of "X+-42". 2115 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2116 if (RHSC->getValue() < 0) { 2117 OS << RHSC->getValue(); 2118 return; 2119 } 2120 } 2121 2122 OS << '+'; 2123 break; 2124 default: llvm_unreachable("Unhandled binary operator"); 2125 } 2126 2127 // Only print parens around the LHS if it is non-trivial. 2128 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2129 printMCExpr(*BE.getRHS(), OS); 2130 } else { 2131 OS << '('; 2132 printMCExpr(*BE.getRHS(), OS); 2133 OS << ')'; 2134 } 2135 return; 2136 } 2137 } 2138 2139 llvm_unreachable("Invalid expression kind!"); 2140 } 2141 2142 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2143 /// 2144 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2145 const char *ExtraCode, raw_ostream &O) { 2146 if (ExtraCode && ExtraCode[0]) { 2147 if (ExtraCode[1] != 0) 2148 return true; // Unknown modifier. 2149 2150 switch (ExtraCode[0]) { 2151 default: 2152 // See if this is a generic print operand 2153 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2154 case 'r': 2155 break; 2156 } 2157 } 2158 2159 printOperand(MI, OpNo, O); 2160 2161 return false; 2162 } 2163 2164 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2165 unsigned OpNo, 2166 const char *ExtraCode, 2167 raw_ostream &O) { 2168 if (ExtraCode && ExtraCode[0]) 2169 return true; // Unknown modifier 2170 2171 O << '['; 2172 printMemOperand(MI, OpNo, O); 2173 O << ']'; 2174 2175 return false; 2176 } 2177 2178 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, 2179 raw_ostream &O) { 2180 const MachineOperand &MO = MI->getOperand(opNum); 2181 switch (MO.getType()) { 2182 case MachineOperand::MO_Register: 2183 if (Register::isPhysicalRegister(MO.getReg())) { 2184 if (MO.getReg() == NVPTX::VRDepot) 2185 O << DEPOTNAME << getFunctionNumber(); 2186 else 2187 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2188 } else { 2189 emitVirtualRegister(MO.getReg(), O); 2190 } 2191 break; 2192 2193 case MachineOperand::MO_Immediate: 2194 O << MO.getImm(); 2195 break; 2196 2197 case MachineOperand::MO_FPImmediate: 2198 printFPConstant(MO.getFPImm(), O); 2199 break; 2200 2201 case MachineOperand::MO_GlobalAddress: 2202 PrintSymbolOperand(MO, O); 2203 break; 2204 2205 case MachineOperand::MO_MachineBasicBlock: 2206 MO.getMBB()->getSymbol()->print(O, MAI); 2207 break; 2208 2209 default: 2210 llvm_unreachable("Operand type not supported."); 2211 } 2212 } 2213 2214 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, 2215 raw_ostream &O, const char *Modifier) { 2216 printOperand(MI, opNum, O); 2217 2218 if (Modifier && strcmp(Modifier, "add") == 0) { 2219 O << ", "; 2220 printOperand(MI, opNum + 1, O); 2221 } else { 2222 if (MI->getOperand(opNum + 1).isImm() && 2223 MI->getOperand(opNum + 1).getImm() == 0) 2224 return; // don't print ',0' or '+0' 2225 O << "+"; 2226 printOperand(MI, opNum + 1, O); 2227 } 2228 } 2229 2230 // Force static initialization. 2231 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2232 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2233 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2234 } 2235