1 //===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to the SPIR-V assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "MCTargetDesc/SPIRVInstPrinter.h" 15 #include "SPIRV.h" 16 #include "SPIRVInstrInfo.h" 17 #include "SPIRVMCInstLower.h" 18 #include "SPIRVModuleAnalysis.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 #include "TargetInfo/SPIRVTargetInfo.h" 23 #include "llvm/ADT/DenseMap.h" 24 #include "llvm/Analysis/ValueTracking.h" 25 #include "llvm/CodeGen/AsmPrinter.h" 26 #include "llvm/CodeGen/MachineConstantPool.h" 27 #include "llvm/CodeGen/MachineFunctionPass.h" 28 #include "llvm/CodeGen/MachineInstr.h" 29 #include "llvm/CodeGen/MachineModuleInfo.h" 30 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" 31 #include "llvm/MC/MCAsmInfo.h" 32 #include "llvm/MC/MCInst.h" 33 #include "llvm/MC/MCStreamer.h" 34 #include "llvm/MC/MCSymbol.h" 35 #include "llvm/MC/TargetRegistry.h" 36 #include "llvm/Support/raw_ostream.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "asm-printer" 41 42 namespace { 43 class SPIRVAsmPrinter : public AsmPrinter { 44 public: 45 explicit SPIRVAsmPrinter(TargetMachine &TM, 46 std::unique_ptr<MCStreamer> Streamer) 47 : AsmPrinter(TM, std::move(Streamer)), ST(nullptr), TII(nullptr) {} 48 bool ModuleSectionsEmitted; 49 const SPIRVSubtarget *ST; 50 const SPIRVInstrInfo *TII; 51 52 StringRef getPassName() const override { return "SPIRV Assembly Printer"; } 53 void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O); 54 bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 55 const char *ExtraCode, raw_ostream &O) override; 56 57 void outputMCInst(MCInst &Inst); 58 void outputInstruction(const MachineInstr *MI); 59 void outputModuleSection(SPIRV::ModuleSectionType MSType); 60 void outputEntryPoints(); 61 void outputDebugSourceAndStrings(const Module &M); 62 void outputOpExtInstImports(const Module &M); 63 void outputOpMemoryModel(); 64 void outputOpFunctionEnd(); 65 void outputExtFuncDecls(); 66 void outputExecutionModeFromMDNode(Register Reg, MDNode *Node, 67 SPIRV::ExecutionMode EM); 68 void outputExecutionMode(const Module &M); 69 void outputAnnotations(const Module &M); 70 void outputModuleSections(); 71 72 void emitInstruction(const MachineInstr *MI) override; 73 void emitFunctionEntryLabel() override {} 74 void emitFunctionHeader() override; 75 void emitFunctionBodyStart() override {} 76 void emitFunctionBodyEnd() override; 77 void emitBasicBlockStart(const MachineBasicBlock &MBB) override; 78 void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {} 79 void emitGlobalVariable(const GlobalVariable *GV) override {} 80 void emitOpLabel(const MachineBasicBlock &MBB); 81 void emitEndOfAsmFile(Module &M) override; 82 bool doInitialization(Module &M) override; 83 84 void getAnalysisUsage(AnalysisUsage &AU) const override; 85 SPIRV::ModuleAnalysisInfo *MAI; 86 }; 87 } // namespace 88 89 void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const { 90 AU.addRequired<SPIRVModuleAnalysis>(); 91 AU.addPreserved<SPIRVModuleAnalysis>(); 92 AsmPrinter::getAnalysisUsage(AU); 93 } 94 95 // If the module has no functions, we need output global info anyway. 96 void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) { 97 if (ModuleSectionsEmitted == false) { 98 outputModuleSections(); 99 ModuleSectionsEmitted = true; 100 } 101 } 102 103 void SPIRVAsmPrinter::emitFunctionHeader() { 104 if (ModuleSectionsEmitted == false) { 105 outputModuleSections(); 106 ModuleSectionsEmitted = true; 107 } 108 // Get the subtarget from the current MachineFunction. 109 ST = &MF->getSubtarget<SPIRVSubtarget>(); 110 TII = ST->getInstrInfo(); 111 const Function &F = MF->getFunction(); 112 113 if (isVerbose()) { 114 OutStreamer->getCommentOS() 115 << "-- Begin function " 116 << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n'; 117 } 118 119 auto Section = getObjFileLowering().SectionForGlobal(&F, TM); 120 MF->setSection(Section); 121 } 122 123 void SPIRVAsmPrinter::outputOpFunctionEnd() { 124 MCInst FunctionEndInst; 125 FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd); 126 outputMCInst(FunctionEndInst); 127 } 128 129 // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap. 130 void SPIRVAsmPrinter::emitFunctionBodyEnd() { 131 outputOpFunctionEnd(); 132 MAI->BBNumToRegMap.clear(); 133 } 134 135 void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { 136 if (MAI->MBBsToSkip.contains(&MBB)) 137 return; 138 MCInst LabelInst; 139 LabelInst.setOpcode(SPIRV::OpLabel); 140 LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); 141 outputMCInst(LabelInst); 142 } 143 144 void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 145 // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so 146 // OpLabel should be output after them. 147 if (MBB.getNumber() == MF->front().getNumber()) { 148 for (const MachineInstr &MI : MBB) 149 if (MI.getOpcode() == SPIRV::OpFunction) 150 return; 151 // TODO: this case should be checked by the verifier. 152 report_fatal_error("OpFunction is expected in the front MBB of MF"); 153 } 154 emitOpLabel(MBB); 155 } 156 157 void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum, 158 raw_ostream &O) { 159 const MachineOperand &MO = MI->getOperand(OpNum); 160 161 switch (MO.getType()) { 162 case MachineOperand::MO_Register: 163 O << SPIRVInstPrinter::getRegisterName(MO.getReg()); 164 break; 165 166 case MachineOperand::MO_Immediate: 167 O << MO.getImm(); 168 break; 169 170 case MachineOperand::MO_FPImmediate: 171 O << MO.getFPImm(); 172 break; 173 174 case MachineOperand::MO_MachineBasicBlock: 175 O << *MO.getMBB()->getSymbol(); 176 break; 177 178 case MachineOperand::MO_GlobalAddress: 179 O << *getSymbol(MO.getGlobal()); 180 break; 181 182 case MachineOperand::MO_BlockAddress: { 183 MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress()); 184 O << BA->getName(); 185 break; 186 } 187 188 case MachineOperand::MO_ExternalSymbol: 189 O << *GetExternalSymbolSymbol(MO.getSymbolName()); 190 break; 191 192 case MachineOperand::MO_JumpTableIndex: 193 case MachineOperand::MO_ConstantPoolIndex: 194 default: 195 llvm_unreachable("<unknown operand type>"); 196 } 197 } 198 199 bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 200 const char *ExtraCode, raw_ostream &O) { 201 if (ExtraCode && ExtraCode[0]) 202 return true; // Invalid instruction - SPIR-V does not have special modifiers 203 204 printOperand(MI, OpNo, O); 205 return false; 206 } 207 208 static bool isFuncOrHeaderInstr(const MachineInstr *MI, 209 const SPIRVInstrInfo *TII) { 210 return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction || 211 MI->getOpcode() == SPIRV::OpFunctionParameter; 212 } 213 214 void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) { 215 OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo()); 216 } 217 218 void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) { 219 SPIRVMCInstLower MCInstLowering; 220 MCInst TmpInst; 221 MCInstLowering.lower(MI, TmpInst, MAI); 222 outputMCInst(TmpInst); 223 } 224 225 void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) { 226 SPIRV_MC::verifyInstructionPredicates(MI->getOpcode(), 227 getSubtargetInfo().getFeatureBits()); 228 229 if (!MAI->getSkipEmission(MI)) 230 outputInstruction(MI); 231 232 // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB. 233 const MachineInstr *NextMI = MI->getNextNode(); 234 if (!MAI->hasMBBRegister(*MI->getParent()) && isFuncOrHeaderInstr(MI, TII) && 235 (!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) { 236 assert(MI->getParent()->getNumber() == MF->front().getNumber() && 237 "OpFunction is not in the front MBB of MF"); 238 emitOpLabel(*MI->getParent()); 239 } 240 } 241 242 void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { 243 for (MachineInstr *MI : MAI->getMSInstrs(MSType)) 244 outputInstruction(MI); 245 } 246 247 void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) { 248 // Output OpSourceExtensions. 249 for (auto &Str : MAI->SrcExt) { 250 MCInst Inst; 251 Inst.setOpcode(SPIRV::OpSourceExtension); 252 addStringImm(Str.first(), Inst); 253 outputMCInst(Inst); 254 } 255 // Output OpSource. 256 MCInst Inst; 257 Inst.setOpcode(SPIRV::OpSource); 258 Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->SrcLang))); 259 Inst.addOperand( 260 MCOperand::createImm(static_cast<unsigned>(MAI->SrcLangVersion))); 261 outputMCInst(Inst); 262 } 263 264 void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) { 265 for (auto &CU : MAI->ExtInstSetMap) { 266 unsigned Set = CU.first; 267 Register Reg = CU.second; 268 MCInst Inst; 269 Inst.setOpcode(SPIRV::OpExtInstImport); 270 Inst.addOperand(MCOperand::createReg(Reg)); 271 addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)), 272 Inst); 273 outputMCInst(Inst); 274 } 275 } 276 277 void SPIRVAsmPrinter::outputOpMemoryModel() { 278 MCInst Inst; 279 Inst.setOpcode(SPIRV::OpMemoryModel); 280 Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Addr))); 281 Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Mem))); 282 outputMCInst(Inst); 283 } 284 285 // Before the OpEntryPoints' output, we need to add the entry point's 286 // interfaces. The interface is a list of IDs of global OpVariable instructions. 287 // These declare the set of global variables from a module that form 288 // the interface of this entry point. 289 void SPIRVAsmPrinter::outputEntryPoints() { 290 // Find all OpVariable IDs with required StorageClass. 291 DenseSet<Register> InterfaceIDs; 292 for (MachineInstr *MI : MAI->GlobalVarList) { 293 assert(MI->getOpcode() == SPIRV::OpVariable); 294 auto SC = static_cast<SPIRV::StorageClass>(MI->getOperand(2).getImm()); 295 // Before version 1.4, the interface's storage classes are limited to 296 // the Input and Output storage classes. Starting with version 1.4, 297 // the interface's storage classes are all storage classes used in 298 // declaring all global variables referenced by the entry point call tree. 299 if (ST->getSPIRVVersion() >= 14 || SC == SPIRV::StorageClass::Input || 300 SC == SPIRV::StorageClass::Output) { 301 MachineFunction *MF = MI->getMF(); 302 Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); 303 InterfaceIDs.insert(Reg); 304 } 305 } 306 307 // Output OpEntryPoints adding interface args to all of them. 308 for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { 309 SPIRVMCInstLower MCInstLowering; 310 MCInst TmpInst; 311 MCInstLowering.lower(MI, TmpInst, MAI); 312 for (Register Reg : InterfaceIDs) { 313 assert(Reg.isValid()); 314 TmpInst.addOperand(MCOperand::createReg(Reg)); 315 } 316 outputMCInst(TmpInst); 317 } 318 } 319 320 void SPIRVAsmPrinter::outputExtFuncDecls() { 321 // Insert OpFunctionEnd after each declaration. 322 SmallVectorImpl<MachineInstr *>::iterator 323 I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), 324 E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); 325 for (; I != E; ++I) { 326 outputInstruction(*I); 327 if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction) 328 outputOpFunctionEnd(); 329 } 330 } 331 332 // Encode LLVM type by SPIR-V execution mode VecTypeHint. 333 static unsigned encodeVecTypeHint(Type *Ty) { 334 if (Ty->isHalfTy()) 335 return 4; 336 if (Ty->isFloatTy()) 337 return 5; 338 if (Ty->isDoubleTy()) 339 return 6; 340 if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) { 341 switch (IntTy->getIntegerBitWidth()) { 342 case 8: 343 return 0; 344 case 16: 345 return 1; 346 case 32: 347 return 2; 348 case 64: 349 return 3; 350 default: 351 llvm_unreachable("invalid integer type"); 352 } 353 } 354 if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) { 355 Type *EleTy = VecTy->getElementType(); 356 unsigned Size = VecTy->getNumElements(); 357 return Size << 16 | encodeVecTypeHint(EleTy); 358 } 359 llvm_unreachable("invalid type"); 360 } 361 362 static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst, 363 SPIRV::ModuleAnalysisInfo *MAI) { 364 for (const MDOperand &MDOp : MDN->operands()) { 365 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 366 Constant *C = CMeta->getValue(); 367 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 368 Inst.addOperand(MCOperand::createImm(Const->getZExtValue())); 369 } else if (auto *CE = dyn_cast<Function>(C)) { 370 Register FuncReg = MAI->getFuncReg(CE->getName().str()); 371 assert(FuncReg.isValid()); 372 Inst.addOperand(MCOperand::createReg(FuncReg)); 373 } 374 } 375 } 376 } 377 378 void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node, 379 SPIRV::ExecutionMode EM) { 380 MCInst Inst; 381 Inst.setOpcode(SPIRV::OpExecutionMode); 382 Inst.addOperand(MCOperand::createReg(Reg)); 383 Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM))); 384 addOpsFromMDNode(Node, Inst, MAI); 385 outputMCInst(Inst); 386 } 387 388 void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { 389 NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); 390 if (Node) { 391 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 392 MCInst Inst; 393 Inst.setOpcode(SPIRV::OpExecutionMode); 394 addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); 395 outputMCInst(Inst); 396 } 397 } 398 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 399 const Function &F = *FI; 400 if (F.isDeclaration()) 401 continue; 402 Register FReg = MAI->getFuncReg(F.getGlobalIdentifier()); 403 assert(FReg.isValid()); 404 if (MDNode *Node = F.getMetadata("reqd_work_group_size")) 405 outputExecutionModeFromMDNode(FReg, Node, 406 SPIRV::ExecutionMode::LocalSize); 407 if (MDNode *Node = F.getMetadata("work_group_size_hint")) 408 outputExecutionModeFromMDNode(FReg, Node, 409 SPIRV::ExecutionMode::LocalSizeHint); 410 if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size")) 411 outputExecutionModeFromMDNode(FReg, Node, 412 SPIRV::ExecutionMode::SubgroupSize); 413 if (MDNode *Node = F.getMetadata("vec_type_hint")) { 414 MCInst Inst; 415 Inst.setOpcode(SPIRV::OpExecutionMode); 416 Inst.addOperand(MCOperand::createReg(FReg)); 417 unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint); 418 Inst.addOperand(MCOperand::createImm(EM)); 419 unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0)); 420 Inst.addOperand(MCOperand::createImm(TypeCode)); 421 outputMCInst(Inst); 422 } 423 } 424 } 425 426 void SPIRVAsmPrinter::outputAnnotations(const Module &M) { 427 outputModuleSection(SPIRV::MB_Annotations); 428 // Process llvm.global.annotations special global variable. 429 for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) { 430 if ((*F).getName() != "llvm.global.annotations") 431 continue; 432 const GlobalVariable *V = &(*F); 433 const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0)); 434 for (Value *Op : CA->operands()) { 435 ConstantStruct *CS = cast<ConstantStruct>(Op); 436 // The first field of the struct contains a pointer to 437 // the annotated variable. 438 Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts(); 439 if (!isa<Function>(AnnotatedVar)) 440 llvm_unreachable("Unsupported value in llvm.global.annotations"); 441 Function *Func = cast<Function>(AnnotatedVar); 442 Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier()); 443 444 // The second field contains a pointer to a global annotation string. 445 GlobalVariable *GV = 446 cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts()); 447 448 StringRef AnnotationString; 449 getConstantStringInfo(GV, AnnotationString); 450 MCInst Inst; 451 Inst.setOpcode(SPIRV::OpDecorate); 452 Inst.addOperand(MCOperand::createReg(Reg)); 453 unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic); 454 Inst.addOperand(MCOperand::createImm(Dec)); 455 addStringImm(AnnotationString, Inst); 456 outputMCInst(Inst); 457 } 458 } 459 } 460 461 void SPIRVAsmPrinter::outputModuleSections() { 462 const Module *M = MMI->getModule(); 463 // Get the global subtarget to output module-level info. 464 ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl(); 465 TII = ST->getInstrInfo(); 466 MAI = &SPIRVModuleAnalysis::MAI; 467 assert(ST && TII && MAI && M && "Module analysis is required"); 468 // Output instructions according to the Logical Layout of a Module: 469 // TODO: 1,2. All OpCapability instructions, then optional OpExtension 470 // instructions. 471 // 3. Optional OpExtInstImport instructions. 472 outputOpExtInstImports(*M); 473 // 4. The single required OpMemoryModel instruction. 474 outputOpMemoryModel(); 475 // 5. All entry point declarations, using OpEntryPoint. 476 outputEntryPoints(); 477 // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. 478 outputExecutionMode(*M); 479 // 7a. Debug: all OpString, OpSourceExtension, OpSource, and 480 // OpSourceContinued, without forward references. 481 outputDebugSourceAndStrings(*M); 482 // 7b. Debug: all OpName and all OpMemberName. 483 outputModuleSection(SPIRV::MB_DebugNames); 484 // 7c. Debug: all OpModuleProcessed instructions. 485 outputModuleSection(SPIRV::MB_DebugModuleProcessed); 486 // 8. All annotation instructions (all decorations). 487 outputAnnotations(*M); 488 // 9. All type declarations (OpTypeXXX instructions), all constant 489 // instructions, and all global variable declarations. This section is 490 // the first section to allow use of: OpLine and OpNoLine debug information; 491 // non-semantic instructions with OpExtInst. 492 outputModuleSection(SPIRV::MB_TypeConstVars); 493 // 10. All function declarations (functions without a body). 494 outputExtFuncDecls(); 495 // 11. All function definitions (functions with a body). 496 // This is done in regular function output. 497 } 498 499 bool SPIRVAsmPrinter::doInitialization(Module &M) { 500 ModuleSectionsEmitted = false; 501 // We need to call the parent's one explicitly. 502 return AsmPrinter::doInitialization(M); 503 } 504 505 // Force static initialization. 506 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() { 507 RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target()); 508 RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target()); 509 } 510