xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision a2fda816eb054d5873be223ef2461741dfcc253c)
1  //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // The analysis collects instructions that should be output at the module level
10  // and performs the global register numbering.
11  //
12  // The results of this analysis are used in AsmPrinter to rename registers
13  // globally and to output required instructions at the module level.
14  //
15  //===----------------------------------------------------------------------===//
16  
17  #include "SPIRVModuleAnalysis.h"
18  #include "MCTargetDesc/SPIRVBaseInfo.h"
19  #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20  #include "SPIRV.h"
21  #include "SPIRVSubtarget.h"
22  #include "SPIRVTargetMachine.h"
23  #include "SPIRVUtils.h"
24  #include "TargetInfo/SPIRVTargetInfo.h"
25  #include "llvm/ADT/STLExtras.h"
26  #include "llvm/CodeGen/MachineModuleInfo.h"
27  #include "llvm/CodeGen/TargetPassConfig.h"
28  
29  using namespace llvm;
30  
31  #define DEBUG_TYPE "spirv-module-analysis"
32  
33  static cl::opt<bool>
34      SPVDumpDeps("spv-dump-deps",
35                  cl::desc("Dump MIR with SPIR-V dependencies info"),
36                  cl::Optional, cl::init(false));
37  
38  char llvm::SPIRVModuleAnalysis::ID = 0;
39  
40  namespace llvm {
41  void initializeSPIRVModuleAnalysisPass(PassRegistry &);
42  } // namespace llvm
43  
44  INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
45                  true)
46  
47  // Retrieve an unsigned from an MDNode with a list of them as operands.
48  static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
49                                  unsigned DefaultVal = 0) {
50    if (MdNode && OpIndex < MdNode->getNumOperands()) {
51      const auto &Op = MdNode->getOperand(OpIndex);
52      return mdconst::extract<ConstantInt>(Op)->getZExtValue();
53    }
54    return DefaultVal;
55  }
56  
57  static SPIRV::Requirements
58  getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
59                                 unsigned i, const SPIRVSubtarget &ST,
60                                 SPIRV::RequirementHandler &Reqs) {
61    unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
62    unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
63    unsigned TargetVer = ST.getSPIRVVersion();
64    bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
65    bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
66    CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
67    ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
68    if (ReqCaps.empty()) {
69      if (ReqExts.empty()) {
70        if (MinVerOK && MaxVerOK)
71          return {true, {}, {}, ReqMinVer, ReqMaxVer};
72        return {false, {}, {}, 0, 0};
73      }
74    } else if (MinVerOK && MaxVerOK) {
75      for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work.
76        if (Reqs.isCapabilityAvailable(Cap))
77          return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
78      }
79    }
80    // If there are no capabilities, or we can't satisfy the version or
81    // capability requirements, use the list of extensions (if the subtarget
82    // can handle them all).
83    if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
84          return ST.canUseExtension(Ext);
85        })) {
86      return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
87    }
88    return {false, {}, {}, 0, 0};
89  }
90  
91  void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
92    MAI.MaxID = 0;
93    for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
94      MAI.MS[i].clear();
95    MAI.RegisterAliasTable.clear();
96    MAI.InstrsToDelete.clear();
97    MAI.FuncMap.clear();
98    MAI.GlobalVarList.clear();
99    MAI.ExtInstSetMap.clear();
100    MAI.Reqs.clear();
101    MAI.Reqs.initAvailableCapabilities(*ST);
102  
103    // TODO: determine memory model and source language from the configuratoin.
104    if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
105      auto MemMD = MemModel->getOperand(0);
106      MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
107          getMetadataUInt(MemMD, 0));
108      MAI.Mem =
109          static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
110    } else {
111      // TODO: Add support for VulkanMemoryModel.
112      MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
113                                  : SPIRV::MemoryModel::GLSL450;
114      if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
115        unsigned PtrSize = ST->getPointerSize();
116        MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
117                   : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
118                                   : SPIRV::AddressingModel::Logical;
119      } else {
120        // TODO: Add support for PhysicalStorageBufferAddress.
121        MAI.Addr = SPIRV::AddressingModel::Logical;
122      }
123    }
124    // Get the OpenCL version number from metadata.
125    // TODO: support other source languages.
126    if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
127      MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
128      // Construct version literal in accordance with SPIRV-LLVM-Translator.
129      // TODO: support multiple OCL version metadata.
130      assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
131      auto VersionMD = VerNode->getOperand(0);
132      unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
133      unsigned MinorNum = getMetadataUInt(VersionMD, 1);
134      unsigned RevNum = getMetadataUInt(VersionMD, 2);
135      MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
136    } else {
137      MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
138      MAI.SrcLangVersion = 0;
139    }
140  
141    if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
142      for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
143        MDNode *MD = ExtNode->getOperand(I);
144        if (!MD || MD->getNumOperands() == 0)
145          continue;
146        for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
147          MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
148      }
149    }
150  
151    // Update required capabilities for this memory model, addressing model and
152    // source language.
153    MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
154                                   MAI.Mem, *ST);
155    MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
156                                   MAI.SrcLang, *ST);
157    MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
158                                   MAI.Addr, *ST);
159  
160    if (ST->isOpenCLEnv()) {
161      // TODO: check if it's required by default.
162      MAI.ExtInstSetMap[static_cast<unsigned>(
163          SPIRV::InstructionSet::OpenCL_std)] =
164          Register::index2VirtReg(MAI.getNextID());
165    }
166  }
167  
168  // Collect MI which defines the register in the given machine function.
169  static void collectDefInstr(Register Reg, const MachineFunction *MF,
170                              SPIRV::ModuleAnalysisInfo *MAI,
171                              SPIRV::ModuleSectionType MSType,
172                              bool DoInsert = true) {
173    assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
174    MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
175    assert(MI && "There should be an instruction that defines the register");
176    MAI->setSkipEmission(MI);
177    if (DoInsert)
178      MAI->MS[MSType].push_back(MI);
179  }
180  
181  void SPIRVModuleAnalysis::collectGlobalEntities(
182      const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
183      SPIRV::ModuleSectionType MSType,
184      std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
185      bool UsePreOrder = false) {
186    DenseSet<const SPIRV::DTSortableEntry *> Visited;
187    for (const auto *E : DepsGraph) {
188      std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
189      // NOTE: here we prefer recursive approach over iterative because
190      // we don't expect depchains long enough to cause SO.
191      RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
192                      &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
193        if (Visited.count(E) || !Pred(E))
194          return;
195        Visited.insert(E);
196  
197        // Traversing deps graph in post-order allows us to get rid of
198        // register aliases preprocessing.
199        // But pre-order is required for correct processing of function
200        // declaration and arguments processing.
201        if (!UsePreOrder)
202          for (auto *S : E->getDeps())
203            RecHoistUtil(S);
204  
205        Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
206        bool IsFirst = true;
207        for (auto &U : *E) {
208          const MachineFunction *MF = U.first;
209          Register Reg = U.second;
210          MAI.setRegisterAlias(MF, Reg, GlobalReg);
211          if (!MF->getRegInfo().getUniqueVRegDef(Reg))
212            continue;
213          collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
214          IsFirst = false;
215          if (E->getIsGV())
216            MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
217        }
218  
219        if (UsePreOrder)
220          for (auto *S : E->getDeps())
221            RecHoistUtil(S);
222      };
223      RecHoistUtil(E);
224    }
225  }
226  
227  // The function initializes global register alias table for types, consts,
228  // global vars and func decls and collects these instruction for output
229  // at module level. Also it collects explicit OpExtension/OpCapability
230  // instructions.
231  void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
232    std::vector<SPIRV::DTSortableEntry *> DepsGraph;
233  
234    GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
235  
236    collectGlobalEntities(
237        DepsGraph, SPIRV::MB_TypeConstVars,
238        [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
239  
240    for (auto F = M.begin(), E = M.end(); F != E; ++F) {
241      MachineFunction *MF = MMI->getMachineFunction(*F);
242      if (!MF)
243        continue;
244      // Iterate through and collect OpExtension/OpCapability instructions.
245      for (MachineBasicBlock &MBB : *MF) {
246        for (MachineInstr &MI : MBB) {
247          if (MI.getOpcode() == SPIRV::OpExtension) {
248            // Here, OpExtension just has a single enum operand, not a string.
249            auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
250            MAI.Reqs.addExtension(Ext);
251            MAI.setSkipEmission(&MI);
252          } else if (MI.getOpcode() == SPIRV::OpCapability) {
253            auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
254            MAI.Reqs.addCapability(Cap);
255            MAI.setSkipEmission(&MI);
256          }
257        }
258      }
259    }
260  
261    collectGlobalEntities(
262        DepsGraph, SPIRV::MB_ExtFuncDecls,
263        [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
264  }
265  
266  // Look for IDs declared with Import linkage, and map the corresponding function
267  // to the register defining that variable (which will usually be the result of
268  // an OpFunction). This lets us call externally imported functions using
269  // the correct ID registers.
270  void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
271                                             const Function *F) {
272    if (MI.getOpcode() == SPIRV::OpDecorate) {
273      // If it's got Import linkage.
274      auto Dec = MI.getOperand(1).getImm();
275      if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
276        auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
277        if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
278          // Map imported function name to function ID register.
279          const Function *ImportedFunc =
280              F->getParent()->getFunction(getStringImm(MI, 2));
281          Register Target = MI.getOperand(0).getReg();
282          MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
283        }
284      }
285    } else if (MI.getOpcode() == SPIRV::OpFunction) {
286      // Record all internal OpFunction declarations.
287      Register Reg = MI.defs().begin()->getReg();
288      Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
289      assert(GlobalReg.isValid());
290      MAI.FuncMap[F] = GlobalReg;
291    }
292  }
293  
294  using InstrSignature = SmallVector<size_t>;
295  using InstrTraces = std::set<InstrSignature>;
296  
297  // Returns a representation of an instruction as a vector of MachineOperand
298  // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
299  // This creates a signature of the instruction with the same content
300  // that MachineOperand::isIdenticalTo uses for comparison.
301  static InstrSignature instrToSignature(MachineInstr &MI,
302                                         SPIRV::ModuleAnalysisInfo &MAI) {
303    InstrSignature Signature;
304    for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
305      const MachineOperand &MO = MI.getOperand(i);
306      size_t h;
307      if (MO.isReg()) {
308        Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
309        // mimic llvm::hash_value(const MachineOperand &MO)
310        h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
311                         MO.isDef());
312      } else {
313        h = hash_value(MO);
314      }
315      Signature.push_back(h);
316    }
317    return Signature;
318  }
319  
320  // Collect the given instruction in the specified MS. We assume global register
321  // numbering has already occurred by this point. We can directly compare reg
322  // arguments when detecting duplicates.
323  static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
324                                SPIRV::ModuleSectionType MSType, InstrTraces &IS,
325                                bool Append = true) {
326    MAI.setSkipEmission(&MI);
327    InstrSignature MISign = instrToSignature(MI, MAI);
328    auto FoundMI = IS.insert(MISign);
329    if (!FoundMI.second)
330      return; // insert failed, so we found a duplicate; don't add it to MAI.MS
331    // No duplicates, so add it.
332    if (Append)
333      MAI.MS[MSType].push_back(&MI);
334    else
335      MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
336  }
337  
338  // Some global instructions make reference to function-local ID regs, so cannot
339  // be correctly collected until these registers are globally numbered.
340  void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
341    InstrTraces IS;
342    for (auto F = M.begin(), E = M.end(); F != E; ++F) {
343      if ((*F).isDeclaration())
344        continue;
345      MachineFunction *MF = MMI->getMachineFunction(*F);
346      assert(MF);
347      for (MachineBasicBlock &MBB : *MF)
348        for (MachineInstr &MI : MBB) {
349          if (MAI.getSkipEmission(&MI))
350            continue;
351          const unsigned OpCode = MI.getOpcode();
352          if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
353            collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
354          } else if (OpCode == SPIRV::OpEntryPoint) {
355            collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
356          } else if (TII->isDecorationInstr(MI)) {
357            collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
358            collectFuncNames(MI, &*F);
359          } else if (TII->isConstantInstr(MI)) {
360            // Now OpSpecConstant*s are not in DT,
361            // but they need to be collected anyway.
362            collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
363          } else if (OpCode == SPIRV::OpFunction) {
364            collectFuncNames(MI, &*F);
365          } else if (OpCode == SPIRV::OpTypeForwardPointer) {
366            collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
367          }
368        }
369    }
370  }
371  
372  // Number registers in all functions globally from 0 onwards and store
373  // the result in global register alias table. Some registers are already
374  // numbered in collectGlobalEntities.
375  void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
376    for (auto F = M.begin(), E = M.end(); F != E; ++F) {
377      if ((*F).isDeclaration())
378        continue;
379      MachineFunction *MF = MMI->getMachineFunction(*F);
380      assert(MF);
381      for (MachineBasicBlock &MBB : *MF) {
382        for (MachineInstr &MI : MBB) {
383          for (MachineOperand &Op : MI.operands()) {
384            if (!Op.isReg())
385              continue;
386            Register Reg = Op.getReg();
387            if (MAI.hasRegisterAlias(MF, Reg))
388              continue;
389            Register NewReg = Register::index2VirtReg(MAI.getNextID());
390            MAI.setRegisterAlias(MF, Reg, NewReg);
391          }
392          if (MI.getOpcode() != SPIRV::OpExtInst)
393            continue;
394          auto Set = MI.getOperand(2).getImm();
395          if (!MAI.ExtInstSetMap.contains(Set))
396            MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
397        }
398      }
399    }
400  }
401  
402  // RequirementHandler implementations.
403  void SPIRV::RequirementHandler::getAndAddRequirements(
404      SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
405      const SPIRVSubtarget &ST) {
406    addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
407  }
408  
409  void SPIRV::RequirementHandler::pruneCapabilities(
410      const CapabilityList &ToPrune) {
411    for (const auto &Cap : ToPrune) {
412      AllCaps.insert(Cap);
413      auto FoundIndex = llvm::find(MinimalCaps, Cap);
414      if (FoundIndex != MinimalCaps.end())
415        MinimalCaps.erase(FoundIndex);
416      CapabilityList ImplicitDecls =
417          getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
418      pruneCapabilities(ImplicitDecls);
419    }
420  }
421  
422  void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
423    for (const auto &Cap : ToAdd) {
424      bool IsNewlyInserted = AllCaps.insert(Cap).second;
425      if (!IsNewlyInserted) // Don't re-add if it's already been declared.
426        continue;
427      CapabilityList ImplicitDecls =
428          getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
429      pruneCapabilities(ImplicitDecls);
430      MinimalCaps.push_back(Cap);
431    }
432  }
433  
434  void SPIRV::RequirementHandler::addRequirements(
435      const SPIRV::Requirements &Req) {
436    if (!Req.IsSatisfiable)
437      report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
438  
439    if (Req.Cap.has_value())
440      addCapabilities({Req.Cap.value()});
441  
442    addExtensions(Req.Exts);
443  
444    if (Req.MinVer) {
445      if (MaxVersion && Req.MinVer > MaxVersion) {
446        LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
447                          << " and <= " << MaxVersion << "\n");
448        report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
449      }
450  
451      if (MinVersion == 0 || Req.MinVer > MinVersion)
452        MinVersion = Req.MinVer;
453    }
454  
455    if (Req.MaxVer) {
456      if (MinVersion && Req.MaxVer < MinVersion) {
457        LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
458                          << " and >= " << MinVersion << "\n");
459        report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
460      }
461  
462      if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
463        MaxVersion = Req.MaxVer;
464    }
465  }
466  
467  void SPIRV::RequirementHandler::checkSatisfiable(
468      const SPIRVSubtarget &ST) const {
469    // Report as many errors as possible before aborting the compilation.
470    bool IsSatisfiable = true;
471    auto TargetVer = ST.getSPIRVVersion();
472  
473    if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
474      LLVM_DEBUG(
475          dbgs() << "Target SPIR-V version too high for required features\n"
476                 << "Required max version: " << MaxVersion << " target version "
477                 << TargetVer << "\n");
478      IsSatisfiable = false;
479    }
480  
481    if (MinVersion && TargetVer && MinVersion > TargetVer) {
482      LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
483                        << "Required min version: " << MinVersion
484                        << " target version " << TargetVer << "\n");
485      IsSatisfiable = false;
486    }
487  
488    if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
489      LLVM_DEBUG(
490          dbgs()
491          << "Version is too low for some features and too high for others.\n"
492          << "Required SPIR-V min version: " << MinVersion
493          << " required SPIR-V max version " << MaxVersion << "\n");
494      IsSatisfiable = false;
495    }
496  
497    for (auto Cap : MinimalCaps) {
498      if (AvailableCaps.contains(Cap))
499        continue;
500      LLVM_DEBUG(dbgs() << "Capability not supported: "
501                        << getSymbolicOperandMnemonic(
502                               OperandCategory::CapabilityOperand, Cap)
503                        << "\n");
504      IsSatisfiable = false;
505    }
506  
507    for (auto Ext : AllExtensions) {
508      if (ST.canUseExtension(Ext))
509        continue;
510      LLVM_DEBUG(dbgs() << "Extension not supported: "
511                        << getSymbolicOperandMnemonic(
512                               OperandCategory::ExtensionOperand, Ext)
513                        << "\n");
514      IsSatisfiable = false;
515    }
516  
517    if (!IsSatisfiable)
518      report_fatal_error("Unable to meet SPIR-V requirements for this target.");
519  }
520  
521  // Add the given capabilities and all their implicitly defined capabilities too.
522  void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
523    for (const auto Cap : ToAdd)
524      if (AvailableCaps.insert(Cap).second)
525        addAvailableCaps(getSymbolicOperandCapabilities(
526            SPIRV::OperandCategory::CapabilityOperand, Cap));
527  }
528  
529  void SPIRV::RequirementHandler::removeCapabilityIf(
530      const Capability::Capability ToRemove,
531      const Capability::Capability IfPresent) {
532    if (AllCaps.contains(IfPresent))
533      AllCaps.erase(ToRemove);
534  }
535  
536  namespace llvm {
537  namespace SPIRV {
538  void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
539    if (ST.isOpenCLEnv()) {
540      initAvailableCapabilitiesForOpenCL(ST);
541      return;
542    }
543  
544    if (ST.isVulkanEnv()) {
545      initAvailableCapabilitiesForVulkan(ST);
546      return;
547    }
548  
549    report_fatal_error("Unimplemented environment for SPIR-V generation.");
550  }
551  
552  void RequirementHandler::initAvailableCapabilitiesForOpenCL(
553      const SPIRVSubtarget &ST) {
554    // Add the min requirements for different OpenCL and SPIR-V versions.
555    addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
556                      Capability::Int16, Capability::Int8, Capability::Kernel,
557                      Capability::Linkage, Capability::Vector16,
558                      Capability::Groups, Capability::GenericPointer,
559                      Capability::Shader});
560    if (ST.hasOpenCLFullProfile())
561      addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
562    if (ST.hasOpenCLImageSupport()) {
563      addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
564                        Capability::Image1D, Capability::SampledBuffer,
565                        Capability::ImageBuffer});
566      if (ST.isAtLeastOpenCLVer(20))
567        addAvailableCaps({Capability::ImageReadWrite});
568    }
569    if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
570      addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
571    if (ST.isAtLeastSPIRVVer(13))
572      addAvailableCaps({Capability::GroupNonUniform,
573                        Capability::GroupNonUniformVote,
574                        Capability::GroupNonUniformArithmetic,
575                        Capability::GroupNonUniformBallot,
576                        Capability::GroupNonUniformClustered,
577                        Capability::GroupNonUniformShuffle,
578                        Capability::GroupNonUniformShuffleRelative});
579    if (ST.isAtLeastSPIRVVer(14))
580      addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
581                        Capability::SignedZeroInfNanPreserve,
582                        Capability::RoundingModeRTE,
583                        Capability::RoundingModeRTZ});
584    // TODO: verify if this needs some checks.
585    addAvailableCaps({Capability::Float16, Capability::Float64});
586  
587    // Add capabilities enabled by extensions.
588    for (auto Extension : ST.getAllAvailableExtensions()) {
589      CapabilityList EnabledCapabilities =
590          getCapabilitiesEnabledByExtension(Extension);
591      addAvailableCaps(EnabledCapabilities);
592    }
593  
594    // TODO: add OpenCL extensions.
595  }
596  
597  void RequirementHandler::initAvailableCapabilitiesForVulkan(
598      const SPIRVSubtarget &ST) {
599    addAvailableCaps({Capability::Shader, Capability::Linkage});
600  
601    // Provided by all supported Vulkan versions.
602    addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
603                      Capability::Float64});
604  }
605  
606  } // namespace SPIRV
607  } // namespace llvm
608  
609  // Add the required capabilities from a decoration instruction (including
610  // BuiltIns).
611  static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
612                                SPIRV::RequirementHandler &Reqs,
613                                const SPIRVSubtarget &ST) {
614    int64_t DecOp = MI.getOperand(DecIndex).getImm();
615    auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
616    Reqs.addRequirements(getSymbolicOperandRequirements(
617        SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
618  
619    if (Dec == SPIRV::Decoration::BuiltIn) {
620      int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
621      auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
622      Reqs.addRequirements(getSymbolicOperandRequirements(
623          SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
624    }
625  }
626  
627  // Add requirements for image handling.
628  static void addOpTypeImageReqs(const MachineInstr &MI,
629                                 SPIRV::RequirementHandler &Reqs,
630                                 const SPIRVSubtarget &ST) {
631    assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
632    // The operand indices used here are based on the OpTypeImage layout, which
633    // the MachineInstr follows as well.
634    int64_t ImgFormatOp = MI.getOperand(7).getImm();
635    auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
636    Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
637                               ImgFormat, ST);
638  
639    bool IsArrayed = MI.getOperand(4).getImm() == 1;
640    bool IsMultisampled = MI.getOperand(5).getImm() == 1;
641    bool NoSampler = MI.getOperand(6).getImm() == 2;
642    // Add dimension requirements.
643    assert(MI.getOperand(2).isImm());
644    switch (MI.getOperand(2).getImm()) {
645    case SPIRV::Dim::DIM_1D:
646      Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
647                                     : SPIRV::Capability::Sampled1D);
648      break;
649    case SPIRV::Dim::DIM_2D:
650      if (IsMultisampled && NoSampler)
651        Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
652      break;
653    case SPIRV::Dim::DIM_Cube:
654      Reqs.addRequirements(SPIRV::Capability::Shader);
655      if (IsArrayed)
656        Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
657                                       : SPIRV::Capability::SampledCubeArray);
658      break;
659    case SPIRV::Dim::DIM_Rect:
660      Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
661                                     : SPIRV::Capability::SampledRect);
662      break;
663    case SPIRV::Dim::DIM_Buffer:
664      Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
665                                     : SPIRV::Capability::SampledBuffer);
666      break;
667    case SPIRV::Dim::DIM_SubpassData:
668      Reqs.addRequirements(SPIRV::Capability::InputAttachment);
669      break;
670    }
671  
672    // Has optional access qualifier.
673    // TODO: check if it's OpenCL's kernel.
674    if (MI.getNumOperands() > 8 &&
675        MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
676      Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
677    else
678      Reqs.addRequirements(SPIRV::Capability::ImageBasic);
679  }
680  
681  void addInstrRequirements(const MachineInstr &MI,
682                            SPIRV::RequirementHandler &Reqs,
683                            const SPIRVSubtarget &ST) {
684    switch (MI.getOpcode()) {
685    case SPIRV::OpMemoryModel: {
686      int64_t Addr = MI.getOperand(0).getImm();
687      Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
688                                 Addr, ST);
689      int64_t Mem = MI.getOperand(1).getImm();
690      Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
691                                 ST);
692      break;
693    }
694    case SPIRV::OpEntryPoint: {
695      int64_t Exe = MI.getOperand(0).getImm();
696      Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
697                                 Exe, ST);
698      break;
699    }
700    case SPIRV::OpExecutionMode:
701    case SPIRV::OpExecutionModeId: {
702      int64_t Exe = MI.getOperand(1).getImm();
703      Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
704                                 Exe, ST);
705      break;
706    }
707    case SPIRV::OpTypeMatrix:
708      Reqs.addCapability(SPIRV::Capability::Matrix);
709      break;
710    case SPIRV::OpTypeInt: {
711      unsigned BitWidth = MI.getOperand(1).getImm();
712      if (BitWidth == 64)
713        Reqs.addCapability(SPIRV::Capability::Int64);
714      else if (BitWidth == 16)
715        Reqs.addCapability(SPIRV::Capability::Int16);
716      else if (BitWidth == 8)
717        Reqs.addCapability(SPIRV::Capability::Int8);
718      break;
719    }
720    case SPIRV::OpTypeFloat: {
721      unsigned BitWidth = MI.getOperand(1).getImm();
722      if (BitWidth == 64)
723        Reqs.addCapability(SPIRV::Capability::Float64);
724      else if (BitWidth == 16)
725        Reqs.addCapability(SPIRV::Capability::Float16);
726      break;
727    }
728    case SPIRV::OpTypeVector: {
729      unsigned NumComponents = MI.getOperand(2).getImm();
730      if (NumComponents == 8 || NumComponents == 16)
731        Reqs.addCapability(SPIRV::Capability::Vector16);
732      break;
733    }
734    case SPIRV::OpTypePointer: {
735      auto SC = MI.getOperand(1).getImm();
736      Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
737                                 ST);
738      // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
739      // capability.
740      if (!ST.isOpenCLEnv())
741        break;
742      assert(MI.getOperand(2).isReg());
743      const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
744      SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
745      if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
746          TypeDef->getOperand(1).getImm() == 16)
747        Reqs.addCapability(SPIRV::Capability::Float16Buffer);
748      break;
749    }
750    case SPIRV::OpBitReverse:
751    case SPIRV::OpBitFieldInsert:
752    case SPIRV::OpBitFieldSExtract:
753    case SPIRV::OpBitFieldUExtract:
754      if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
755        Reqs.addCapability(SPIRV::Capability::Shader);
756        break;
757      }
758      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
759      Reqs.addCapability(SPIRV::Capability::BitInstructions);
760      break;
761    case SPIRV::OpTypeRuntimeArray:
762      Reqs.addCapability(SPIRV::Capability::Shader);
763      break;
764    case SPIRV::OpTypeOpaque:
765    case SPIRV::OpTypeEvent:
766      Reqs.addCapability(SPIRV::Capability::Kernel);
767      break;
768    case SPIRV::OpTypePipe:
769    case SPIRV::OpTypeReserveId:
770      Reqs.addCapability(SPIRV::Capability::Pipes);
771      break;
772    case SPIRV::OpTypeDeviceEvent:
773    case SPIRV::OpTypeQueue:
774    case SPIRV::OpBuildNDRange:
775      Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
776      break;
777    case SPIRV::OpDecorate:
778    case SPIRV::OpDecorateId:
779    case SPIRV::OpDecorateString:
780      addOpDecorateReqs(MI, 1, Reqs, ST);
781      break;
782    case SPIRV::OpMemberDecorate:
783    case SPIRV::OpMemberDecorateString:
784      addOpDecorateReqs(MI, 2, Reqs, ST);
785      break;
786    case SPIRV::OpInBoundsPtrAccessChain:
787      Reqs.addCapability(SPIRV::Capability::Addresses);
788      break;
789    case SPIRV::OpConstantSampler:
790      Reqs.addCapability(SPIRV::Capability::LiteralSampler);
791      break;
792    case SPIRV::OpTypeImage:
793      addOpTypeImageReqs(MI, Reqs, ST);
794      break;
795    case SPIRV::OpTypeSampler:
796      Reqs.addCapability(SPIRV::Capability::ImageBasic);
797      break;
798    case SPIRV::OpTypeForwardPointer:
799      // TODO: check if it's OpenCL's kernel.
800      Reqs.addCapability(SPIRV::Capability::Addresses);
801      break;
802    case SPIRV::OpAtomicFlagTestAndSet:
803    case SPIRV::OpAtomicLoad:
804    case SPIRV::OpAtomicStore:
805    case SPIRV::OpAtomicExchange:
806    case SPIRV::OpAtomicCompareExchange:
807    case SPIRV::OpAtomicIIncrement:
808    case SPIRV::OpAtomicIDecrement:
809    case SPIRV::OpAtomicIAdd:
810    case SPIRV::OpAtomicISub:
811    case SPIRV::OpAtomicUMin:
812    case SPIRV::OpAtomicUMax:
813    case SPIRV::OpAtomicSMin:
814    case SPIRV::OpAtomicSMax:
815    case SPIRV::OpAtomicAnd:
816    case SPIRV::OpAtomicOr:
817    case SPIRV::OpAtomicXor: {
818      const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
819      const MachineInstr *InstrPtr = &MI;
820      if (MI.getOpcode() == SPIRV::OpAtomicStore) {
821        assert(MI.getOperand(3).isReg());
822        InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
823        assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
824      }
825      assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
826      Register TypeReg = InstrPtr->getOperand(1).getReg();
827      SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
828      if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
829        unsigned BitWidth = TypeDef->getOperand(1).getImm();
830        if (BitWidth == 64)
831          Reqs.addCapability(SPIRV::Capability::Int64Atomics);
832      }
833      break;
834    }
835    case SPIRV::OpGroupNonUniformIAdd:
836    case SPIRV::OpGroupNonUniformFAdd:
837    case SPIRV::OpGroupNonUniformIMul:
838    case SPIRV::OpGroupNonUniformFMul:
839    case SPIRV::OpGroupNonUniformSMin:
840    case SPIRV::OpGroupNonUniformUMin:
841    case SPIRV::OpGroupNonUniformFMin:
842    case SPIRV::OpGroupNonUniformSMax:
843    case SPIRV::OpGroupNonUniformUMax:
844    case SPIRV::OpGroupNonUniformFMax:
845    case SPIRV::OpGroupNonUniformBitwiseAnd:
846    case SPIRV::OpGroupNonUniformBitwiseOr:
847    case SPIRV::OpGroupNonUniformBitwiseXor:
848    case SPIRV::OpGroupNonUniformLogicalAnd:
849    case SPIRV::OpGroupNonUniformLogicalOr:
850    case SPIRV::OpGroupNonUniformLogicalXor: {
851      assert(MI.getOperand(3).isImm());
852      int64_t GroupOp = MI.getOperand(3).getImm();
853      switch (GroupOp) {
854      case SPIRV::GroupOperation::Reduce:
855      case SPIRV::GroupOperation::InclusiveScan:
856      case SPIRV::GroupOperation::ExclusiveScan:
857        Reqs.addCapability(SPIRV::Capability::Kernel);
858        Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
859        Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
860        break;
861      case SPIRV::GroupOperation::ClusteredReduce:
862        Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
863        break;
864      case SPIRV::GroupOperation::PartitionedReduceNV:
865      case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
866      case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
867        Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
868        break;
869      }
870      break;
871    }
872    case SPIRV::OpGroupNonUniformShuffle:
873    case SPIRV::OpGroupNonUniformShuffleXor:
874      Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
875      break;
876    case SPIRV::OpGroupNonUniformShuffleUp:
877    case SPIRV::OpGroupNonUniformShuffleDown:
878      Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
879      break;
880    case SPIRV::OpGroupAll:
881    case SPIRV::OpGroupAny:
882    case SPIRV::OpGroupBroadcast:
883    case SPIRV::OpGroupIAdd:
884    case SPIRV::OpGroupFAdd:
885    case SPIRV::OpGroupFMin:
886    case SPIRV::OpGroupUMin:
887    case SPIRV::OpGroupSMin:
888    case SPIRV::OpGroupFMax:
889    case SPIRV::OpGroupUMax:
890    case SPIRV::OpGroupSMax:
891      Reqs.addCapability(SPIRV::Capability::Groups);
892      break;
893    case SPIRV::OpGroupNonUniformElect:
894      Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
895      break;
896    case SPIRV::OpGroupNonUniformAll:
897    case SPIRV::OpGroupNonUniformAny:
898    case SPIRV::OpGroupNonUniformAllEqual:
899      Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
900      break;
901    case SPIRV::OpGroupNonUniformBroadcast:
902    case SPIRV::OpGroupNonUniformBroadcastFirst:
903    case SPIRV::OpGroupNonUniformBallot:
904    case SPIRV::OpGroupNonUniformInverseBallot:
905    case SPIRV::OpGroupNonUniformBallotBitExtract:
906    case SPIRV::OpGroupNonUniformBallotBitCount:
907    case SPIRV::OpGroupNonUniformBallotFindLSB:
908    case SPIRV::OpGroupNonUniformBallotFindMSB:
909      Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
910      break;
911    case SPIRV::OpAssumeTrueKHR:
912    case SPIRV::OpExpectKHR:
913      if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
914        Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
915        Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
916      }
917      break;
918    default:
919      break;
920    }
921  
922    // If we require capability Shader, then we can remove the requirement for
923    // the BitInstructions capability, since Shader is a superset capability
924    // of BitInstructions.
925    Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
926                            SPIRV::Capability::Shader);
927  }
928  
929  static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
930                          MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
931    // Collect requirements for existing instructions.
932    for (auto F = M.begin(), E = M.end(); F != E; ++F) {
933      MachineFunction *MF = MMI->getMachineFunction(*F);
934      if (!MF)
935        continue;
936      for (const MachineBasicBlock &MBB : *MF)
937        for (const MachineInstr &MI : MBB)
938          addInstrRequirements(MI, MAI.Reqs, ST);
939    }
940    // Collect requirements for OpExecutionMode instructions.
941    auto Node = M.getNamedMetadata("spirv.ExecutionMode");
942    if (Node) {
943      for (unsigned i = 0; i < Node->getNumOperands(); i++) {
944        MDNode *MDN = cast<MDNode>(Node->getOperand(i));
945        const MDOperand &MDOp = MDN->getOperand(1);
946        if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
947          Constant *C = CMeta->getValue();
948          if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
949            auto EM = Const->getZExtValue();
950            MAI.Reqs.getAndAddRequirements(
951                SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
952          }
953        }
954      }
955    }
956    for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
957      const Function &F = *FI;
958      if (F.isDeclaration())
959        continue;
960      if (F.getMetadata("reqd_work_group_size"))
961        MAI.Reqs.getAndAddRequirements(
962            SPIRV::OperandCategory::ExecutionModeOperand,
963            SPIRV::ExecutionMode::LocalSize, ST);
964      if (F.getFnAttribute("hlsl.numthreads").isValid()) {
965        MAI.Reqs.getAndAddRequirements(
966            SPIRV::OperandCategory::ExecutionModeOperand,
967            SPIRV::ExecutionMode::LocalSize, ST);
968      }
969      if (F.getMetadata("work_group_size_hint"))
970        MAI.Reqs.getAndAddRequirements(
971            SPIRV::OperandCategory::ExecutionModeOperand,
972            SPIRV::ExecutionMode::LocalSizeHint, ST);
973      if (F.getMetadata("intel_reqd_sub_group_size"))
974        MAI.Reqs.getAndAddRequirements(
975            SPIRV::OperandCategory::ExecutionModeOperand,
976            SPIRV::ExecutionMode::SubgroupSize, ST);
977      if (F.getMetadata("vec_type_hint"))
978        MAI.Reqs.getAndAddRequirements(
979            SPIRV::OperandCategory::ExecutionModeOperand,
980            SPIRV::ExecutionMode::VecTypeHint, ST);
981  
982      if (F.hasOptNone() &&
983          ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
984        // Output OpCapability OptNoneINTEL.
985        MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
986        MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
987      }
988    }
989  }
990  
991  static unsigned getFastMathFlags(const MachineInstr &I) {
992    unsigned Flags = SPIRV::FPFastMathMode::None;
993    if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
994      Flags |= SPIRV::FPFastMathMode::NotNaN;
995    if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
996      Flags |= SPIRV::FPFastMathMode::NotInf;
997    if (I.getFlag(MachineInstr::MIFlag::FmNsz))
998      Flags |= SPIRV::FPFastMathMode::NSZ;
999    if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1000      Flags |= SPIRV::FPFastMathMode::AllowRecip;
1001    if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1002      Flags |= SPIRV::FPFastMathMode::Fast;
1003    return Flags;
1004  }
1005  
1006  static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1007                                     const SPIRVInstrInfo &TII,
1008                                     SPIRV::RequirementHandler &Reqs) {
1009    if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1010        getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1011                                       SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1012            .IsSatisfiable) {
1013      buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1014                      SPIRV::Decoration::NoSignedWrap, {});
1015    }
1016    if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1017        getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1018                                       SPIRV::Decoration::NoUnsignedWrap, ST,
1019                                       Reqs)
1020            .IsSatisfiable) {
1021      buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1022                      SPIRV::Decoration::NoUnsignedWrap, {});
1023    }
1024    if (!TII.canUseFastMathFlags(I))
1025      return;
1026    unsigned FMFlags = getFastMathFlags(I);
1027    if (FMFlags == SPIRV::FPFastMathMode::None)
1028      return;
1029    Register DstReg = I.getOperand(0).getReg();
1030    buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1031  }
1032  
1033  // Walk all functions and add decorations related to MI flags.
1034  static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1035                             MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1036                             SPIRV::ModuleAnalysisInfo &MAI) {
1037    for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1038      MachineFunction *MF = MMI->getMachineFunction(*F);
1039      if (!MF)
1040        continue;
1041      for (auto &MBB : *MF)
1042        for (auto &MI : MBB)
1043          handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1044    }
1045  }
1046  
1047  struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1048  
1049  void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1050    AU.addRequired<TargetPassConfig>();
1051    AU.addRequired<MachineModuleInfoWrapperPass>();
1052  }
1053  
1054  bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1055    SPIRVTargetMachine &TM =
1056        getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1057    ST = TM.getSubtargetImpl();
1058    GR = ST->getSPIRVGlobalRegistry();
1059    TII = ST->getInstrInfo();
1060  
1061    MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1062  
1063    setBaseInfo(M);
1064  
1065    addDecorations(M, *TII, MMI, *ST, MAI);
1066  
1067    collectReqs(M, MAI, MMI, *ST);
1068  
1069    // Process type/const/global var/func decl instructions, number their
1070    // destination registers from 0 to N, collect Extensions and Capabilities.
1071    processDefInstrs(M);
1072  
1073    // Number rest of registers from N+1 onwards.
1074    numberRegistersGlobally(M);
1075  
1076    // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1077    processOtherInstrs(M);
1078  
1079    // If there are no entry points, we need the Linkage capability.
1080    if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1081      MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1082  
1083    return false;
1084  }
1085