xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp (revision 0c428864495af9dc7d2af4d0a5ae21732af9c739)
1 //===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to the SPIR-V assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "MCTargetDesc/SPIRVInstPrinter.h"
15 #include "SPIRV.h"
16 #include "SPIRVInstrInfo.h"
17 #include "SPIRVMCInstLower.h"
18 #include "SPIRVModuleAnalysis.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22 #include "TargetInfo/SPIRVTargetInfo.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/Analysis/ValueTracking.h"
25 #include "llvm/CodeGen/AsmPrinter.h"
26 #include "llvm/CodeGen/MachineConstantPool.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineInstr.h"
29 #include "llvm/CodeGen/MachineModuleInfo.h"
30 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
31 #include "llvm/MC/MCAsmInfo.h"
32 #include "llvm/MC/MCInst.h"
33 #include "llvm/MC/MCStreamer.h"
34 #include "llvm/MC/MCSymbol.h"
35 #include "llvm/MC/TargetRegistry.h"
36 #include "llvm/Support/raw_ostream.h"
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "asm-printer"
41 
42 namespace {
43 class SPIRVAsmPrinter : public AsmPrinter {
44 public:
45   explicit SPIRVAsmPrinter(TargetMachine &TM,
46                            std::unique_ptr<MCStreamer> Streamer)
47       : AsmPrinter(TM, std::move(Streamer)), ST(nullptr), TII(nullptr) {}
48   bool ModuleSectionsEmitted;
49   const SPIRVSubtarget *ST;
50   const SPIRVInstrInfo *TII;
51 
52   StringRef getPassName() const override { return "SPIRV Assembly Printer"; }
53   void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
54   bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
55                        const char *ExtraCode, raw_ostream &O) override;
56 
57   void outputMCInst(MCInst &Inst);
58   void outputInstruction(const MachineInstr *MI);
59   void outputModuleSection(SPIRV::ModuleSectionType MSType);
60   void outputEntryPoints();
61   void outputDebugSourceAndStrings(const Module &M);
62   void outputOpExtInstImports(const Module &M);
63   void outputOpMemoryModel();
64   void outputOpFunctionEnd();
65   void outputExtFuncDecls();
66   void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
67                                      SPIRV::ExecutionMode EM);
68   void outputExecutionMode(const Module &M);
69   void outputAnnotations(const Module &M);
70   void outputModuleSections();
71 
72   void emitInstruction(const MachineInstr *MI) override;
73   void emitFunctionEntryLabel() override {}
74   void emitFunctionHeader() override;
75   void emitFunctionBodyStart() override {}
76   void emitFunctionBodyEnd() override;
77   void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
78   void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {}
79   void emitGlobalVariable(const GlobalVariable *GV) override {}
80   void emitOpLabel(const MachineBasicBlock &MBB);
81   void emitEndOfAsmFile(Module &M) override;
82   bool doInitialization(Module &M) override;
83 
84   void getAnalysisUsage(AnalysisUsage &AU) const override;
85   SPIRV::ModuleAnalysisInfo *MAI;
86 };
87 } // namespace
88 
89 void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
90   AU.addRequired<SPIRVModuleAnalysis>();
91   AU.addPreserved<SPIRVModuleAnalysis>();
92   AsmPrinter::getAnalysisUsage(AU);
93 }
94 
95 // If the module has no functions, we need output global info anyway.
96 void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
97   if (ModuleSectionsEmitted == false) {
98     outputModuleSections();
99     ModuleSectionsEmitted = true;
100   }
101 }
102 
103 void SPIRVAsmPrinter::emitFunctionHeader() {
104   if (ModuleSectionsEmitted == false) {
105     outputModuleSections();
106     ModuleSectionsEmitted = true;
107   }
108   // Get the subtarget from the current MachineFunction.
109   ST = &MF->getSubtarget<SPIRVSubtarget>();
110   TII = ST->getInstrInfo();
111   const Function &F = MF->getFunction();
112 
113   if (isVerbose()) {
114     OutStreamer->getCommentOS()
115         << "-- Begin function "
116         << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
117   }
118 
119   auto Section = getObjFileLowering().SectionForGlobal(&F, TM);
120   MF->setSection(Section);
121 }
122 
123 void SPIRVAsmPrinter::outputOpFunctionEnd() {
124   MCInst FunctionEndInst;
125   FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd);
126   outputMCInst(FunctionEndInst);
127 }
128 
129 // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
130 void SPIRVAsmPrinter::emitFunctionBodyEnd() {
131   outputOpFunctionEnd();
132   MAI->BBNumToRegMap.clear();
133 }
134 
135 void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
136   if (MAI->MBBsToSkip.contains(&MBB))
137     return;
138   MCInst LabelInst;
139   LabelInst.setOpcode(SPIRV::OpLabel);
140   LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
141   outputMCInst(LabelInst);
142 }
143 
144 void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
145   // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so
146   // OpLabel should be output after them.
147   if (MBB.getNumber() == MF->front().getNumber()) {
148     for (const MachineInstr &MI : MBB)
149       if (MI.getOpcode() == SPIRV::OpFunction)
150         return;
151     // TODO: this case should be checked by the verifier.
152     report_fatal_error("OpFunction is expected in the front MBB of MF");
153   }
154   emitOpLabel(MBB);
155 }
156 
157 void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
158                                    raw_ostream &O) {
159   const MachineOperand &MO = MI->getOperand(OpNum);
160 
161   switch (MO.getType()) {
162   case MachineOperand::MO_Register:
163     O << SPIRVInstPrinter::getRegisterName(MO.getReg());
164     break;
165 
166   case MachineOperand::MO_Immediate:
167     O << MO.getImm();
168     break;
169 
170   case MachineOperand::MO_FPImmediate:
171     O << MO.getFPImm();
172     break;
173 
174   case MachineOperand::MO_MachineBasicBlock:
175     O << *MO.getMBB()->getSymbol();
176     break;
177 
178   case MachineOperand::MO_GlobalAddress:
179     O << *getSymbol(MO.getGlobal());
180     break;
181 
182   case MachineOperand::MO_BlockAddress: {
183     MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress());
184     O << BA->getName();
185     break;
186   }
187 
188   case MachineOperand::MO_ExternalSymbol:
189     O << *GetExternalSymbolSymbol(MO.getSymbolName());
190     break;
191 
192   case MachineOperand::MO_JumpTableIndex:
193   case MachineOperand::MO_ConstantPoolIndex:
194   default:
195     llvm_unreachable("<unknown operand type>");
196   }
197 }
198 
199 bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
200                                       const char *ExtraCode, raw_ostream &O) {
201   if (ExtraCode && ExtraCode[0])
202     return true; // Invalid instruction - SPIR-V does not have special modifiers
203 
204   printOperand(MI, OpNo, O);
205   return false;
206 }
207 
208 static bool isFuncOrHeaderInstr(const MachineInstr *MI,
209                                 const SPIRVInstrInfo *TII) {
210   return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction ||
211          MI->getOpcode() == SPIRV::OpFunctionParameter;
212 }
213 
214 void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) {
215   OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo());
216 }
217 
218 void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) {
219   SPIRVMCInstLower MCInstLowering;
220   MCInst TmpInst;
221   MCInstLowering.lower(MI, TmpInst, MAI);
222   outputMCInst(TmpInst);
223 }
224 
225 void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) {
226   SPIRV_MC::verifyInstructionPredicates(MI->getOpcode(),
227                                         getSubtargetInfo().getFeatureBits());
228 
229   if (!MAI->getSkipEmission(MI))
230     outputInstruction(MI);
231 
232   // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB.
233   const MachineInstr *NextMI = MI->getNextNode();
234   if (!MAI->hasMBBRegister(*MI->getParent()) && isFuncOrHeaderInstr(MI, TII) &&
235       (!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) {
236     assert(MI->getParent()->getNumber() == MF->front().getNumber() &&
237            "OpFunction is not in the front MBB of MF");
238     emitOpLabel(*MI->getParent());
239   }
240 }
241 
242 void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
243   for (MachineInstr *MI : MAI->getMSInstrs(MSType))
244     outputInstruction(MI);
245 }
246 
247 void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
248   // Output OpSourceExtensions.
249   for (auto &Str : MAI->SrcExt) {
250     MCInst Inst;
251     Inst.setOpcode(SPIRV::OpSourceExtension);
252     addStringImm(Str.first(), Inst);
253     outputMCInst(Inst);
254   }
255   // Output OpSource.
256   MCInst Inst;
257   Inst.setOpcode(SPIRV::OpSource);
258   Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->SrcLang)));
259   Inst.addOperand(
260       MCOperand::createImm(static_cast<unsigned>(MAI->SrcLangVersion)));
261   outputMCInst(Inst);
262 }
263 
264 void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
265   for (auto &CU : MAI->ExtInstSetMap) {
266     unsigned Set = CU.first;
267     Register Reg = CU.second;
268     MCInst Inst;
269     Inst.setOpcode(SPIRV::OpExtInstImport);
270     Inst.addOperand(MCOperand::createReg(Reg));
271     addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)),
272                  Inst);
273     outputMCInst(Inst);
274   }
275 }
276 
277 void SPIRVAsmPrinter::outputOpMemoryModel() {
278   MCInst Inst;
279   Inst.setOpcode(SPIRV::OpMemoryModel);
280   Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Addr)));
281   Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(MAI->Mem)));
282   outputMCInst(Inst);
283 }
284 
285 // Before the OpEntryPoints' output, we need to add the entry point's
286 // interfaces. The interface is a list of IDs of global OpVariable instructions.
287 // These declare the set of global variables from a module that form
288 // the interface of this entry point.
289 void SPIRVAsmPrinter::outputEntryPoints() {
290   // Find all OpVariable IDs with required StorageClass.
291   DenseSet<Register> InterfaceIDs;
292   for (MachineInstr *MI : MAI->GlobalVarList) {
293     assert(MI->getOpcode() == SPIRV::OpVariable);
294     auto SC = static_cast<SPIRV::StorageClass>(MI->getOperand(2).getImm());
295     // Before version 1.4, the interface's storage classes are limited to
296     // the Input and Output storage classes. Starting with version 1.4,
297     // the interface's storage classes are all storage classes used in
298     // declaring all global variables referenced by the entry point call tree.
299     if (ST->getSPIRVVersion() >= 14 || SC == SPIRV::StorageClass::Input ||
300         SC == SPIRV::StorageClass::Output) {
301       MachineFunction *MF = MI->getMF();
302       Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
303       InterfaceIDs.insert(Reg);
304     }
305   }
306 
307   // Output OpEntryPoints adding interface args to all of them.
308   for (MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) {
309     SPIRVMCInstLower MCInstLowering;
310     MCInst TmpInst;
311     MCInstLowering.lower(MI, TmpInst, MAI);
312     for (Register Reg : InterfaceIDs) {
313       assert(Reg.isValid());
314       TmpInst.addOperand(MCOperand::createReg(Reg));
315     }
316     outputMCInst(TmpInst);
317   }
318 }
319 
320 void SPIRVAsmPrinter::outputExtFuncDecls() {
321   // Insert OpFunctionEnd after each declaration.
322   SmallVectorImpl<MachineInstr *>::iterator
323       I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(),
324       E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end();
325   for (; I != E; ++I) {
326     outputInstruction(*I);
327     if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction)
328       outputOpFunctionEnd();
329   }
330 }
331 
332 // Encode LLVM type by SPIR-V execution mode VecTypeHint.
333 static unsigned encodeVecTypeHint(Type *Ty) {
334   if (Ty->isHalfTy())
335     return 4;
336   if (Ty->isFloatTy())
337     return 5;
338   if (Ty->isDoubleTy())
339     return 6;
340   if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) {
341     switch (IntTy->getIntegerBitWidth()) {
342     case 8:
343       return 0;
344     case 16:
345       return 1;
346     case 32:
347       return 2;
348     case 64:
349       return 3;
350     default:
351       llvm_unreachable("invalid integer type");
352     }
353   }
354   if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) {
355     Type *EleTy = VecTy->getElementType();
356     unsigned Size = VecTy->getNumElements();
357     return Size << 16 | encodeVecTypeHint(EleTy);
358   }
359   llvm_unreachable("invalid type");
360 }
361 
362 static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
363                              SPIRV::ModuleAnalysisInfo *MAI) {
364   for (const MDOperand &MDOp : MDN->operands()) {
365     if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
366       Constant *C = CMeta->getValue();
367       if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
368         Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
369       } else if (auto *CE = dyn_cast<Function>(C)) {
370         Register FuncReg = MAI->getFuncReg(CE->getName().str());
371         assert(FuncReg.isValid());
372         Inst.addOperand(MCOperand::createReg(FuncReg));
373       }
374     }
375   }
376 }
377 
378 void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
379                                                     SPIRV::ExecutionMode EM) {
380   MCInst Inst;
381   Inst.setOpcode(SPIRV::OpExecutionMode);
382   Inst.addOperand(MCOperand::createReg(Reg));
383   Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
384   addOpsFromMDNode(Node, Inst, MAI);
385   outputMCInst(Inst);
386 }
387 
388 void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
389   NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
390   if (Node) {
391     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
392       MCInst Inst;
393       Inst.setOpcode(SPIRV::OpExecutionMode);
394       addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
395       outputMCInst(Inst);
396     }
397   }
398   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
399     const Function &F = *FI;
400     if (F.isDeclaration())
401       continue;
402     Register FReg = MAI->getFuncReg(F.getGlobalIdentifier());
403     assert(FReg.isValid());
404     if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
405       outputExecutionModeFromMDNode(FReg, Node,
406                                     SPIRV::ExecutionMode::LocalSize);
407     if (MDNode *Node = F.getMetadata("work_group_size_hint"))
408       outputExecutionModeFromMDNode(FReg, Node,
409                                     SPIRV::ExecutionMode::LocalSizeHint);
410     if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
411       outputExecutionModeFromMDNode(FReg, Node,
412                                     SPIRV::ExecutionMode::SubgroupSize);
413     if (MDNode *Node = F.getMetadata("vec_type_hint")) {
414       MCInst Inst;
415       Inst.setOpcode(SPIRV::OpExecutionMode);
416       Inst.addOperand(MCOperand::createReg(FReg));
417       unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
418       Inst.addOperand(MCOperand::createImm(EM));
419       unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0));
420       Inst.addOperand(MCOperand::createImm(TypeCode));
421       outputMCInst(Inst);
422     }
423   }
424 }
425 
426 void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
427   outputModuleSection(SPIRV::MB_Annotations);
428   // Process llvm.global.annotations special global variable.
429   for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
430     if ((*F).getName() != "llvm.global.annotations")
431       continue;
432     const GlobalVariable *V = &(*F);
433     const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0));
434     for (Value *Op : CA->operands()) {
435       ConstantStruct *CS = cast<ConstantStruct>(Op);
436       // The first field of the struct contains a pointer to
437       // the annotated variable.
438       Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts();
439       if (!isa<Function>(AnnotatedVar))
440         llvm_unreachable("Unsupported value in llvm.global.annotations");
441       Function *Func = cast<Function>(AnnotatedVar);
442       Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier());
443 
444       // The second field contains a pointer to a global annotation string.
445       GlobalVariable *GV =
446           cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts());
447 
448       StringRef AnnotationString;
449       getConstantStringInfo(GV, AnnotationString);
450       MCInst Inst;
451       Inst.setOpcode(SPIRV::OpDecorate);
452       Inst.addOperand(MCOperand::createReg(Reg));
453       unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
454       Inst.addOperand(MCOperand::createImm(Dec));
455       addStringImm(AnnotationString, Inst);
456       outputMCInst(Inst);
457     }
458   }
459 }
460 
461 void SPIRVAsmPrinter::outputModuleSections() {
462   const Module *M = MMI->getModule();
463   // Get the global subtarget to output module-level info.
464   ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
465   TII = ST->getInstrInfo();
466   MAI = &SPIRVModuleAnalysis::MAI;
467   assert(ST && TII && MAI && M && "Module analysis is required");
468   // Output instructions according to the Logical Layout of a Module:
469   // TODO: 1,2. All OpCapability instructions, then optional OpExtension
470   // instructions.
471   // 3. Optional OpExtInstImport instructions.
472   outputOpExtInstImports(*M);
473   // 4. The single required OpMemoryModel instruction.
474   outputOpMemoryModel();
475   // 5. All entry point declarations, using OpEntryPoint.
476   outputEntryPoints();
477   // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
478   outputExecutionMode(*M);
479   // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
480   // OpSourceContinued, without forward references.
481   outputDebugSourceAndStrings(*M);
482   // 7b. Debug: all OpName and all OpMemberName.
483   outputModuleSection(SPIRV::MB_DebugNames);
484   // 7c. Debug: all OpModuleProcessed instructions.
485   outputModuleSection(SPIRV::MB_DebugModuleProcessed);
486   // 8. All annotation instructions (all decorations).
487   outputAnnotations(*M);
488   // 9. All type declarations (OpTypeXXX instructions), all constant
489   // instructions, and all global variable declarations. This section is
490   // the first section to allow use of: OpLine and OpNoLine debug information;
491   // non-semantic instructions with OpExtInst.
492   outputModuleSection(SPIRV::MB_TypeConstVars);
493   // 10. All function declarations (functions without a body).
494   outputExtFuncDecls();
495   // 11. All function definitions (functions with a body).
496   // This is done in regular function output.
497 }
498 
499 bool SPIRVAsmPrinter::doInitialization(Module &M) {
500   ModuleSectionsEmitted = false;
501   // We need to call the parent's one explicitly.
502   return AsmPrinter::doInitialization(M);
503 }
504 
505 // Force static initialization.
506 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() {
507   RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
508   RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
509 }
510