xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision a90b9d0159070121c221b966469c3e36d912bf82)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "TargetInfo/SPIRVTargetInfo.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "spirv-module-analysis"
32 
33 static cl::opt<bool>
34     SPVDumpDeps("spv-dump-deps",
35                 cl::desc("Dump MIR with SPIR-V dependencies info"),
36                 cl::Optional, cl::init(false));
37 
38 char llvm::SPIRVModuleAnalysis::ID = 0;
39 
40 namespace llvm {
41 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
42 } // namespace llvm
43 
44 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
45                 true)
46 
47 // Retrieve an unsigned from an MDNode with a list of them as operands.
48 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
49                                 unsigned DefaultVal = 0) {
50   if (MdNode && OpIndex < MdNode->getNumOperands()) {
51     const auto &Op = MdNode->getOperand(OpIndex);
52     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
53   }
54   return DefaultVal;
55 }
56 
57 static SPIRV::Requirements
58 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
59                                unsigned i, const SPIRVSubtarget &ST,
60                                SPIRV::RequirementHandler &Reqs) {
61   unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
62   unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
63   unsigned TargetVer = ST.getSPIRVVersion();
64   bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
65   bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
66   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
67   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
68   if (ReqCaps.empty()) {
69     if (ReqExts.empty()) {
70       if (MinVerOK && MaxVerOK)
71         return {true, {}, {}, ReqMinVer, ReqMaxVer};
72       return {false, {}, {}, 0, 0};
73     }
74   } else if (MinVerOK && MaxVerOK) {
75     for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work.
76       if (Reqs.isCapabilityAvailable(Cap))
77         return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
78     }
79   }
80   // If there are no capabilities, or we can't satisfy the version or
81   // capability requirements, use the list of extensions (if the subtarget
82   // can handle them all).
83   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
84         return ST.canUseExtension(Ext);
85       })) {
86     return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
87   }
88   return {false, {}, {}, 0, 0};
89 }
90 
91 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
92   MAI.MaxID = 0;
93   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
94     MAI.MS[i].clear();
95   MAI.RegisterAliasTable.clear();
96   MAI.InstrsToDelete.clear();
97   MAI.FuncMap.clear();
98   MAI.GlobalVarList.clear();
99   MAI.ExtInstSetMap.clear();
100   MAI.Reqs.clear();
101   MAI.Reqs.initAvailableCapabilities(*ST);
102 
103   // TODO: determine memory model and source language from the configuratoin.
104   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
105     auto MemMD = MemModel->getOperand(0);
106     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
107         getMetadataUInt(MemMD, 0));
108     MAI.Mem =
109         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
110   } else {
111     // TODO: Add support for VulkanMemoryModel.
112     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
113                                 : SPIRV::MemoryModel::GLSL450;
114     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
115       unsigned PtrSize = ST->getPointerSize();
116       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
117                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
118                                  : SPIRV::AddressingModel::Logical;
119     } else {
120       // TODO: Add support for PhysicalStorageBufferAddress.
121       MAI.Addr = SPIRV::AddressingModel::Logical;
122     }
123   }
124   // Get the OpenCL version number from metadata.
125   // TODO: support other source languages.
126   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
127     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
128     // Construct version literal in accordance with SPIRV-LLVM-Translator.
129     // TODO: support multiple OCL version metadata.
130     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
131     auto VersionMD = VerNode->getOperand(0);
132     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
133     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
134     unsigned RevNum = getMetadataUInt(VersionMD, 2);
135     MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
136   } else {
137     MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
138     MAI.SrcLangVersion = 0;
139   }
140 
141   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
142     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
143       MDNode *MD = ExtNode->getOperand(I);
144       if (!MD || MD->getNumOperands() == 0)
145         continue;
146       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
147         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
148     }
149   }
150 
151   // Update required capabilities for this memory model, addressing model and
152   // source language.
153   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
154                                  MAI.Mem, *ST);
155   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
156                                  MAI.SrcLang, *ST);
157   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
158                                  MAI.Addr, *ST);
159 
160   if (ST->isOpenCLEnv()) {
161     // TODO: check if it's required by default.
162     MAI.ExtInstSetMap[static_cast<unsigned>(
163         SPIRV::InstructionSet::OpenCL_std)] =
164         Register::index2VirtReg(MAI.getNextID());
165   }
166 }
167 
168 // Collect MI which defines the register in the given machine function.
169 static void collectDefInstr(Register Reg, const MachineFunction *MF,
170                             SPIRV::ModuleAnalysisInfo *MAI,
171                             SPIRV::ModuleSectionType MSType,
172                             bool DoInsert = true) {
173   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
174   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
175   assert(MI && "There should be an instruction that defines the register");
176   MAI->setSkipEmission(MI);
177   if (DoInsert)
178     MAI->MS[MSType].push_back(MI);
179 }
180 
181 void SPIRVModuleAnalysis::collectGlobalEntities(
182     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
183     SPIRV::ModuleSectionType MSType,
184     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
185     bool UsePreOrder = false) {
186   DenseSet<const SPIRV::DTSortableEntry *> Visited;
187   for (const auto *E : DepsGraph) {
188     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
189     // NOTE: here we prefer recursive approach over iterative because
190     // we don't expect depchains long enough to cause SO.
191     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
192                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
193       if (Visited.count(E) || !Pred(E))
194         return;
195       Visited.insert(E);
196 
197       // Traversing deps graph in post-order allows us to get rid of
198       // register aliases preprocessing.
199       // But pre-order is required for correct processing of function
200       // declaration and arguments processing.
201       if (!UsePreOrder)
202         for (auto *S : E->getDeps())
203           RecHoistUtil(S);
204 
205       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
206       bool IsFirst = true;
207       for (auto &U : *E) {
208         const MachineFunction *MF = U.first;
209         Register Reg = U.second;
210         MAI.setRegisterAlias(MF, Reg, GlobalReg);
211         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
212           continue;
213         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
214         IsFirst = false;
215         if (E->getIsGV())
216           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
217       }
218 
219       if (UsePreOrder)
220         for (auto *S : E->getDeps())
221           RecHoistUtil(S);
222     };
223     RecHoistUtil(E);
224   }
225 }
226 
227 // The function initializes global register alias table for types, consts,
228 // global vars and func decls and collects these instruction for output
229 // at module level. Also it collects explicit OpExtension/OpCapability
230 // instructions.
231 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
232   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
233 
234   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
235 
236   collectGlobalEntities(
237       DepsGraph, SPIRV::MB_TypeConstVars,
238       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
239 
240   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
241     MachineFunction *MF = MMI->getMachineFunction(*F);
242     if (!MF)
243       continue;
244     // Iterate through and collect OpExtension/OpCapability instructions.
245     for (MachineBasicBlock &MBB : *MF) {
246       for (MachineInstr &MI : MBB) {
247         if (MI.getOpcode() == SPIRV::OpExtension) {
248           // Here, OpExtension just has a single enum operand, not a string.
249           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
250           MAI.Reqs.addExtension(Ext);
251           MAI.setSkipEmission(&MI);
252         } else if (MI.getOpcode() == SPIRV::OpCapability) {
253           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
254           MAI.Reqs.addCapability(Cap);
255           MAI.setSkipEmission(&MI);
256         }
257       }
258     }
259   }
260 
261   collectGlobalEntities(
262       DepsGraph, SPIRV::MB_ExtFuncDecls,
263       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
264 }
265 
266 // Look for IDs declared with Import linkage, and map the corresponding function
267 // to the register defining that variable (which will usually be the result of
268 // an OpFunction). This lets us call externally imported functions using
269 // the correct ID registers.
270 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
271                                            const Function *F) {
272   if (MI.getOpcode() == SPIRV::OpDecorate) {
273     // If it's got Import linkage.
274     auto Dec = MI.getOperand(1).getImm();
275     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
276       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
277       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
278         // Map imported function name to function ID register.
279         const Function *ImportedFunc =
280             F->getParent()->getFunction(getStringImm(MI, 2));
281         Register Target = MI.getOperand(0).getReg();
282         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
283       }
284     }
285   } else if (MI.getOpcode() == SPIRV::OpFunction) {
286     // Record all internal OpFunction declarations.
287     Register Reg = MI.defs().begin()->getReg();
288     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
289     assert(GlobalReg.isValid());
290     MAI.FuncMap[F] = GlobalReg;
291   }
292 }
293 
294 using InstrSignature = SmallVector<size_t>;
295 using InstrTraces = std::set<InstrSignature>;
296 
297 // Returns a representation of an instruction as a vector of MachineOperand
298 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
299 // This creates a signature of the instruction with the same content
300 // that MachineOperand::isIdenticalTo uses for comparison.
301 static InstrSignature instrToSignature(MachineInstr &MI,
302                                        SPIRV::ModuleAnalysisInfo &MAI) {
303   InstrSignature Signature;
304   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
305     const MachineOperand &MO = MI.getOperand(i);
306     size_t h;
307     if (MO.isReg()) {
308       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
309       // mimic llvm::hash_value(const MachineOperand &MO)
310       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
311                        MO.isDef());
312     } else {
313       h = hash_value(MO);
314     }
315     Signature.push_back(h);
316   }
317   return Signature;
318 }
319 
320 // Collect the given instruction in the specified MS. We assume global register
321 // numbering has already occurred by this point. We can directly compare reg
322 // arguments when detecting duplicates.
323 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
324                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
325                               bool Append = true) {
326   MAI.setSkipEmission(&MI);
327   InstrSignature MISign = instrToSignature(MI, MAI);
328   auto FoundMI = IS.insert(MISign);
329   if (!FoundMI.second)
330     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
331   // No duplicates, so add it.
332   if (Append)
333     MAI.MS[MSType].push_back(&MI);
334   else
335     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
336 }
337 
338 // Some global instructions make reference to function-local ID regs, so cannot
339 // be correctly collected until these registers are globally numbered.
340 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
341   InstrTraces IS;
342   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
343     if ((*F).isDeclaration())
344       continue;
345     MachineFunction *MF = MMI->getMachineFunction(*F);
346     assert(MF);
347     for (MachineBasicBlock &MBB : *MF)
348       for (MachineInstr &MI : MBB) {
349         if (MAI.getSkipEmission(&MI))
350           continue;
351         const unsigned OpCode = MI.getOpcode();
352         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
353           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
354         } else if (OpCode == SPIRV::OpEntryPoint) {
355           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
356         } else if (TII->isDecorationInstr(MI)) {
357           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
358           collectFuncNames(MI, &*F);
359         } else if (TII->isConstantInstr(MI)) {
360           // Now OpSpecConstant*s are not in DT,
361           // but they need to be collected anyway.
362           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
363         } else if (OpCode == SPIRV::OpFunction) {
364           collectFuncNames(MI, &*F);
365         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
366           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
367         }
368       }
369   }
370 }
371 
372 // Number registers in all functions globally from 0 onwards and store
373 // the result in global register alias table. Some registers are already
374 // numbered in collectGlobalEntities.
375 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
376   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
377     if ((*F).isDeclaration())
378       continue;
379     MachineFunction *MF = MMI->getMachineFunction(*F);
380     assert(MF);
381     for (MachineBasicBlock &MBB : *MF) {
382       for (MachineInstr &MI : MBB) {
383         for (MachineOperand &Op : MI.operands()) {
384           if (!Op.isReg())
385             continue;
386           Register Reg = Op.getReg();
387           if (MAI.hasRegisterAlias(MF, Reg))
388             continue;
389           Register NewReg = Register::index2VirtReg(MAI.getNextID());
390           MAI.setRegisterAlias(MF, Reg, NewReg);
391         }
392         if (MI.getOpcode() != SPIRV::OpExtInst)
393           continue;
394         auto Set = MI.getOperand(2).getImm();
395         if (!MAI.ExtInstSetMap.contains(Set))
396           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
397       }
398     }
399   }
400 }
401 
402 // RequirementHandler implementations.
403 void SPIRV::RequirementHandler::getAndAddRequirements(
404     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
405     const SPIRVSubtarget &ST) {
406   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
407 }
408 
409 void SPIRV::RequirementHandler::pruneCapabilities(
410     const CapabilityList &ToPrune) {
411   for (const auto &Cap : ToPrune) {
412     AllCaps.insert(Cap);
413     auto FoundIndex = llvm::find(MinimalCaps, Cap);
414     if (FoundIndex != MinimalCaps.end())
415       MinimalCaps.erase(FoundIndex);
416     CapabilityList ImplicitDecls =
417         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
418     pruneCapabilities(ImplicitDecls);
419   }
420 }
421 
422 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
423   for (const auto &Cap : ToAdd) {
424     bool IsNewlyInserted = AllCaps.insert(Cap).second;
425     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
426       continue;
427     CapabilityList ImplicitDecls =
428         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
429     pruneCapabilities(ImplicitDecls);
430     MinimalCaps.push_back(Cap);
431   }
432 }
433 
434 void SPIRV::RequirementHandler::addRequirements(
435     const SPIRV::Requirements &Req) {
436   if (!Req.IsSatisfiable)
437     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
438 
439   if (Req.Cap.has_value())
440     addCapabilities({Req.Cap.value()});
441 
442   addExtensions(Req.Exts);
443 
444   if (Req.MinVer) {
445     if (MaxVersion && Req.MinVer > MaxVersion) {
446       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
447                         << " and <= " << MaxVersion << "\n");
448       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
449     }
450 
451     if (MinVersion == 0 || Req.MinVer > MinVersion)
452       MinVersion = Req.MinVer;
453   }
454 
455   if (Req.MaxVer) {
456     if (MinVersion && Req.MaxVer < MinVersion) {
457       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
458                         << " and >= " << MinVersion << "\n");
459       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
460     }
461 
462     if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
463       MaxVersion = Req.MaxVer;
464   }
465 }
466 
467 void SPIRV::RequirementHandler::checkSatisfiable(
468     const SPIRVSubtarget &ST) const {
469   // Report as many errors as possible before aborting the compilation.
470   bool IsSatisfiable = true;
471   auto TargetVer = ST.getSPIRVVersion();
472 
473   if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
474     LLVM_DEBUG(
475         dbgs() << "Target SPIR-V version too high for required features\n"
476                << "Required max version: " << MaxVersion << " target version "
477                << TargetVer << "\n");
478     IsSatisfiable = false;
479   }
480 
481   if (MinVersion && TargetVer && MinVersion > TargetVer) {
482     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
483                       << "Required min version: " << MinVersion
484                       << " target version " << TargetVer << "\n");
485     IsSatisfiable = false;
486   }
487 
488   if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
489     LLVM_DEBUG(
490         dbgs()
491         << "Version is too low for some features and too high for others.\n"
492         << "Required SPIR-V min version: " << MinVersion
493         << " required SPIR-V max version " << MaxVersion << "\n");
494     IsSatisfiable = false;
495   }
496 
497   for (auto Cap : MinimalCaps) {
498     if (AvailableCaps.contains(Cap))
499       continue;
500     LLVM_DEBUG(dbgs() << "Capability not supported: "
501                       << getSymbolicOperandMnemonic(
502                              OperandCategory::CapabilityOperand, Cap)
503                       << "\n");
504     IsSatisfiable = false;
505   }
506 
507   for (auto Ext : AllExtensions) {
508     if (ST.canUseExtension(Ext))
509       continue;
510     LLVM_DEBUG(dbgs() << "Extension not supported: "
511                       << getSymbolicOperandMnemonic(
512                              OperandCategory::ExtensionOperand, Ext)
513                       << "\n");
514     IsSatisfiable = false;
515   }
516 
517   if (!IsSatisfiable)
518     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
519 }
520 
521 // Add the given capabilities and all their implicitly defined capabilities too.
522 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
523   for (const auto Cap : ToAdd)
524     if (AvailableCaps.insert(Cap).second)
525       addAvailableCaps(getSymbolicOperandCapabilities(
526           SPIRV::OperandCategory::CapabilityOperand, Cap));
527 }
528 
529 void SPIRV::RequirementHandler::removeCapabilityIf(
530     const Capability::Capability ToRemove,
531     const Capability::Capability IfPresent) {
532   if (AllCaps.contains(IfPresent))
533     AllCaps.erase(ToRemove);
534 }
535 
536 namespace llvm {
537 namespace SPIRV {
538 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
539   if (ST.isOpenCLEnv()) {
540     initAvailableCapabilitiesForOpenCL(ST);
541     return;
542   }
543 
544   if (ST.isVulkanEnv()) {
545     initAvailableCapabilitiesForVulkan(ST);
546     return;
547   }
548 
549   report_fatal_error("Unimplemented environment for SPIR-V generation.");
550 }
551 
552 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
553     const SPIRVSubtarget &ST) {
554   // Add the min requirements for different OpenCL and SPIR-V versions.
555   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
556                     Capability::Int16, Capability::Int8, Capability::Kernel,
557                     Capability::Linkage, Capability::Vector16,
558                     Capability::Groups, Capability::GenericPointer,
559                     Capability::Shader});
560   if (ST.hasOpenCLFullProfile())
561     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
562   if (ST.hasOpenCLImageSupport()) {
563     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
564                       Capability::Image1D, Capability::SampledBuffer,
565                       Capability::ImageBuffer});
566     if (ST.isAtLeastOpenCLVer(20))
567       addAvailableCaps({Capability::ImageReadWrite});
568   }
569   if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
570     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
571   if (ST.isAtLeastSPIRVVer(13))
572     addAvailableCaps({Capability::GroupNonUniform,
573                       Capability::GroupNonUniformVote,
574                       Capability::GroupNonUniformArithmetic,
575                       Capability::GroupNonUniformBallot,
576                       Capability::GroupNonUniformClustered,
577                       Capability::GroupNonUniformShuffle,
578                       Capability::GroupNonUniformShuffleRelative});
579   if (ST.isAtLeastSPIRVVer(14))
580     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
581                       Capability::SignedZeroInfNanPreserve,
582                       Capability::RoundingModeRTE,
583                       Capability::RoundingModeRTZ});
584   // TODO: verify if this needs some checks.
585   addAvailableCaps({Capability::Float16, Capability::Float64});
586 
587   // Add capabilities enabled by extensions.
588   for (auto Extension : ST.getAllAvailableExtensions()) {
589     CapabilityList EnabledCapabilities =
590         getCapabilitiesEnabledByExtension(Extension);
591     addAvailableCaps(EnabledCapabilities);
592   }
593 
594   // TODO: add OpenCL extensions.
595 }
596 
597 void RequirementHandler::initAvailableCapabilitiesForVulkan(
598     const SPIRVSubtarget &ST) {
599   addAvailableCaps({Capability::Shader, Capability::Linkage});
600 
601   // Provided by all supported Vulkan versions.
602   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
603                     Capability::Float64});
604 }
605 
606 } // namespace SPIRV
607 } // namespace llvm
608 
609 // Add the required capabilities from a decoration instruction (including
610 // BuiltIns).
611 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
612                               SPIRV::RequirementHandler &Reqs,
613                               const SPIRVSubtarget &ST) {
614   int64_t DecOp = MI.getOperand(DecIndex).getImm();
615   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
616   Reqs.addRequirements(getSymbolicOperandRequirements(
617       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
618 
619   if (Dec == SPIRV::Decoration::BuiltIn) {
620     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
621     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
622     Reqs.addRequirements(getSymbolicOperandRequirements(
623         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
624   }
625 }
626 
627 // Add requirements for image handling.
628 static void addOpTypeImageReqs(const MachineInstr &MI,
629                                SPIRV::RequirementHandler &Reqs,
630                                const SPIRVSubtarget &ST) {
631   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
632   // The operand indices used here are based on the OpTypeImage layout, which
633   // the MachineInstr follows as well.
634   int64_t ImgFormatOp = MI.getOperand(7).getImm();
635   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
636   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
637                              ImgFormat, ST);
638 
639   bool IsArrayed = MI.getOperand(4).getImm() == 1;
640   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
641   bool NoSampler = MI.getOperand(6).getImm() == 2;
642   // Add dimension requirements.
643   assert(MI.getOperand(2).isImm());
644   switch (MI.getOperand(2).getImm()) {
645   case SPIRV::Dim::DIM_1D:
646     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
647                                    : SPIRV::Capability::Sampled1D);
648     break;
649   case SPIRV::Dim::DIM_2D:
650     if (IsMultisampled && NoSampler)
651       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
652     break;
653   case SPIRV::Dim::DIM_Cube:
654     Reqs.addRequirements(SPIRV::Capability::Shader);
655     if (IsArrayed)
656       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
657                                      : SPIRV::Capability::SampledCubeArray);
658     break;
659   case SPIRV::Dim::DIM_Rect:
660     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
661                                    : SPIRV::Capability::SampledRect);
662     break;
663   case SPIRV::Dim::DIM_Buffer:
664     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
665                                    : SPIRV::Capability::SampledBuffer);
666     break;
667   case SPIRV::Dim::DIM_SubpassData:
668     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
669     break;
670   }
671 
672   // Has optional access qualifier.
673   // TODO: check if it's OpenCL's kernel.
674   if (MI.getNumOperands() > 8 &&
675       MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
676     Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
677   else
678     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
679 }
680 
681 void addInstrRequirements(const MachineInstr &MI,
682                           SPIRV::RequirementHandler &Reqs,
683                           const SPIRVSubtarget &ST) {
684   switch (MI.getOpcode()) {
685   case SPIRV::OpMemoryModel: {
686     int64_t Addr = MI.getOperand(0).getImm();
687     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
688                                Addr, ST);
689     int64_t Mem = MI.getOperand(1).getImm();
690     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
691                                ST);
692     break;
693   }
694   case SPIRV::OpEntryPoint: {
695     int64_t Exe = MI.getOperand(0).getImm();
696     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
697                                Exe, ST);
698     break;
699   }
700   case SPIRV::OpExecutionMode:
701   case SPIRV::OpExecutionModeId: {
702     int64_t Exe = MI.getOperand(1).getImm();
703     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
704                                Exe, ST);
705     break;
706   }
707   case SPIRV::OpTypeMatrix:
708     Reqs.addCapability(SPIRV::Capability::Matrix);
709     break;
710   case SPIRV::OpTypeInt: {
711     unsigned BitWidth = MI.getOperand(1).getImm();
712     if (BitWidth == 64)
713       Reqs.addCapability(SPIRV::Capability::Int64);
714     else if (BitWidth == 16)
715       Reqs.addCapability(SPIRV::Capability::Int16);
716     else if (BitWidth == 8)
717       Reqs.addCapability(SPIRV::Capability::Int8);
718     break;
719   }
720   case SPIRV::OpTypeFloat: {
721     unsigned BitWidth = MI.getOperand(1).getImm();
722     if (BitWidth == 64)
723       Reqs.addCapability(SPIRV::Capability::Float64);
724     else if (BitWidth == 16)
725       Reqs.addCapability(SPIRV::Capability::Float16);
726     break;
727   }
728   case SPIRV::OpTypeVector: {
729     unsigned NumComponents = MI.getOperand(2).getImm();
730     if (NumComponents == 8 || NumComponents == 16)
731       Reqs.addCapability(SPIRV::Capability::Vector16);
732     break;
733   }
734   case SPIRV::OpTypePointer: {
735     auto SC = MI.getOperand(1).getImm();
736     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
737                                ST);
738     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
739     // capability.
740     if (!ST.isOpenCLEnv())
741       break;
742     assert(MI.getOperand(2).isReg());
743     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
744     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
745     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
746         TypeDef->getOperand(1).getImm() == 16)
747       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
748     break;
749   }
750   case SPIRV::OpBitReverse:
751   case SPIRV::OpBitFieldInsert:
752   case SPIRV::OpBitFieldSExtract:
753   case SPIRV::OpBitFieldUExtract:
754     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
755       Reqs.addCapability(SPIRV::Capability::Shader);
756       break;
757     }
758     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
759     Reqs.addCapability(SPIRV::Capability::BitInstructions);
760     break;
761   case SPIRV::OpTypeRuntimeArray:
762     Reqs.addCapability(SPIRV::Capability::Shader);
763     break;
764   case SPIRV::OpTypeOpaque:
765   case SPIRV::OpTypeEvent:
766     Reqs.addCapability(SPIRV::Capability::Kernel);
767     break;
768   case SPIRV::OpTypePipe:
769   case SPIRV::OpTypeReserveId:
770     Reqs.addCapability(SPIRV::Capability::Pipes);
771     break;
772   case SPIRV::OpTypeDeviceEvent:
773   case SPIRV::OpTypeQueue:
774   case SPIRV::OpBuildNDRange:
775     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
776     break;
777   case SPIRV::OpDecorate:
778   case SPIRV::OpDecorateId:
779   case SPIRV::OpDecorateString:
780     addOpDecorateReqs(MI, 1, Reqs, ST);
781     break;
782   case SPIRV::OpMemberDecorate:
783   case SPIRV::OpMemberDecorateString:
784     addOpDecorateReqs(MI, 2, Reqs, ST);
785     break;
786   case SPIRV::OpInBoundsPtrAccessChain:
787     Reqs.addCapability(SPIRV::Capability::Addresses);
788     break;
789   case SPIRV::OpConstantSampler:
790     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
791     break;
792   case SPIRV::OpTypeImage:
793     addOpTypeImageReqs(MI, Reqs, ST);
794     break;
795   case SPIRV::OpTypeSampler:
796     Reqs.addCapability(SPIRV::Capability::ImageBasic);
797     break;
798   case SPIRV::OpTypeForwardPointer:
799     // TODO: check if it's OpenCL's kernel.
800     Reqs.addCapability(SPIRV::Capability::Addresses);
801     break;
802   case SPIRV::OpAtomicFlagTestAndSet:
803   case SPIRV::OpAtomicLoad:
804   case SPIRV::OpAtomicStore:
805   case SPIRV::OpAtomicExchange:
806   case SPIRV::OpAtomicCompareExchange:
807   case SPIRV::OpAtomicIIncrement:
808   case SPIRV::OpAtomicIDecrement:
809   case SPIRV::OpAtomicIAdd:
810   case SPIRV::OpAtomicISub:
811   case SPIRV::OpAtomicUMin:
812   case SPIRV::OpAtomicUMax:
813   case SPIRV::OpAtomicSMin:
814   case SPIRV::OpAtomicSMax:
815   case SPIRV::OpAtomicAnd:
816   case SPIRV::OpAtomicOr:
817   case SPIRV::OpAtomicXor: {
818     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
819     const MachineInstr *InstrPtr = &MI;
820     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
821       assert(MI.getOperand(3).isReg());
822       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
823       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
824     }
825     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
826     Register TypeReg = InstrPtr->getOperand(1).getReg();
827     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
828     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
829       unsigned BitWidth = TypeDef->getOperand(1).getImm();
830       if (BitWidth == 64)
831         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
832     }
833     break;
834   }
835   case SPIRV::OpGroupNonUniformIAdd:
836   case SPIRV::OpGroupNonUniformFAdd:
837   case SPIRV::OpGroupNonUniformIMul:
838   case SPIRV::OpGroupNonUniformFMul:
839   case SPIRV::OpGroupNonUniformSMin:
840   case SPIRV::OpGroupNonUniformUMin:
841   case SPIRV::OpGroupNonUniformFMin:
842   case SPIRV::OpGroupNonUniformSMax:
843   case SPIRV::OpGroupNonUniformUMax:
844   case SPIRV::OpGroupNonUniformFMax:
845   case SPIRV::OpGroupNonUniformBitwiseAnd:
846   case SPIRV::OpGroupNonUniformBitwiseOr:
847   case SPIRV::OpGroupNonUniformBitwiseXor:
848   case SPIRV::OpGroupNonUniformLogicalAnd:
849   case SPIRV::OpGroupNonUniformLogicalOr:
850   case SPIRV::OpGroupNonUniformLogicalXor: {
851     assert(MI.getOperand(3).isImm());
852     int64_t GroupOp = MI.getOperand(3).getImm();
853     switch (GroupOp) {
854     case SPIRV::GroupOperation::Reduce:
855     case SPIRV::GroupOperation::InclusiveScan:
856     case SPIRV::GroupOperation::ExclusiveScan:
857       Reqs.addCapability(SPIRV::Capability::Kernel);
858       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
859       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
860       break;
861     case SPIRV::GroupOperation::ClusteredReduce:
862       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
863       break;
864     case SPIRV::GroupOperation::PartitionedReduceNV:
865     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
866     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
867       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
868       break;
869     }
870     break;
871   }
872   case SPIRV::OpGroupNonUniformShuffle:
873   case SPIRV::OpGroupNonUniformShuffleXor:
874     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
875     break;
876   case SPIRV::OpGroupNonUniformShuffleUp:
877   case SPIRV::OpGroupNonUniformShuffleDown:
878     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
879     break;
880   case SPIRV::OpGroupAll:
881   case SPIRV::OpGroupAny:
882   case SPIRV::OpGroupBroadcast:
883   case SPIRV::OpGroupIAdd:
884   case SPIRV::OpGroupFAdd:
885   case SPIRV::OpGroupFMin:
886   case SPIRV::OpGroupUMin:
887   case SPIRV::OpGroupSMin:
888   case SPIRV::OpGroupFMax:
889   case SPIRV::OpGroupUMax:
890   case SPIRV::OpGroupSMax:
891     Reqs.addCapability(SPIRV::Capability::Groups);
892     break;
893   case SPIRV::OpGroupNonUniformElect:
894     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
895     break;
896   case SPIRV::OpGroupNonUniformAll:
897   case SPIRV::OpGroupNonUniformAny:
898   case SPIRV::OpGroupNonUniformAllEqual:
899     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
900     break;
901   case SPIRV::OpGroupNonUniformBroadcast:
902   case SPIRV::OpGroupNonUniformBroadcastFirst:
903   case SPIRV::OpGroupNonUniformBallot:
904   case SPIRV::OpGroupNonUniformInverseBallot:
905   case SPIRV::OpGroupNonUniformBallotBitExtract:
906   case SPIRV::OpGroupNonUniformBallotBitCount:
907   case SPIRV::OpGroupNonUniformBallotFindLSB:
908   case SPIRV::OpGroupNonUniformBallotFindMSB:
909     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
910     break;
911   case SPIRV::OpAssumeTrueKHR:
912   case SPIRV::OpExpectKHR:
913     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
914       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
915       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
916     }
917     break;
918   default:
919     break;
920   }
921 
922   // If we require capability Shader, then we can remove the requirement for
923   // the BitInstructions capability, since Shader is a superset capability
924   // of BitInstructions.
925   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
926                           SPIRV::Capability::Shader);
927 }
928 
929 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
930                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
931   // Collect requirements for existing instructions.
932   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
933     MachineFunction *MF = MMI->getMachineFunction(*F);
934     if (!MF)
935       continue;
936     for (const MachineBasicBlock &MBB : *MF)
937       for (const MachineInstr &MI : MBB)
938         addInstrRequirements(MI, MAI.Reqs, ST);
939   }
940   // Collect requirements for OpExecutionMode instructions.
941   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
942   if (Node) {
943     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
944       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
945       const MDOperand &MDOp = MDN->getOperand(1);
946       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
947         Constant *C = CMeta->getValue();
948         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
949           auto EM = Const->getZExtValue();
950           MAI.Reqs.getAndAddRequirements(
951               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
952         }
953       }
954     }
955   }
956   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
957     const Function &F = *FI;
958     if (F.isDeclaration())
959       continue;
960     if (F.getMetadata("reqd_work_group_size"))
961       MAI.Reqs.getAndAddRequirements(
962           SPIRV::OperandCategory::ExecutionModeOperand,
963           SPIRV::ExecutionMode::LocalSize, ST);
964     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
965       MAI.Reqs.getAndAddRequirements(
966           SPIRV::OperandCategory::ExecutionModeOperand,
967           SPIRV::ExecutionMode::LocalSize, ST);
968     }
969     if (F.getMetadata("work_group_size_hint"))
970       MAI.Reqs.getAndAddRequirements(
971           SPIRV::OperandCategory::ExecutionModeOperand,
972           SPIRV::ExecutionMode::LocalSizeHint, ST);
973     if (F.getMetadata("intel_reqd_sub_group_size"))
974       MAI.Reqs.getAndAddRequirements(
975           SPIRV::OperandCategory::ExecutionModeOperand,
976           SPIRV::ExecutionMode::SubgroupSize, ST);
977     if (F.getMetadata("vec_type_hint"))
978       MAI.Reqs.getAndAddRequirements(
979           SPIRV::OperandCategory::ExecutionModeOperand,
980           SPIRV::ExecutionMode::VecTypeHint, ST);
981 
982     if (F.hasOptNone() &&
983         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
984       // Output OpCapability OptNoneINTEL.
985       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
986       MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
987     }
988   }
989 }
990 
991 static unsigned getFastMathFlags(const MachineInstr &I) {
992   unsigned Flags = SPIRV::FPFastMathMode::None;
993   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
994     Flags |= SPIRV::FPFastMathMode::NotNaN;
995   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
996     Flags |= SPIRV::FPFastMathMode::NotInf;
997   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
998     Flags |= SPIRV::FPFastMathMode::NSZ;
999   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1000     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1001   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1002     Flags |= SPIRV::FPFastMathMode::Fast;
1003   return Flags;
1004 }
1005 
1006 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1007                                    const SPIRVInstrInfo &TII,
1008                                    SPIRV::RequirementHandler &Reqs) {
1009   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1010       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1011                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1012           .IsSatisfiable) {
1013     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1014                     SPIRV::Decoration::NoSignedWrap, {});
1015   }
1016   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1017       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1018                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1019                                      Reqs)
1020           .IsSatisfiable) {
1021     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1022                     SPIRV::Decoration::NoUnsignedWrap, {});
1023   }
1024   if (!TII.canUseFastMathFlags(I))
1025     return;
1026   unsigned FMFlags = getFastMathFlags(I);
1027   if (FMFlags == SPIRV::FPFastMathMode::None)
1028     return;
1029   Register DstReg = I.getOperand(0).getReg();
1030   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1031 }
1032 
1033 // Walk all functions and add decorations related to MI flags.
1034 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1035                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1036                            SPIRV::ModuleAnalysisInfo &MAI) {
1037   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1038     MachineFunction *MF = MMI->getMachineFunction(*F);
1039     if (!MF)
1040       continue;
1041     for (auto &MBB : *MF)
1042       for (auto &MI : MBB)
1043         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1044   }
1045 }
1046 
1047 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1048 
1049 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1050   AU.addRequired<TargetPassConfig>();
1051   AU.addRequired<MachineModuleInfoWrapperPass>();
1052 }
1053 
1054 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1055   SPIRVTargetMachine &TM =
1056       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1057   ST = TM.getSubtargetImpl();
1058   GR = ST->getSPIRVGlobalRegistry();
1059   TII = ST->getInstrInfo();
1060 
1061   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1062 
1063   setBaseInfo(M);
1064 
1065   addDecorations(M, *TII, MMI, *ST, MAI);
1066 
1067   collectReqs(M, MAI, MMI, *ST);
1068 
1069   // Process type/const/global var/func decl instructions, number their
1070   // destination registers from 0 to N, collect Extensions and Capabilities.
1071   processDefInstrs(M);
1072 
1073   // Number rest of registers from N+1 onwards.
1074   numberRegistersGlobally(M);
1075 
1076   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1077   processOtherInstrs(M);
1078 
1079   // If there are no entry points, we need the Linkage capability.
1080   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1081     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1082 
1083   return false;
1084 }
1085