1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Triple.h" 37 #include "llvm/ADT/Twine.h" 38 #include "llvm/Analysis/ConstantFolding.h" 39 #include "llvm/CodeGen/Analysis.h" 40 #include "llvm/CodeGen/MachineBasicBlock.h" 41 #include "llvm/CodeGen/MachineFrameInfo.h" 42 #include "llvm/CodeGen/MachineFunction.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineLoopInfo.h" 45 #include "llvm/CodeGen/MachineModuleInfo.h" 46 #include "llvm/CodeGen/MachineOperand.h" 47 #include "llvm/CodeGen/MachineRegisterInfo.h" 48 #include "llvm/CodeGen/TargetLowering.h" 49 #include "llvm/CodeGen/TargetRegisterInfo.h" 50 #include "llvm/CodeGen/ValueTypes.h" 51 #include "llvm/IR/Attributes.h" 52 #include "llvm/IR/BasicBlock.h" 53 #include "llvm/IR/Constant.h" 54 #include "llvm/IR/Constants.h" 55 #include "llvm/IR/DataLayout.h" 56 #include "llvm/IR/DebugInfo.h" 57 #include "llvm/IR/DebugInfoMetadata.h" 58 #include "llvm/IR/DebugLoc.h" 59 #include "llvm/IR/DerivedTypes.h" 60 #include "llvm/IR/Function.h" 61 #include "llvm/IR/GlobalValue.h" 62 #include "llvm/IR/GlobalVariable.h" 63 #include "llvm/IR/Instruction.h" 64 #include "llvm/IR/LLVMContext.h" 65 #include "llvm/IR/Module.h" 66 #include "llvm/IR/Operator.h" 67 #include "llvm/IR/Type.h" 68 #include "llvm/IR/User.h" 69 #include "llvm/MC/MCExpr.h" 70 #include "llvm/MC/MCInst.h" 71 #include "llvm/MC/MCInstrDesc.h" 72 #include "llvm/MC/MCStreamer.h" 73 #include "llvm/MC/MCSymbol.h" 74 #include "llvm/MC/TargetRegistry.h" 75 #include "llvm/Support/Casting.h" 76 #include "llvm/Support/CommandLine.h" 77 #include "llvm/Support/ErrorHandling.h" 78 #include "llvm/Support/MachineValueType.h" 79 #include "llvm/Support/Path.h" 80 #include "llvm/Support/raw_ostream.h" 81 #include "llvm/Target/TargetLoweringObjectFile.h" 82 #include "llvm/Target/TargetMachine.h" 83 #include "llvm/Transforms/Utils/UnrollLoop.h" 84 #include <cassert> 85 #include <cstdint> 86 #include <cstring> 87 #include <new> 88 #include <string> 89 #include <utility> 90 #include <vector> 91 92 using namespace llvm; 93 94 #define DEPOTNAME "__local_depot" 95 96 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 97 /// depends. 98 static void 99 DiscoverDependentGlobals(const Value *V, 100 DenseSet<const GlobalVariable *> &Globals) { 101 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 102 Globals.insert(GV); 103 else { 104 if (const User *U = dyn_cast<User>(V)) { 105 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 106 DiscoverDependentGlobals(U->getOperand(i), Globals); 107 } 108 } 109 } 110 } 111 112 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 113 /// instances to be emitted, but only after any dependents have been added 114 /// first.s 115 static void 116 VisitGlobalVariableForEmission(const GlobalVariable *GV, 117 SmallVectorImpl<const GlobalVariable *> &Order, 118 DenseSet<const GlobalVariable *> &Visited, 119 DenseSet<const GlobalVariable *> &Visiting) { 120 // Have we already visited this one? 121 if (Visited.count(GV)) 122 return; 123 124 // Do we have a circular dependency? 125 if (!Visiting.insert(GV).second) 126 report_fatal_error("Circular dependency found in global variable set"); 127 128 // Make sure we visit all dependents first 129 DenseSet<const GlobalVariable *> Others; 130 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 131 DiscoverDependentGlobals(GV->getOperand(i), Others); 132 133 for (DenseSet<const GlobalVariable *>::iterator I = Others.begin(), 134 E = Others.end(); 135 I != E; ++I) 136 VisitGlobalVariableForEmission(*I, Order, Visited, Visiting); 137 138 // Now we can visit ourself 139 Order.push_back(GV); 140 Visited.insert(GV); 141 Visiting.erase(GV); 142 } 143 144 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 145 MCInst Inst; 146 lowerToMCInst(MI, Inst); 147 EmitToStreamer(*OutStreamer, Inst); 148 } 149 150 // Handle symbol backtracking for targets that do not support image handles 151 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 152 unsigned OpNo, MCOperand &MCOp) { 153 const MachineOperand &MO = MI->getOperand(OpNo); 154 const MCInstrDesc &MCID = MI->getDesc(); 155 156 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 157 // This is a texture fetch, so operand 4 is a texref and operand 5 is 158 // a samplerref 159 if (OpNo == 4 && MO.isImm()) { 160 lowerImageHandleSymbol(MO.getImm(), MCOp); 161 return true; 162 } 163 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 164 lowerImageHandleSymbol(MO.getImm(), MCOp); 165 return true; 166 } 167 168 return false; 169 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 170 unsigned VecSize = 171 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 172 173 // For a surface load of vector size N, the Nth operand will be the surfref 174 if (OpNo == VecSize && MO.isImm()) { 175 lowerImageHandleSymbol(MO.getImm(), MCOp); 176 return true; 177 } 178 179 return false; 180 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 181 // This is a surface store, so operand 0 is a surfref 182 if (OpNo == 0 && MO.isImm()) { 183 lowerImageHandleSymbol(MO.getImm(), MCOp); 184 return true; 185 } 186 187 return false; 188 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 189 // This is a query, so operand 1 is a surfref/texref 190 if (OpNo == 1 && MO.isImm()) { 191 lowerImageHandleSymbol(MO.getImm(), MCOp); 192 return true; 193 } 194 195 return false; 196 } 197 198 return false; 199 } 200 201 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 202 // Ewwww 203 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget()); 204 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM); 205 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 206 const char *Sym = MFI->getImageHandleSymbol(Index); 207 std::string *SymNamePtr = 208 nvTM.getManagedStrPool()->getManagedString(Sym); 209 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr))); 210 } 211 212 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 213 OutMI.setOpcode(MI->getOpcode()); 214 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 215 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 216 const MachineOperand &MO = MI->getOperand(0); 217 OutMI.addOperand(GetSymbolRef( 218 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 219 return; 220 } 221 222 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 223 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 224 const MachineOperand &MO = MI->getOperand(i); 225 226 MCOperand MCOp; 227 if (!STI.hasImageHandles()) { 228 if (lowerImageHandleOperand(MI, i, MCOp)) { 229 OutMI.addOperand(MCOp); 230 continue; 231 } 232 } 233 234 if (lowerOperand(MO, MCOp)) 235 OutMI.addOperand(MCOp); 236 } 237 } 238 239 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 240 MCOperand &MCOp) { 241 switch (MO.getType()) { 242 default: llvm_unreachable("unknown operand type"); 243 case MachineOperand::MO_Register: 244 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 245 break; 246 case MachineOperand::MO_Immediate: 247 MCOp = MCOperand::createImm(MO.getImm()); 248 break; 249 case MachineOperand::MO_MachineBasicBlock: 250 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 251 MO.getMBB()->getSymbol(), OutContext)); 252 break; 253 case MachineOperand::MO_ExternalSymbol: 254 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 255 break; 256 case MachineOperand::MO_GlobalAddress: 257 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 258 break; 259 case MachineOperand::MO_FPImmediate: { 260 const ConstantFP *Cnt = MO.getFPImm(); 261 const APFloat &Val = Cnt->getValueAPF(); 262 263 switch (Cnt->getType()->getTypeID()) { 264 default: report_fatal_error("Unsupported FP type"); break; 265 case Type::HalfTyID: 266 MCOp = MCOperand::createExpr( 267 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 268 break; 269 case Type::FloatTyID: 270 MCOp = MCOperand::createExpr( 271 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 272 break; 273 case Type::DoubleTyID: 274 MCOp = MCOperand::createExpr( 275 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 276 break; 277 } 278 break; 279 } 280 } 281 return true; 282 } 283 284 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 285 if (Register::isVirtualRegister(Reg)) { 286 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 287 288 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 289 unsigned RegNum = RegMap[Reg]; 290 291 // Encode the register class in the upper 4 bits 292 // Must be kept in sync with NVPTXInstPrinter::printRegName 293 unsigned Ret = 0; 294 if (RC == &NVPTX::Int1RegsRegClass) { 295 Ret = (1 << 28); 296 } else if (RC == &NVPTX::Int16RegsRegClass) { 297 Ret = (2 << 28); 298 } else if (RC == &NVPTX::Int32RegsRegClass) { 299 Ret = (3 << 28); 300 } else if (RC == &NVPTX::Int64RegsRegClass) { 301 Ret = (4 << 28); 302 } else if (RC == &NVPTX::Float32RegsRegClass) { 303 Ret = (5 << 28); 304 } else if (RC == &NVPTX::Float64RegsRegClass) { 305 Ret = (6 << 28); 306 } else if (RC == &NVPTX::Float16RegsRegClass) { 307 Ret = (7 << 28); 308 } else if (RC == &NVPTX::Float16x2RegsRegClass) { 309 Ret = (8 << 28); 310 } else { 311 report_fatal_error("Bad register class"); 312 } 313 314 // Insert the vreg number 315 Ret |= (RegNum & 0x0FFFFFFF); 316 return Ret; 317 } else { 318 // Some special-use registers are actually physical registers. 319 // Encode this as the register class ID of 0 and the real register ID. 320 return Reg & 0x0FFFFFFF; 321 } 322 } 323 324 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 325 const MCExpr *Expr; 326 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 327 OutContext); 328 return MCOperand::createExpr(Expr); 329 } 330 331 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 332 const DataLayout &DL = getDataLayout(); 333 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 334 const TargetLowering *TLI = STI.getTargetLowering(); 335 336 Type *Ty = F->getReturnType(); 337 338 bool isABI = (STI.getSmVersion() >= 20); 339 340 if (Ty->getTypeID() == Type::VoidTyID) 341 return; 342 343 O << " ("; 344 345 if (isABI) { 346 if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) { 347 unsigned size = 0; 348 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 349 size = ITy->getBitWidth(); 350 } else { 351 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 352 size = Ty->getPrimitiveSizeInBits(); 353 } 354 // PTX ABI requires all scalar return values to be at least 32 355 // bits in size. fp16 normally uses .b16 as its storage type in 356 // PTX, so its size must be adjusted here, too. 357 if (size < 32) 358 size = 32; 359 360 O << ".param .b" << size << " func_retval0"; 361 } else if (isa<PointerType>(Ty)) { 362 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 363 << " func_retval0"; 364 } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 365 unsigned totalsz = DL.getTypeAllocSize(Ty); 366 unsigned retAlignment = 0; 367 if (!getAlign(*F, 0, retAlignment)) 368 retAlignment = DL.getABITypeAlignment(Ty); 369 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz 370 << "]"; 371 } else 372 llvm_unreachable("Unknown return type"); 373 } else { 374 SmallVector<EVT, 16> vtparts; 375 ComputeValueVTs(*TLI, DL, Ty, vtparts); 376 unsigned idx = 0; 377 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 378 unsigned elems = 1; 379 EVT elemtype = vtparts[i]; 380 if (vtparts[i].isVector()) { 381 elems = vtparts[i].getVectorNumElements(); 382 elemtype = vtparts[i].getVectorElementType(); 383 } 384 385 for (unsigned j = 0, je = elems; j != je; ++j) { 386 unsigned sz = elemtype.getSizeInBits(); 387 if (elemtype.isInteger() && (sz < 32)) 388 sz = 32; 389 O << ".reg .b" << sz << " func_retval" << idx; 390 if (j < je - 1) 391 O << ", "; 392 ++idx; 393 } 394 if (i < e - 1) 395 O << ", "; 396 } 397 } 398 O << ") "; 399 } 400 401 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 402 raw_ostream &O) { 403 const Function &F = MF.getFunction(); 404 printReturnValStr(&F, O); 405 } 406 407 // Return true if MBB is the header of a loop marked with 408 // llvm.loop.unroll.disable. 409 // TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll". 410 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 411 const MachineBasicBlock &MBB) const { 412 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 413 // We insert .pragma "nounroll" only to the loop header. 414 if (!LI.isLoopHeader(&MBB)) 415 return false; 416 417 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 418 // we iterate through each back edge of the loop with header MBB, and check 419 // whether its metadata contains llvm.loop.unroll.disable. 420 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 421 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 422 // Edges from other loops to MBB are not back edges. 423 continue; 424 } 425 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 426 if (MDNode *LoopID = 427 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 428 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 429 return true; 430 } 431 } 432 } 433 return false; 434 } 435 436 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 437 AsmPrinter::emitBasicBlockStart(MBB); 438 if (isLoopHeaderOfNoUnroll(MBB)) 439 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 440 } 441 442 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 443 SmallString<128> Str; 444 raw_svector_ostream O(Str); 445 446 if (!GlobalsEmitted) { 447 emitGlobals(*MF->getFunction().getParent()); 448 GlobalsEmitted = true; 449 } 450 451 // Set up 452 MRI = &MF->getRegInfo(); 453 F = &MF->getFunction(); 454 emitLinkageDirective(F, O); 455 if (isKernelFunction(*F)) 456 O << ".entry "; 457 else { 458 O << ".func "; 459 printReturnValStr(*MF, O); 460 } 461 462 CurrentFnSym->print(O, MAI); 463 464 emitFunctionParamList(*MF, O); 465 466 if (isKernelFunction(*F)) 467 emitKernelFunctionDirectives(*F, O); 468 469 OutStreamer->emitRawText(O.str()); 470 471 VRegMapping.clear(); 472 // Emit open brace for function body. 473 OutStreamer->emitRawText(StringRef("{\n")); 474 setAndEmitFunctionVirtualRegisters(*MF); 475 // Emit initial .loc debug directive for correct relocation symbol data. 476 if (MMI && MMI->hasDebugInfo()) 477 emitInitialRawDwarfLocDirective(*MF); 478 } 479 480 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 481 bool Result = AsmPrinter::runOnMachineFunction(F); 482 // Emit closing brace for the body of function F. 483 // The closing brace must be emitted here because we need to emit additional 484 // debug labels/data after the last basic block. 485 // We need to emit the closing brace here because we don't have function that 486 // finished emission of the function body. 487 OutStreamer->emitRawText(StringRef("}\n")); 488 return Result; 489 } 490 491 void NVPTXAsmPrinter::emitFunctionBodyStart() { 492 SmallString<128> Str; 493 raw_svector_ostream O(Str); 494 emitDemotedVars(&MF->getFunction(), O); 495 OutStreamer->emitRawText(O.str()); 496 } 497 498 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 499 VRegMapping.clear(); 500 } 501 502 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 503 SmallString<128> Str; 504 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 505 return OutContext.getOrCreateSymbol(Str); 506 } 507 508 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 509 Register RegNo = MI->getOperand(0).getReg(); 510 if (Register::isVirtualRegister(RegNo)) { 511 OutStreamer->AddComment(Twine("implicit-def: ") + 512 getVirtualRegisterName(RegNo)); 513 } else { 514 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 515 OutStreamer->AddComment(Twine("implicit-def: ") + 516 STI.getRegisterInfo()->getName(RegNo)); 517 } 518 OutStreamer->AddBlankLine(); 519 } 520 521 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 522 raw_ostream &O) const { 523 // If the NVVM IR has some of reqntid* specified, then output 524 // the reqntid directive, and set the unspecified ones to 1. 525 // If none of reqntid* is specified, don't output reqntid directive. 526 unsigned reqntidx, reqntidy, reqntidz; 527 bool specified = false; 528 if (!getReqNTIDx(F, reqntidx)) 529 reqntidx = 1; 530 else 531 specified = true; 532 if (!getReqNTIDy(F, reqntidy)) 533 reqntidy = 1; 534 else 535 specified = true; 536 if (!getReqNTIDz(F, reqntidz)) 537 reqntidz = 1; 538 else 539 specified = true; 540 541 if (specified) 542 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz 543 << "\n"; 544 545 // If the NVVM IR has some of maxntid* specified, then output 546 // the maxntid directive, and set the unspecified ones to 1. 547 // If none of maxntid* is specified, don't output maxntid directive. 548 unsigned maxntidx, maxntidy, maxntidz; 549 specified = false; 550 if (!getMaxNTIDx(F, maxntidx)) 551 maxntidx = 1; 552 else 553 specified = true; 554 if (!getMaxNTIDy(F, maxntidy)) 555 maxntidy = 1; 556 else 557 specified = true; 558 if (!getMaxNTIDz(F, maxntidz)) 559 maxntidz = 1; 560 else 561 specified = true; 562 563 if (specified) 564 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz 565 << "\n"; 566 567 unsigned mincta; 568 if (getMinCTASm(F, mincta)) 569 O << ".minnctapersm " << mincta << "\n"; 570 571 unsigned maxnreg; 572 if (getMaxNReg(F, maxnreg)) 573 O << ".maxnreg " << maxnreg << "\n"; 574 } 575 576 std::string 577 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 578 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 579 580 std::string Name; 581 raw_string_ostream NameStr(Name); 582 583 VRegRCMap::const_iterator I = VRegMapping.find(RC); 584 assert(I != VRegMapping.end() && "Bad register class"); 585 const DenseMap<unsigned, unsigned> &RegMap = I->second; 586 587 VRegMap::const_iterator VI = RegMap.find(Reg); 588 assert(VI != RegMap.end() && "Bad virtual register"); 589 unsigned MappedVR = VI->second; 590 591 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 592 593 NameStr.flush(); 594 return Name; 595 } 596 597 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 598 raw_ostream &O) { 599 O << getVirtualRegisterName(vr); 600 } 601 602 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 603 emitLinkageDirective(F, O); 604 if (isKernelFunction(*F)) 605 O << ".entry "; 606 else 607 O << ".func "; 608 printReturnValStr(F, O); 609 getSymbol(F)->print(O, MAI); 610 O << "\n"; 611 emitFunctionParamList(F, O); 612 O << ";\n"; 613 } 614 615 static bool usedInGlobalVarDef(const Constant *C) { 616 if (!C) 617 return false; 618 619 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 620 return GV->getName() != "llvm.used"; 621 } 622 623 for (const User *U : C->users()) 624 if (const Constant *C = dyn_cast<Constant>(U)) 625 if (usedInGlobalVarDef(C)) 626 return true; 627 628 return false; 629 } 630 631 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 632 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 633 if (othergv->getName() == "llvm.used") 634 return true; 635 } 636 637 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 638 if (instr->getParent() && instr->getParent()->getParent()) { 639 const Function *curFunc = instr->getParent()->getParent(); 640 if (oneFunc && (curFunc != oneFunc)) 641 return false; 642 oneFunc = curFunc; 643 return true; 644 } else 645 return false; 646 } 647 648 for (const User *UU : U->users()) 649 if (!usedInOneFunc(UU, oneFunc)) 650 return false; 651 652 return true; 653 } 654 655 /* Find out if a global variable can be demoted to local scope. 656 * Currently, this is valid for CUDA shared variables, which have local 657 * scope and global lifetime. So the conditions to check are : 658 * 1. Is the global variable in shared address space? 659 * 2. Does it have internal linkage? 660 * 3. Is the global variable referenced only in one function? 661 */ 662 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 663 if (!gv->hasInternalLinkage()) 664 return false; 665 PointerType *Pty = gv->getType(); 666 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 667 return false; 668 669 const Function *oneFunc = nullptr; 670 671 bool flag = usedInOneFunc(gv, oneFunc); 672 if (!flag) 673 return false; 674 if (!oneFunc) 675 return false; 676 f = oneFunc; 677 return true; 678 } 679 680 static bool useFuncSeen(const Constant *C, 681 DenseMap<const Function *, bool> &seenMap) { 682 for (const User *U : C->users()) { 683 if (const Constant *cu = dyn_cast<Constant>(U)) { 684 if (useFuncSeen(cu, seenMap)) 685 return true; 686 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 687 const BasicBlock *bb = I->getParent(); 688 if (!bb) 689 continue; 690 const Function *caller = bb->getParent(); 691 if (!caller) 692 continue; 693 if (seenMap.find(caller) != seenMap.end()) 694 return true; 695 } 696 } 697 return false; 698 } 699 700 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 701 DenseMap<const Function *, bool> seenMap; 702 for (Module::const_iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { 703 const Function *F = &*FI; 704 705 if (F->getAttributes().hasFnAttr("nvptx-libcall-callee")) { 706 emitDeclaration(F, O); 707 continue; 708 } 709 710 if (F->isDeclaration()) { 711 if (F->use_empty()) 712 continue; 713 if (F->getIntrinsicID()) 714 continue; 715 emitDeclaration(F, O); 716 continue; 717 } 718 for (const User *U : F->users()) { 719 if (const Constant *C = dyn_cast<Constant>(U)) { 720 if (usedInGlobalVarDef(C)) { 721 // The use is in the initialization of a global variable 722 // that is a function pointer, so print a declaration 723 // for the original function 724 emitDeclaration(F, O); 725 break; 726 } 727 // Emit a declaration of this function if the function that 728 // uses this constant expr has already been seen. 729 if (useFuncSeen(C, seenMap)) { 730 emitDeclaration(F, O); 731 break; 732 } 733 } 734 735 if (!isa<Instruction>(U)) 736 continue; 737 const Instruction *instr = cast<Instruction>(U); 738 const BasicBlock *bb = instr->getParent(); 739 if (!bb) 740 continue; 741 const Function *caller = bb->getParent(); 742 if (!caller) 743 continue; 744 745 // If a caller has already been seen, then the caller is 746 // appearing in the module before the callee. so print out 747 // a declaration for the callee. 748 if (seenMap.find(caller) != seenMap.end()) { 749 emitDeclaration(F, O); 750 break; 751 } 752 } 753 seenMap[F] = true; 754 } 755 } 756 757 static bool isEmptyXXStructor(GlobalVariable *GV) { 758 if (!GV) return true; 759 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 760 if (!InitList) return true; // Not an array; we don't know how to parse. 761 return InitList->getNumOperands() == 0; 762 } 763 764 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 765 // Construct a default subtarget off of the TargetMachine defaults. The 766 // rest of NVPTX isn't friendly to change subtargets per function and 767 // so the default TargetMachine will have all of the options. 768 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 769 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 770 SmallString<128> Str1; 771 raw_svector_ostream OS1(Str1); 772 773 // Emit header before any dwarf directives are emitted below. 774 emitHeader(M, OS1, *STI); 775 OutStreamer->emitRawText(OS1.str()); 776 } 777 778 bool NVPTXAsmPrinter::doInitialization(Module &M) { 779 if (M.alias_size()) { 780 report_fatal_error("Module has aliases, which NVPTX does not support."); 781 return true; // error 782 } 783 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) { 784 report_fatal_error( 785 "Module has a nontrivial global ctor, which NVPTX does not support."); 786 return true; // error 787 } 788 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) { 789 report_fatal_error( 790 "Module has a nontrivial global dtor, which NVPTX does not support."); 791 return true; // error 792 } 793 794 // We need to call the parent's one explicitly. 795 bool Result = AsmPrinter::doInitialization(M); 796 797 GlobalsEmitted = false; 798 799 return Result; 800 } 801 802 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 803 SmallString<128> Str2; 804 raw_svector_ostream OS2(Str2); 805 806 emitDeclarations(M, OS2); 807 808 // As ptxas does not support forward references of globals, we need to first 809 // sort the list of module-level globals in def-use order. We visit each 810 // global variable in order, and ensure that we emit it *after* its dependent 811 // globals. We use a little extra memory maintaining both a set and a list to 812 // have fast searches while maintaining a strict ordering. 813 SmallVector<const GlobalVariable *, 8> Globals; 814 DenseSet<const GlobalVariable *> GVVisited; 815 DenseSet<const GlobalVariable *> GVVisiting; 816 817 // Visit each global variable, in order 818 for (const GlobalVariable &I : M.globals()) 819 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 820 821 assert(GVVisited.size() == M.getGlobalList().size() && 822 "Missed a global variable"); 823 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 824 825 // Print out module-level global variables in proper order 826 for (unsigned i = 0, e = Globals.size(); i != e; ++i) 827 printModuleLevelGV(Globals[i], OS2); 828 829 OS2 << '\n'; 830 831 OutStreamer->emitRawText(OS2.str()); 832 } 833 834 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 835 const NVPTXSubtarget &STI) { 836 O << "//\n"; 837 O << "// Generated by LLVM NVPTX Back-End\n"; 838 O << "//\n"; 839 O << "\n"; 840 841 unsigned PTXVersion = STI.getPTXVersion(); 842 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 843 844 O << ".target "; 845 O << STI.getTargetName(); 846 847 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 848 if (NTM.getDrvInterface() == NVPTX::NVCL) 849 O << ", texmode_independent"; 850 851 bool HasFullDebugInfo = false; 852 for (DICompileUnit *CU : M.debug_compile_units()) { 853 switch(CU->getEmissionKind()) { 854 case DICompileUnit::NoDebug: 855 case DICompileUnit::DebugDirectivesOnly: 856 break; 857 case DICompileUnit::LineTablesOnly: 858 case DICompileUnit::FullDebug: 859 HasFullDebugInfo = true; 860 break; 861 } 862 if (HasFullDebugInfo) 863 break; 864 } 865 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo) 866 O << ", debug"; 867 868 O << "\n"; 869 870 O << ".address_size "; 871 if (NTM.is64Bit()) 872 O << "64"; 873 else 874 O << "32"; 875 O << "\n"; 876 877 O << "\n"; 878 } 879 880 bool NVPTXAsmPrinter::doFinalization(Module &M) { 881 bool HasDebugInfo = MMI && MMI->hasDebugInfo(); 882 883 // If we did not emit any functions, then the global declarations have not 884 // yet been emitted. 885 if (!GlobalsEmitted) { 886 emitGlobals(M); 887 GlobalsEmitted = true; 888 } 889 890 // XXX Temproarily remove global variables so that doFinalization() will not 891 // emit them again (global variables are emitted at beginning). 892 893 Module::GlobalListType &global_list = M.getGlobalList(); 894 int i, n = global_list.size(); 895 GlobalVariable **gv_array = new GlobalVariable *[n]; 896 897 // first, back-up GlobalVariable in gv_array 898 i = 0; 899 for (Module::global_iterator I = global_list.begin(), E = global_list.end(); 900 I != E; ++I) 901 gv_array[i++] = &*I; 902 903 // second, empty global_list 904 while (!global_list.empty()) 905 global_list.remove(global_list.begin()); 906 907 // call doFinalization 908 bool ret = AsmPrinter::doFinalization(M); 909 910 // now we restore global variables 911 for (i = 0; i < n; i++) 912 global_list.insert(global_list.end(), gv_array[i]); 913 914 clearAnnotationCache(&M); 915 916 delete[] gv_array; 917 // Close the last emitted section 918 if (HasDebugInfo) { 919 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()) 920 ->closeLastSection(); 921 // Emit empty .debug_loc section for better support of the empty files. 922 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}"); 923 } 924 925 // Output last DWARF .file directives, if any. 926 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()) 927 ->outputDwarfFileDirectives(); 928 929 return ret; 930 931 //bool Result = AsmPrinter::doFinalization(M); 932 // Instead of calling the parents doFinalization, we may 933 // clone parents doFinalization and customize here. 934 // Currently, we if NVISA out the EmitGlobals() in 935 // parent's doFinalization, which is too intrusive. 936 // 937 // Same for the doInitialization. 938 //return Result; 939 } 940 941 // This function emits appropriate linkage directives for 942 // functions and global variables. 943 // 944 // extern function declaration -> .extern 945 // extern function definition -> .visible 946 // external global variable with init -> .visible 947 // external without init -> .extern 948 // appending -> not allowed, assert. 949 // for any linkage other than 950 // internal, private, linker_private, 951 // linker_private_weak, linker_private_weak_def_auto, 952 // we emit -> .weak. 953 954 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 955 raw_ostream &O) { 956 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 957 if (V->hasExternalLinkage()) { 958 if (isa<GlobalVariable>(V)) { 959 const GlobalVariable *GVar = cast<GlobalVariable>(V); 960 if (GVar) { 961 if (GVar->hasInitializer()) 962 O << ".visible "; 963 else 964 O << ".extern "; 965 } 966 } else if (V->isDeclaration()) 967 O << ".extern "; 968 else 969 O << ".visible "; 970 } else if (V->hasAppendingLinkage()) { 971 std::string msg; 972 msg.append("Error: "); 973 msg.append("Symbol "); 974 if (V->hasName()) 975 msg.append(std::string(V->getName())); 976 msg.append("has unsupported appending linkage type"); 977 llvm_unreachable(msg.c_str()); 978 } else if (!V->hasInternalLinkage() && 979 !V->hasPrivateLinkage()) { 980 O << ".weak "; 981 } 982 } 983 } 984 985 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 986 raw_ostream &O, 987 bool processDemoted) { 988 // Skip meta data 989 if (GVar->hasSection()) { 990 if (GVar->getSection() == "llvm.metadata") 991 return; 992 } 993 994 // Skip LLVM intrinsic global variables 995 if (GVar->getName().startswith("llvm.") || 996 GVar->getName().startswith("nvvm.")) 997 return; 998 999 const DataLayout &DL = getDataLayout(); 1000 1001 // GlobalVariables are always constant pointers themselves. 1002 PointerType *PTy = GVar->getType(); 1003 Type *ETy = GVar->getValueType(); 1004 1005 if (GVar->hasExternalLinkage()) { 1006 if (GVar->hasInitializer()) 1007 O << ".visible "; 1008 else 1009 O << ".extern "; 1010 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 1011 GVar->hasAvailableExternallyLinkage() || 1012 GVar->hasCommonLinkage()) { 1013 O << ".weak "; 1014 } 1015 1016 if (isTexture(*GVar)) { 1017 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1018 return; 1019 } 1020 1021 if (isSurface(*GVar)) { 1022 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1023 return; 1024 } 1025 1026 if (GVar->isDeclaration()) { 1027 // (extern) declarations, no definition or initializer 1028 // Currently the only known declaration is for an automatic __local 1029 // (.shared) promoted to global. 1030 emitPTXGlobalVariable(GVar, O); 1031 O << ";\n"; 1032 return; 1033 } 1034 1035 if (isSampler(*GVar)) { 1036 O << ".global .samplerref " << getSamplerName(*GVar); 1037 1038 const Constant *Initializer = nullptr; 1039 if (GVar->hasInitializer()) 1040 Initializer = GVar->getInitializer(); 1041 const ConstantInt *CI = nullptr; 1042 if (Initializer) 1043 CI = dyn_cast<ConstantInt>(Initializer); 1044 if (CI) { 1045 unsigned sample = CI->getZExtValue(); 1046 1047 O << " = { "; 1048 1049 for (int i = 0, 1050 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1051 i < 3; i++) { 1052 O << "addr_mode_" << i << " = "; 1053 switch (addr) { 1054 case 0: 1055 O << "wrap"; 1056 break; 1057 case 1: 1058 O << "clamp_to_border"; 1059 break; 1060 case 2: 1061 O << "clamp_to_edge"; 1062 break; 1063 case 3: 1064 O << "wrap"; 1065 break; 1066 case 4: 1067 O << "mirror"; 1068 break; 1069 } 1070 O << ", "; 1071 } 1072 O << "filter_mode = "; 1073 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1074 case 0: 1075 O << "nearest"; 1076 break; 1077 case 1: 1078 O << "linear"; 1079 break; 1080 case 2: 1081 llvm_unreachable("Anisotropic filtering is not supported"); 1082 default: 1083 O << "nearest"; 1084 break; 1085 } 1086 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1087 O << ", force_unnormalized_coords = 1"; 1088 } 1089 O << " }"; 1090 } 1091 1092 O << ";\n"; 1093 return; 1094 } 1095 1096 if (GVar->hasPrivateLinkage()) { 1097 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1098 return; 1099 1100 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1101 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1102 return; 1103 if (GVar->use_empty()) 1104 return; 1105 } 1106 1107 const Function *demotedFunc = nullptr; 1108 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1109 O << "// " << GVar->getName() << " has been demoted\n"; 1110 if (localDecls.find(demotedFunc) != localDecls.end()) 1111 localDecls[demotedFunc].push_back(GVar); 1112 else { 1113 std::vector<const GlobalVariable *> temp; 1114 temp.push_back(GVar); 1115 localDecls[demotedFunc] = temp; 1116 } 1117 return; 1118 } 1119 1120 O << "."; 1121 emitPTXAddressSpace(PTy->getAddressSpace(), O); 1122 1123 if (isManaged(*GVar)) { 1124 O << " .attribute(.managed)"; 1125 } 1126 1127 if (GVar->getAlignment() == 0) 1128 O << " .align " << (int)DL.getPrefTypeAlignment(ETy); 1129 else 1130 O << " .align " << GVar->getAlignment(); 1131 1132 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1133 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1134 O << " ."; 1135 // Special case: ABI requires that we use .u8 for predicates 1136 if (ETy->isIntegerTy(1)) 1137 O << "u8"; 1138 else 1139 O << getPTXFundamentalTypeStr(ETy, false); 1140 O << " "; 1141 getSymbol(GVar)->print(O, MAI); 1142 1143 // Ptx allows variable initilization only for constant and global state 1144 // spaces. 1145 if (GVar->hasInitializer()) { 1146 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1147 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1148 const Constant *Initializer = GVar->getInitializer(); 1149 // 'undef' is treated as there is no value specified. 1150 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1151 O << " = "; 1152 printScalarConstant(Initializer, O); 1153 } 1154 } else { 1155 // The frontend adds zero-initializer to device and constant variables 1156 // that don't have an initial value, and UndefValue to shared 1157 // variables, so skip warning for this case. 1158 if (!GVar->getInitializer()->isNullValue() && 1159 !isa<UndefValue>(GVar->getInitializer())) { 1160 report_fatal_error("initial value of '" + GVar->getName() + 1161 "' is not allowed in addrspace(" + 1162 Twine(PTy->getAddressSpace()) + ")"); 1163 } 1164 } 1165 } 1166 } else { 1167 unsigned int ElementSize = 0; 1168 1169 // Although PTX has direct support for struct type and array type and 1170 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1171 // targets that support these high level field accesses. Structs, arrays 1172 // and vectors are lowered into arrays of bytes. 1173 switch (ETy->getTypeID()) { 1174 case Type::IntegerTyID: // Integers larger than 64 bits 1175 case Type::StructTyID: 1176 case Type::ArrayTyID: 1177 case Type::FixedVectorTyID: 1178 ElementSize = DL.getTypeStoreSize(ETy); 1179 // Ptx allows variable initilization only for constant and 1180 // global state spaces. 1181 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1182 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1183 GVar->hasInitializer()) { 1184 const Constant *Initializer = GVar->getInitializer(); 1185 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1186 AggBuffer aggBuffer(ElementSize, O, *this); 1187 bufferAggregateConstant(Initializer, &aggBuffer); 1188 if (aggBuffer.numSymbols) { 1189 if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) { 1190 O << " .u64 "; 1191 getSymbol(GVar)->print(O, MAI); 1192 O << "["; 1193 O << ElementSize / 8; 1194 } else { 1195 O << " .u32 "; 1196 getSymbol(GVar)->print(O, MAI); 1197 O << "["; 1198 O << ElementSize / 4; 1199 } 1200 O << "]"; 1201 } else { 1202 O << " .b8 "; 1203 getSymbol(GVar)->print(O, MAI); 1204 O << "["; 1205 O << ElementSize; 1206 O << "]"; 1207 } 1208 O << " = {"; 1209 aggBuffer.print(); 1210 O << "}"; 1211 } else { 1212 O << " .b8 "; 1213 getSymbol(GVar)->print(O, MAI); 1214 if (ElementSize) { 1215 O << "["; 1216 O << ElementSize; 1217 O << "]"; 1218 } 1219 } 1220 } else { 1221 O << " .b8 "; 1222 getSymbol(GVar)->print(O, MAI); 1223 if (ElementSize) { 1224 O << "["; 1225 O << ElementSize; 1226 O << "]"; 1227 } 1228 } 1229 break; 1230 default: 1231 llvm_unreachable("type not supported yet"); 1232 } 1233 } 1234 O << ";\n"; 1235 } 1236 1237 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1238 if (localDecls.find(f) == localDecls.end()) 1239 return; 1240 1241 std::vector<const GlobalVariable *> &gvars = localDecls[f]; 1242 1243 for (unsigned i = 0, e = gvars.size(); i != e; ++i) { 1244 O << "\t// demoted variable\n\t"; 1245 printModuleLevelGV(gvars[i], O, true); 1246 } 1247 } 1248 1249 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1250 raw_ostream &O) const { 1251 switch (AddressSpace) { 1252 case ADDRESS_SPACE_LOCAL: 1253 O << "local"; 1254 break; 1255 case ADDRESS_SPACE_GLOBAL: 1256 O << "global"; 1257 break; 1258 case ADDRESS_SPACE_CONST: 1259 O << "const"; 1260 break; 1261 case ADDRESS_SPACE_SHARED: 1262 O << "shared"; 1263 break; 1264 default: 1265 report_fatal_error("Bad address space found while emitting PTX: " + 1266 llvm::Twine(AddressSpace)); 1267 break; 1268 } 1269 } 1270 1271 std::string 1272 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1273 switch (Ty->getTypeID()) { 1274 case Type::IntegerTyID: { 1275 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1276 if (NumBits == 1) 1277 return "pred"; 1278 else if (NumBits <= 64) { 1279 std::string name = "u"; 1280 return name + utostr(NumBits); 1281 } else { 1282 llvm_unreachable("Integer too large"); 1283 break; 1284 } 1285 break; 1286 } 1287 case Type::HalfTyID: 1288 // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. 1289 return "b16"; 1290 case Type::FloatTyID: 1291 return "f32"; 1292 case Type::DoubleTyID: 1293 return "f64"; 1294 case Type::PointerTyID: 1295 if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) 1296 if (useB4PTR) 1297 return "b64"; 1298 else 1299 return "u64"; 1300 else if (useB4PTR) 1301 return "b32"; 1302 else 1303 return "u32"; 1304 default: 1305 break; 1306 } 1307 llvm_unreachable("unexpected type"); 1308 } 1309 1310 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1311 raw_ostream &O) { 1312 const DataLayout &DL = getDataLayout(); 1313 1314 // GlobalVariables are always constant pointers themselves. 1315 Type *ETy = GVar->getValueType(); 1316 1317 O << "."; 1318 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1319 if (GVar->getAlignment() == 0) 1320 O << " .align " << (int)DL.getPrefTypeAlignment(ETy); 1321 else 1322 O << " .align " << GVar->getAlignment(); 1323 1324 // Special case for i128 1325 if (ETy->isIntegerTy(128)) { 1326 O << " .b8 "; 1327 getSymbol(GVar)->print(O, MAI); 1328 O << "[16]"; 1329 return; 1330 } 1331 1332 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1333 O << " ."; 1334 O << getPTXFundamentalTypeStr(ETy); 1335 O << " "; 1336 getSymbol(GVar)->print(O, MAI); 1337 return; 1338 } 1339 1340 int64_t ElementSize = 0; 1341 1342 // Although PTX has direct support for struct type and array type and LLVM IR 1343 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1344 // support these high level field accesses. Structs and arrays are lowered 1345 // into arrays of bytes. 1346 switch (ETy->getTypeID()) { 1347 case Type::StructTyID: 1348 case Type::ArrayTyID: 1349 case Type::FixedVectorTyID: 1350 ElementSize = DL.getTypeStoreSize(ETy); 1351 O << " .b8 "; 1352 getSymbol(GVar)->print(O, MAI); 1353 O << "["; 1354 if (ElementSize) { 1355 O << ElementSize; 1356 } 1357 O << "]"; 1358 break; 1359 default: 1360 llvm_unreachable("type not supported yet"); 1361 } 1362 } 1363 1364 static unsigned int getOpenCLAlignment(const DataLayout &DL, Type *Ty) { 1365 if (Ty->isSingleValueType()) 1366 return DL.getPrefTypeAlignment(Ty); 1367 1368 auto *ATy = dyn_cast<ArrayType>(Ty); 1369 if (ATy) 1370 return getOpenCLAlignment(DL, ATy->getElementType()); 1371 1372 auto *STy = dyn_cast<StructType>(Ty); 1373 if (STy) { 1374 unsigned int alignStruct = 1; 1375 // Go through each element of the struct and find the 1376 // largest alignment. 1377 for (unsigned i = 0, e = STy->getNumElements(); i != e; i++) { 1378 Type *ETy = STy->getElementType(i); 1379 unsigned int align = getOpenCLAlignment(DL, ETy); 1380 if (align > alignStruct) 1381 alignStruct = align; 1382 } 1383 return alignStruct; 1384 } 1385 1386 auto *FTy = dyn_cast<FunctionType>(Ty); 1387 if (FTy) 1388 return DL.getPointerPrefAlignment().value(); 1389 return DL.getPrefTypeAlignment(Ty); 1390 } 1391 1392 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, 1393 int paramIndex, raw_ostream &O) { 1394 getSymbol(I->getParent())->print(O, MAI); 1395 O << "_param_" << paramIndex; 1396 } 1397 1398 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1399 const DataLayout &DL = getDataLayout(); 1400 const AttributeList &PAL = F->getAttributes(); 1401 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1402 const TargetLowering *TLI = STI.getTargetLowering(); 1403 Function::const_arg_iterator I, E; 1404 unsigned paramIndex = 0; 1405 bool first = true; 1406 bool isKernelFunc = isKernelFunction(*F); 1407 bool isABI = (STI.getSmVersion() >= 20); 1408 bool hasImageHandles = STI.hasImageHandles(); 1409 MVT thePointerTy = TLI->getPointerTy(DL); 1410 1411 if (F->arg_empty()) { 1412 O << "()\n"; 1413 return; 1414 } 1415 1416 O << "(\n"; 1417 1418 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1419 Type *Ty = I->getType(); 1420 1421 if (!first) 1422 O << ",\n"; 1423 1424 first = false; 1425 1426 // Handle image/sampler parameters 1427 if (isKernelFunction(*F)) { 1428 if (isSampler(*I) || isImage(*I)) { 1429 if (isImage(*I)) { 1430 std::string sname = std::string(I->getName()); 1431 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1432 if (hasImageHandles) 1433 O << "\t.param .u64 .ptr .surfref "; 1434 else 1435 O << "\t.param .surfref "; 1436 CurrentFnSym->print(O, MAI); 1437 O << "_param_" << paramIndex; 1438 } 1439 else { // Default image is read_only 1440 if (hasImageHandles) 1441 O << "\t.param .u64 .ptr .texref "; 1442 else 1443 O << "\t.param .texref "; 1444 CurrentFnSym->print(O, MAI); 1445 O << "_param_" << paramIndex; 1446 } 1447 } else { 1448 if (hasImageHandles) 1449 O << "\t.param .u64 .ptr .samplerref "; 1450 else 1451 O << "\t.param .samplerref "; 1452 CurrentFnSym->print(O, MAI); 1453 O << "_param_" << paramIndex; 1454 } 1455 continue; 1456 } 1457 } 1458 1459 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1460 if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { 1461 // Just print .param .align <a> .b8 .param[size]; 1462 // <a> = PAL.getparamalignment 1463 // size = typeallocsize of element type 1464 const Align align = DL.getValueOrABITypeAlignment( 1465 PAL.getParamAlignment(paramIndex), Ty); 1466 1467 unsigned sz = DL.getTypeAllocSize(Ty); 1468 O << "\t.param .align " << align.value() << " .b8 "; 1469 printParamName(I, paramIndex, O); 1470 O << "[" << sz << "]"; 1471 1472 continue; 1473 } 1474 // Just a scalar 1475 auto *PTy = dyn_cast<PointerType>(Ty); 1476 if (isKernelFunc) { 1477 if (PTy) { 1478 // Special handling for pointer arguments to kernel 1479 O << "\t.param .u" << thePointerTy.getSizeInBits() << " "; 1480 1481 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() != 1482 NVPTX::CUDA) { 1483 Type *ETy = PTy->getElementType(); 1484 int addrSpace = PTy->getAddressSpace(); 1485 switch (addrSpace) { 1486 default: 1487 O << ".ptr "; 1488 break; 1489 case ADDRESS_SPACE_CONST: 1490 O << ".ptr .const "; 1491 break; 1492 case ADDRESS_SPACE_SHARED: 1493 O << ".ptr .shared "; 1494 break; 1495 case ADDRESS_SPACE_GLOBAL: 1496 O << ".ptr .global "; 1497 break; 1498 } 1499 O << ".align " << (int)getOpenCLAlignment(DL, ETy) << " "; 1500 } 1501 printParamName(I, paramIndex, O); 1502 continue; 1503 } 1504 1505 // non-pointer scalar to kernel func 1506 O << "\t.param ."; 1507 // Special case: predicate operands become .u8 types 1508 if (Ty->isIntegerTy(1)) 1509 O << "u8"; 1510 else 1511 O << getPTXFundamentalTypeStr(Ty); 1512 O << " "; 1513 printParamName(I, paramIndex, O); 1514 continue; 1515 } 1516 // Non-kernel function, just print .param .b<size> for ABI 1517 // and .reg .b<size> for non-ABI 1518 unsigned sz = 0; 1519 if (isa<IntegerType>(Ty)) { 1520 sz = cast<IntegerType>(Ty)->getBitWidth(); 1521 if (sz < 32) 1522 sz = 32; 1523 } else if (isa<PointerType>(Ty)) 1524 sz = thePointerTy.getSizeInBits(); 1525 else if (Ty->isHalfTy()) 1526 // PTX ABI requires all scalar parameters to be at least 32 1527 // bits in size. fp16 normally uses .b16 as its storage type 1528 // in PTX, so its size must be adjusted here, too. 1529 sz = 32; 1530 else 1531 sz = Ty->getPrimitiveSizeInBits(); 1532 if (isABI) 1533 O << "\t.param .b" << sz << " "; 1534 else 1535 O << "\t.reg .b" << sz << " "; 1536 printParamName(I, paramIndex, O); 1537 continue; 1538 } 1539 1540 // param has byVal attribute. So should be a pointer 1541 auto *PTy = dyn_cast<PointerType>(Ty); 1542 assert(PTy && "Param with byval attribute should be a pointer type"); 1543 Type *ETy = PTy->getElementType(); 1544 1545 if (isABI || isKernelFunc) { 1546 // Just print .param .align <a> .b8 .param[size]; 1547 // <a> = PAL.getparamalignment 1548 // size = typeallocsize of element type 1549 Align align = 1550 DL.getValueOrABITypeAlignment(PAL.getParamAlignment(paramIndex), ETy); 1551 // Work around a bug in ptxas. When PTX code takes address of 1552 // byval parameter with alignment < 4, ptxas generates code to 1553 // spill argument into memory. Alas on sm_50+ ptxas generates 1554 // SASS code that fails with misaligned access. To work around 1555 // the problem, make sure that we align byval parameters by at 1556 // least 4. Matching change must be made in LowerCall() where we 1557 // prepare parameters for the call. 1558 // 1559 // TODO: this will need to be undone when we get to support multi-TU 1560 // device-side compilation as it breaks ABI compatibility with nvcc. 1561 // Hopefully ptxas bug is fixed by then. 1562 if (!isKernelFunc && align < Align(4)) 1563 align = Align(4); 1564 unsigned sz = DL.getTypeAllocSize(ETy); 1565 O << "\t.param .align " << align.value() << " .b8 "; 1566 printParamName(I, paramIndex, O); 1567 O << "[" << sz << "]"; 1568 continue; 1569 } else { 1570 // Split the ETy into constituent parts and 1571 // print .param .b<size> <name> for each part. 1572 // Further, if a part is vector, print the above for 1573 // each vector element. 1574 SmallVector<EVT, 16> vtparts; 1575 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1576 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1577 unsigned elems = 1; 1578 EVT elemtype = vtparts[i]; 1579 if (vtparts[i].isVector()) { 1580 elems = vtparts[i].getVectorNumElements(); 1581 elemtype = vtparts[i].getVectorElementType(); 1582 } 1583 1584 for (unsigned j = 0, je = elems; j != je; ++j) { 1585 unsigned sz = elemtype.getSizeInBits(); 1586 if (elemtype.isInteger() && (sz < 32)) 1587 sz = 32; 1588 O << "\t.reg .b" << sz << " "; 1589 printParamName(I, paramIndex, O); 1590 if (j < je - 1) 1591 O << ",\n"; 1592 ++paramIndex; 1593 } 1594 if (i < e - 1) 1595 O << ",\n"; 1596 } 1597 --paramIndex; 1598 continue; 1599 } 1600 } 1601 1602 O << "\n)\n"; 1603 } 1604 1605 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF, 1606 raw_ostream &O) { 1607 const Function &F = MF.getFunction(); 1608 emitFunctionParamList(&F, O); 1609 } 1610 1611 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1612 const MachineFunction &MF) { 1613 SmallString<128> Str; 1614 raw_svector_ostream O(Str); 1615 1616 // Map the global virtual register number to a register class specific 1617 // virtual register number starting from 1 with that class. 1618 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1619 //unsigned numRegClasses = TRI->getNumRegClasses(); 1620 1621 // Emit the Fake Stack Object 1622 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1623 int NumBytes = (int) MFI.getStackSize(); 1624 if (NumBytes) { 1625 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1626 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1627 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1628 O << "\t.reg .b64 \t%SP;\n"; 1629 O << "\t.reg .b64 \t%SPL;\n"; 1630 } else { 1631 O << "\t.reg .b32 \t%SP;\n"; 1632 O << "\t.reg .b32 \t%SPL;\n"; 1633 } 1634 } 1635 1636 // Go through all virtual registers to establish the mapping between the 1637 // global virtual 1638 // register number and the per class virtual register number. 1639 // We use the per class virtual register number in the ptx output. 1640 unsigned int numVRs = MRI->getNumVirtRegs(); 1641 for (unsigned i = 0; i < numVRs; i++) { 1642 unsigned int vr = Register::index2VirtReg(i); 1643 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1644 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1645 int n = regmap.size(); 1646 regmap.insert(std::make_pair(vr, n + 1)); 1647 } 1648 1649 // Emit register declarations 1650 // @TODO: Extract out the real register usage 1651 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1652 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1653 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1654 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1655 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1656 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1657 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1658 1659 // Emit declaration of the virtual registers or 'physical' registers for 1660 // each register class 1661 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1662 const TargetRegisterClass *RC = TRI->getRegClass(i); 1663 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1664 std::string rcname = getNVPTXRegClassName(RC); 1665 std::string rcStr = getNVPTXRegClassStr(RC); 1666 int n = regmap.size(); 1667 1668 // Only declare those registers that may be used. 1669 if (n) { 1670 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1671 << ">;\n"; 1672 } 1673 } 1674 1675 OutStreamer->emitRawText(O.str()); 1676 } 1677 1678 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1679 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1680 bool ignored; 1681 unsigned int numHex; 1682 const char *lead; 1683 1684 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1685 numHex = 8; 1686 lead = "0f"; 1687 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1688 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1689 numHex = 16; 1690 lead = "0d"; 1691 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1692 } else 1693 llvm_unreachable("unsupported fp type"); 1694 1695 APInt API = APF.bitcastToAPInt(); 1696 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1697 } 1698 1699 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1700 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1701 O << CI->getValue(); 1702 return; 1703 } 1704 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1705 printFPConstant(CFP, O); 1706 return; 1707 } 1708 if (isa<ConstantPointerNull>(CPV)) { 1709 O << "0"; 1710 return; 1711 } 1712 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1713 bool IsNonGenericPointer = false; 1714 if (GVar->getType()->getAddressSpace() != 0) { 1715 IsNonGenericPointer = true; 1716 } 1717 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1718 O << "generic("; 1719 getSymbol(GVar)->print(O, MAI); 1720 O << ")"; 1721 } else { 1722 getSymbol(GVar)->print(O, MAI); 1723 } 1724 return; 1725 } 1726 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1727 const Value *v = Cexpr->stripPointerCasts(); 1728 PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType()); 1729 bool IsNonGenericPointer = false; 1730 if (PTy && PTy->getAddressSpace() != 0) { 1731 IsNonGenericPointer = true; 1732 } 1733 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1734 if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) { 1735 O << "generic("; 1736 getSymbol(GVar)->print(O, MAI); 1737 O << ")"; 1738 } else { 1739 getSymbol(GVar)->print(O, MAI); 1740 } 1741 return; 1742 } else { 1743 lowerConstant(CPV)->print(O, MAI); 1744 return; 1745 } 1746 } 1747 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1748 } 1749 1750 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1751 AggBuffer *AggBuffer) { 1752 const DataLayout &DL = getDataLayout(); 1753 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1754 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1755 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1756 // only the space allocated by CPV. 1757 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1758 return; 1759 } 1760 1761 // Helper for filling AggBuffer with APInts. 1762 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1763 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1764 SmallVector<unsigned char, 16> Buf(NumBytes); 1765 for (unsigned I = 0; I < NumBytes; ++I) { 1766 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1767 } 1768 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1769 }; 1770 1771 switch (CPV->getType()->getTypeID()) { 1772 case Type::IntegerTyID: 1773 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1774 AddIntToBuffer(CI->getValue()); 1775 break; 1776 } 1777 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1778 if (const auto *CI = 1779 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1780 AddIntToBuffer(CI->getValue()); 1781 break; 1782 } 1783 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1784 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1785 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1786 AggBuffer->addZeros(AllocSize); 1787 break; 1788 } 1789 } 1790 llvm_unreachable("unsupported integer const type"); 1791 break; 1792 1793 case Type::HalfTyID: 1794 case Type::FloatTyID: 1795 case Type::DoubleTyID: 1796 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1797 break; 1798 1799 case Type::PointerTyID: { 1800 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1801 AggBuffer->addSymbol(GVar, GVar); 1802 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1803 const Value *v = Cexpr->stripPointerCasts(); 1804 AggBuffer->addSymbol(v, Cexpr); 1805 } 1806 AggBuffer->addZeros(AllocSize); 1807 break; 1808 } 1809 1810 case Type::ArrayTyID: 1811 case Type::FixedVectorTyID: 1812 case Type::StructTyID: { 1813 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1814 bufferAggregateConstant(CPV, AggBuffer); 1815 if (Bytes > AllocSize) 1816 AggBuffer->addZeros(Bytes - AllocSize); 1817 } else if (isa<ConstantAggregateZero>(CPV)) 1818 AggBuffer->addZeros(Bytes); 1819 else 1820 llvm_unreachable("Unexpected Constant type"); 1821 break; 1822 } 1823 1824 default: 1825 llvm_unreachable("unsupported type"); 1826 } 1827 } 1828 1829 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1830 AggBuffer *aggBuffer) { 1831 const DataLayout &DL = getDataLayout(); 1832 int Bytes; 1833 1834 // Integers of arbitrary width 1835 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1836 APInt Val = CI->getValue(); 1837 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1838 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1839 aggBuffer->addBytes(&Byte, 1, 1); 1840 Val.lshrInPlace(8); 1841 } 1842 return; 1843 } 1844 1845 // Old constants 1846 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1847 if (CPV->getNumOperands()) 1848 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1849 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1850 return; 1851 } 1852 1853 if (const ConstantDataSequential *CDS = 1854 dyn_cast<ConstantDataSequential>(CPV)) { 1855 if (CDS->getNumElements()) 1856 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1857 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1858 aggBuffer); 1859 return; 1860 } 1861 1862 if (isa<ConstantStruct>(CPV)) { 1863 if (CPV->getNumOperands()) { 1864 StructType *ST = cast<StructType>(CPV->getType()); 1865 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1866 if (i == (e - 1)) 1867 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1868 DL.getTypeAllocSize(ST) - 1869 DL.getStructLayout(ST)->getElementOffset(i); 1870 else 1871 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1872 DL.getStructLayout(ST)->getElementOffset(i); 1873 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1874 } 1875 } 1876 return; 1877 } 1878 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 1879 } 1880 1881 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 1882 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 1883 /// expressions that are representable in PTX and create 1884 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 1885 const MCExpr * 1886 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 1887 MCContext &Ctx = OutContext; 1888 1889 if (CV->isNullValue() || isa<UndefValue>(CV)) 1890 return MCConstantExpr::create(0, Ctx); 1891 1892 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 1893 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 1894 1895 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 1896 const MCSymbolRefExpr *Expr = 1897 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 1898 if (ProcessingGeneric) { 1899 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 1900 } else { 1901 return Expr; 1902 } 1903 } 1904 1905 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 1906 if (!CE) { 1907 llvm_unreachable("Unknown constant value to lower!"); 1908 } 1909 1910 switch (CE->getOpcode()) { 1911 default: { 1912 // If the code isn't optimized, there may be outstanding folding 1913 // opportunities. Attempt to fold the expression using DataLayout as a 1914 // last resort before giving up. 1915 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 1916 if (C != CE) 1917 return lowerConstantForGV(C, ProcessingGeneric); 1918 1919 // Otherwise report the problem to the user. 1920 std::string S; 1921 raw_string_ostream OS(S); 1922 OS << "Unsupported expression in static initializer: "; 1923 CE->printAsOperand(OS, /*PrintType=*/false, 1924 !MF ? nullptr : MF->getFunction().getParent()); 1925 report_fatal_error(Twine(OS.str())); 1926 } 1927 1928 case Instruction::AddrSpaceCast: { 1929 // Strip the addrspacecast and pass along the operand 1930 PointerType *DstTy = cast<PointerType>(CE->getType()); 1931 if (DstTy->getAddressSpace() == 0) { 1932 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 1933 } 1934 std::string S; 1935 raw_string_ostream OS(S); 1936 OS << "Unsupported expression in static initializer: "; 1937 CE->printAsOperand(OS, /*PrintType=*/ false, 1938 !MF ? nullptr : MF->getFunction().getParent()); 1939 report_fatal_error(Twine(OS.str())); 1940 } 1941 1942 case Instruction::GetElementPtr: { 1943 const DataLayout &DL = getDataLayout(); 1944 1945 // Generate a symbolic expression for the byte address 1946 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 1947 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 1948 1949 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 1950 ProcessingGeneric); 1951 if (!OffsetAI) 1952 return Base; 1953 1954 int64_t Offset = OffsetAI.getSExtValue(); 1955 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 1956 Ctx); 1957 } 1958 1959 case Instruction::Trunc: 1960 // We emit the value and depend on the assembler to truncate the generated 1961 // expression properly. This is important for differences between 1962 // blockaddress labels. Since the two labels are in the same function, it 1963 // is reasonable to treat their delta as a 32-bit value. 1964 LLVM_FALLTHROUGH; 1965 case Instruction::BitCast: 1966 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 1967 1968 case Instruction::IntToPtr: { 1969 const DataLayout &DL = getDataLayout(); 1970 1971 // Handle casts to pointers by changing them into casts to the appropriate 1972 // integer type. This promotes constant folding and simplifies this code. 1973 Constant *Op = CE->getOperand(0); 1974 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()), 1975 false/*ZExt*/); 1976 return lowerConstantForGV(Op, ProcessingGeneric); 1977 } 1978 1979 case Instruction::PtrToInt: { 1980 const DataLayout &DL = getDataLayout(); 1981 1982 // Support only foldable casts to/from pointers that can be eliminated by 1983 // changing the pointer to the appropriately sized integer type. 1984 Constant *Op = CE->getOperand(0); 1985 Type *Ty = CE->getType(); 1986 1987 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 1988 1989 // We can emit the pointer value into this slot if the slot is an 1990 // integer slot equal to the size of the pointer. 1991 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 1992 return OpExpr; 1993 1994 // Otherwise the pointer is smaller than the resultant integer, mask off 1995 // the high bits so we are sure to get a proper truncation if the input is 1996 // a constant expr. 1997 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 1998 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 1999 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2000 } 2001 2002 // The MC library also has a right-shift operator, but it isn't consistently 2003 // signed or unsigned between different targets. 2004 case Instruction::Add: { 2005 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2006 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2007 switch (CE->getOpcode()) { 2008 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2009 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2010 } 2011 } 2012 } 2013 } 2014 2015 // Copy of MCExpr::print customized for NVPTX 2016 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2017 switch (Expr.getKind()) { 2018 case MCExpr::Target: 2019 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2020 case MCExpr::Constant: 2021 OS << cast<MCConstantExpr>(Expr).getValue(); 2022 return; 2023 2024 case MCExpr::SymbolRef: { 2025 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2026 const MCSymbol &Sym = SRE.getSymbol(); 2027 Sym.print(OS, MAI); 2028 return; 2029 } 2030 2031 case MCExpr::Unary: { 2032 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2033 switch (UE.getOpcode()) { 2034 case MCUnaryExpr::LNot: OS << '!'; break; 2035 case MCUnaryExpr::Minus: OS << '-'; break; 2036 case MCUnaryExpr::Not: OS << '~'; break; 2037 case MCUnaryExpr::Plus: OS << '+'; break; 2038 } 2039 printMCExpr(*UE.getSubExpr(), OS); 2040 return; 2041 } 2042 2043 case MCExpr::Binary: { 2044 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2045 2046 // Only print parens around the LHS if it is non-trivial. 2047 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2048 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2049 printMCExpr(*BE.getLHS(), OS); 2050 } else { 2051 OS << '('; 2052 printMCExpr(*BE.getLHS(), OS); 2053 OS<< ')'; 2054 } 2055 2056 switch (BE.getOpcode()) { 2057 case MCBinaryExpr::Add: 2058 // Print "X-42" instead of "X+-42". 2059 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2060 if (RHSC->getValue() < 0) { 2061 OS << RHSC->getValue(); 2062 return; 2063 } 2064 } 2065 2066 OS << '+'; 2067 break; 2068 default: llvm_unreachable("Unhandled binary operator"); 2069 } 2070 2071 // Only print parens around the LHS if it is non-trivial. 2072 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2073 printMCExpr(*BE.getRHS(), OS); 2074 } else { 2075 OS << '('; 2076 printMCExpr(*BE.getRHS(), OS); 2077 OS << ')'; 2078 } 2079 return; 2080 } 2081 } 2082 2083 llvm_unreachable("Invalid expression kind!"); 2084 } 2085 2086 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2087 /// 2088 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2089 const char *ExtraCode, raw_ostream &O) { 2090 if (ExtraCode && ExtraCode[0]) { 2091 if (ExtraCode[1] != 0) 2092 return true; // Unknown modifier. 2093 2094 switch (ExtraCode[0]) { 2095 default: 2096 // See if this is a generic print operand 2097 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2098 case 'r': 2099 break; 2100 } 2101 } 2102 2103 printOperand(MI, OpNo, O); 2104 2105 return false; 2106 } 2107 2108 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2109 unsigned OpNo, 2110 const char *ExtraCode, 2111 raw_ostream &O) { 2112 if (ExtraCode && ExtraCode[0]) 2113 return true; // Unknown modifier 2114 2115 O << '['; 2116 printMemOperand(MI, OpNo, O); 2117 O << ']'; 2118 2119 return false; 2120 } 2121 2122 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, 2123 raw_ostream &O) { 2124 const MachineOperand &MO = MI->getOperand(opNum); 2125 switch (MO.getType()) { 2126 case MachineOperand::MO_Register: 2127 if (Register::isPhysicalRegister(MO.getReg())) { 2128 if (MO.getReg() == NVPTX::VRDepot) 2129 O << DEPOTNAME << getFunctionNumber(); 2130 else 2131 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2132 } else { 2133 emitVirtualRegister(MO.getReg(), O); 2134 } 2135 break; 2136 2137 case MachineOperand::MO_Immediate: 2138 O << MO.getImm(); 2139 break; 2140 2141 case MachineOperand::MO_FPImmediate: 2142 printFPConstant(MO.getFPImm(), O); 2143 break; 2144 2145 case MachineOperand::MO_GlobalAddress: 2146 PrintSymbolOperand(MO, O); 2147 break; 2148 2149 case MachineOperand::MO_MachineBasicBlock: 2150 MO.getMBB()->getSymbol()->print(O, MAI); 2151 break; 2152 2153 default: 2154 llvm_unreachable("Operand type not supported."); 2155 } 2156 } 2157 2158 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, 2159 raw_ostream &O, const char *Modifier) { 2160 printOperand(MI, opNum, O); 2161 2162 if (Modifier && strcmp(Modifier, "add") == 0) { 2163 O << ", "; 2164 printOperand(MI, opNum + 1, O); 2165 } else { 2166 if (MI->getOperand(opNum + 1).isImm() && 2167 MI->getOperand(opNum + 1).getImm() == 0) 2168 return; // don't print ',0' or '+0' 2169 O << "+"; 2170 printOperand(MI, opNum + 1, O); 2171 } 2172 } 2173 2174 // Force static initialization. 2175 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2176 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2177 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2178 } 2179