xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "TargetInfo/SPIRVTargetInfo.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "spirv-module-analysis"
32 
33 static cl::opt<bool>
34     SPVDumpDeps("spv-dump-deps",
35                 cl::desc("Dump MIR with SPIR-V dependencies info"),
36                 cl::Optional, cl::init(false));
37 
38 static cl::list<SPIRV::Capability::Capability>
39     AvoidCapabilities("avoid-spirv-capabilities",
40                       cl::desc("SPIR-V capabilities to avoid if there are "
41                                "other options enabling a feature"),
42                       cl::ZeroOrMore, cl::Hidden,
43                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
44                                             "SPIR-V Shader capability")));
45 // Use sets instead of cl::list to check "if contains" condition
46 struct AvoidCapabilitiesSet {
47   SmallSet<SPIRV::Capability::Capability, 4> S;
48   AvoidCapabilitiesSet() {
49     for (auto Cap : AvoidCapabilities)
50       S.insert(Cap);
51   }
52 };
53 
54 char llvm::SPIRVModuleAnalysis::ID = 0;
55 
56 namespace llvm {
57 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
58 } // namespace llvm
59 
60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
61                 true)
62 
63 // Retrieve an unsigned from an MDNode with a list of them as operands.
64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
65                                 unsigned DefaultVal = 0) {
66   if (MdNode && OpIndex < MdNode->getNumOperands()) {
67     const auto &Op = MdNode->getOperand(OpIndex);
68     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
69   }
70   return DefaultVal;
71 }
72 
73 static SPIRV::Requirements
74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
75                                unsigned i, const SPIRVSubtarget &ST,
76                                SPIRV::RequirementHandler &Reqs) {
77   static AvoidCapabilitiesSet
78       AvoidCaps; // contains capabilities to avoid if there is another option
79 
80   VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
81   VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
82   VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83   bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84   bool MaxVerOK =
85       ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
87   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
88   if (ReqCaps.empty()) {
89     if (ReqExts.empty()) {
90       if (MinVerOK && MaxVerOK)
91         return {true, {}, {}, ReqMinVer, ReqMaxVer};
92       return {false, {}, {}, VersionTuple(), VersionTuple()};
93     }
94   } else if (MinVerOK && MaxVerOK) {
95     if (ReqCaps.size() == 1) {
96       auto Cap = ReqCaps[0];
97       if (Reqs.isCapabilityAvailable(Cap))
98         return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
99     } else {
100       // By SPIR-V specification: "If an instruction, enumerant, or other
101       // feature specifies multiple enabling capabilities, only one such
102       // capability needs to be declared to use the feature." However, one
103       // capability may be preferred over another. We use command line
104       // argument(s) and AvoidCapabilities to avoid selection of certain
105       // capabilities if there are other options.
106       CapabilityList UseCaps;
107       for (auto Cap : ReqCaps)
108         if (Reqs.isCapabilityAvailable(Cap))
109           UseCaps.push_back(Cap);
110       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
111         auto Cap = UseCaps[i];
112         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
113           return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
114       }
115     }
116   }
117   // If there are no capabilities, or we can't satisfy the version or
118   // capability requirements, use the list of extensions (if the subtarget
119   // can handle them all).
120   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
121         return ST.canUseExtension(Ext);
122       })) {
123     return {true,
124             {},
125             ReqExts,
126             VersionTuple(),
127             VersionTuple()}; // TODO: add versions to extensions.
128   }
129   return {false, {}, {}, VersionTuple(), VersionTuple()};
130 }
131 
132 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
133   MAI.MaxID = 0;
134   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
135     MAI.MS[i].clear();
136   MAI.RegisterAliasTable.clear();
137   MAI.InstrsToDelete.clear();
138   MAI.FuncMap.clear();
139   MAI.GlobalVarList.clear();
140   MAI.ExtInstSetMap.clear();
141   MAI.Reqs.clear();
142   MAI.Reqs.initAvailableCapabilities(*ST);
143 
144   // TODO: determine memory model and source language from the configuratoin.
145   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
146     auto MemMD = MemModel->getOperand(0);
147     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
148         getMetadataUInt(MemMD, 0));
149     MAI.Mem =
150         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
151   } else {
152     // TODO: Add support for VulkanMemoryModel.
153     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
154                                 : SPIRV::MemoryModel::GLSL450;
155     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
156       unsigned PtrSize = ST->getPointerSize();
157       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
158                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
159                                  : SPIRV::AddressingModel::Logical;
160     } else {
161       // TODO: Add support for PhysicalStorageBufferAddress.
162       MAI.Addr = SPIRV::AddressingModel::Logical;
163     }
164   }
165   // Get the OpenCL version number from metadata.
166   // TODO: support other source languages.
167   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
168     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
169     // Construct version literal in accordance with SPIRV-LLVM-Translator.
170     // TODO: support multiple OCL version metadata.
171     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
172     auto VersionMD = VerNode->getOperand(0);
173     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
174     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
175     unsigned RevNum = getMetadataUInt(VersionMD, 2);
176     // Prevent Major part of OpenCL version to be 0
177     MAI.SrcLangVersion =
178         (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
179   } else {
180     // If there is no information about OpenCL version we are forced to generate
181     // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
182     // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
183     // Translator avoids potential issues with run-times in a similar manner.
184     if (ST->isOpenCLEnv()) {
185       MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
186       MAI.SrcLangVersion = 100000;
187     } else {
188       MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
189       MAI.SrcLangVersion = 0;
190     }
191   }
192 
193   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
194     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
195       MDNode *MD = ExtNode->getOperand(I);
196       if (!MD || MD->getNumOperands() == 0)
197         continue;
198       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
199         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
200     }
201   }
202 
203   // Update required capabilities for this memory model, addressing model and
204   // source language.
205   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
206                                  MAI.Mem, *ST);
207   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
208                                  MAI.SrcLang, *ST);
209   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
210                                  MAI.Addr, *ST);
211 
212   if (ST->isOpenCLEnv()) {
213     // TODO: check if it's required by default.
214     MAI.ExtInstSetMap[static_cast<unsigned>(
215         SPIRV::InstructionSet::OpenCL_std)] =
216         Register::index2VirtReg(MAI.getNextID());
217   }
218 }
219 
220 // Collect MI which defines the register in the given machine function.
221 static void collectDefInstr(Register Reg, const MachineFunction *MF,
222                             SPIRV::ModuleAnalysisInfo *MAI,
223                             SPIRV::ModuleSectionType MSType,
224                             bool DoInsert = true) {
225   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
226   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
227   assert(MI && "There should be an instruction that defines the register");
228   MAI->setSkipEmission(MI);
229   if (DoInsert)
230     MAI->MS[MSType].push_back(MI);
231 }
232 
233 void SPIRVModuleAnalysis::collectGlobalEntities(
234     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
235     SPIRV::ModuleSectionType MSType,
236     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
237     bool UsePreOrder = false) {
238   DenseSet<const SPIRV::DTSortableEntry *> Visited;
239   for (const auto *E : DepsGraph) {
240     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
241     // NOTE: here we prefer recursive approach over iterative because
242     // we don't expect depchains long enough to cause SO.
243     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
244                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
245       if (Visited.count(E) || !Pred(E))
246         return;
247       Visited.insert(E);
248 
249       // Traversing deps graph in post-order allows us to get rid of
250       // register aliases preprocessing.
251       // But pre-order is required for correct processing of function
252       // declaration and arguments processing.
253       if (!UsePreOrder)
254         for (auto *S : E->getDeps())
255           RecHoistUtil(S);
256 
257       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
258       bool IsFirst = true;
259       for (auto &U : *E) {
260         const MachineFunction *MF = U.first;
261         Register Reg = U.second;
262         MAI.setRegisterAlias(MF, Reg, GlobalReg);
263         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
264           continue;
265         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
266         IsFirst = false;
267         if (E->getIsGV())
268           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
269       }
270 
271       if (UsePreOrder)
272         for (auto *S : E->getDeps())
273           RecHoistUtil(S);
274     };
275     RecHoistUtil(E);
276   }
277 }
278 
279 // The function initializes global register alias table for types, consts,
280 // global vars and func decls and collects these instruction for output
281 // at module level. Also it collects explicit OpExtension/OpCapability
282 // instructions.
283 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
284   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
285 
286   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
287 
288   collectGlobalEntities(
289       DepsGraph, SPIRV::MB_TypeConstVars,
290       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
291 
292   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
293     MachineFunction *MF = MMI->getMachineFunction(*F);
294     if (!MF)
295       continue;
296     // Iterate through and collect OpExtension/OpCapability instructions.
297     for (MachineBasicBlock &MBB : *MF) {
298       for (MachineInstr &MI : MBB) {
299         if (MI.getOpcode() == SPIRV::OpExtension) {
300           // Here, OpExtension just has a single enum operand, not a string.
301           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
302           MAI.Reqs.addExtension(Ext);
303           MAI.setSkipEmission(&MI);
304         } else if (MI.getOpcode() == SPIRV::OpCapability) {
305           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
306           MAI.Reqs.addCapability(Cap);
307           MAI.setSkipEmission(&MI);
308         }
309       }
310     }
311   }
312 
313   collectGlobalEntities(
314       DepsGraph, SPIRV::MB_ExtFuncDecls,
315       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
316 }
317 
318 // Look for IDs declared with Import linkage, and map the corresponding function
319 // to the register defining that variable (which will usually be the result of
320 // an OpFunction). This lets us call externally imported functions using
321 // the correct ID registers.
322 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
323                                            const Function *F) {
324   if (MI.getOpcode() == SPIRV::OpDecorate) {
325     // If it's got Import linkage.
326     auto Dec = MI.getOperand(1).getImm();
327     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
328       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
329       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
330         // Map imported function name to function ID register.
331         const Function *ImportedFunc =
332             F->getParent()->getFunction(getStringImm(MI, 2));
333         Register Target = MI.getOperand(0).getReg();
334         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
335       }
336     }
337   } else if (MI.getOpcode() == SPIRV::OpFunction) {
338     // Record all internal OpFunction declarations.
339     Register Reg = MI.defs().begin()->getReg();
340     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
341     assert(GlobalReg.isValid());
342     MAI.FuncMap[F] = GlobalReg;
343   }
344 }
345 
346 // References to a function via function pointers generate virtual
347 // registers without a definition. We are able to resolve this
348 // reference using Globar Register info into an OpFunction instruction
349 // and replace dummy operands by the corresponding global register references.
350 void SPIRVModuleAnalysis::collectFuncPtrs() {
351   for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
352     if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
353       collectFuncPtrs(MI);
354 }
355 
356 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
357   const MachineOperand *FunUse = &MI->getOperand(2);
358   if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
359     const MachineInstr *FunDefMI = FunDef->getParent();
360     assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
361            "Constant function pointer must refer to function definition");
362     Register FunDefReg = FunDef->getReg();
363     Register GlobalFunDefReg =
364         MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
365     assert(GlobalFunDefReg.isValid() &&
366            "Function definition must refer to a global register");
367     Register FunPtrReg = FunUse->getReg();
368     MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
369   }
370 }
371 
372 using InstrSignature = SmallVector<size_t>;
373 using InstrTraces = std::set<InstrSignature>;
374 
375 // Returns a representation of an instruction as a vector of MachineOperand
376 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
377 // This creates a signature of the instruction with the same content
378 // that MachineOperand::isIdenticalTo uses for comparison.
379 static InstrSignature instrToSignature(MachineInstr &MI,
380                                        SPIRV::ModuleAnalysisInfo &MAI) {
381   InstrSignature Signature;
382   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
383     const MachineOperand &MO = MI.getOperand(i);
384     size_t h;
385     if (MO.isReg()) {
386       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
387       // mimic llvm::hash_value(const MachineOperand &MO)
388       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
389                        MO.isDef());
390     } else {
391       h = hash_value(MO);
392     }
393     Signature.push_back(h);
394   }
395   return Signature;
396 }
397 
398 // Collect the given instruction in the specified MS. We assume global register
399 // numbering has already occurred by this point. We can directly compare reg
400 // arguments when detecting duplicates.
401 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
402                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
403                               bool Append = true) {
404   MAI.setSkipEmission(&MI);
405   InstrSignature MISign = instrToSignature(MI, MAI);
406   auto FoundMI = IS.insert(MISign);
407   if (!FoundMI.second)
408     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
409   // No duplicates, so add it.
410   if (Append)
411     MAI.MS[MSType].push_back(&MI);
412   else
413     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
414 }
415 
416 // Some global instructions make reference to function-local ID regs, so cannot
417 // be correctly collected until these registers are globally numbered.
418 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
419   InstrTraces IS;
420   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
421     if ((*F).isDeclaration())
422       continue;
423     MachineFunction *MF = MMI->getMachineFunction(*F);
424     assert(MF);
425     for (MachineBasicBlock &MBB : *MF)
426       for (MachineInstr &MI : MBB) {
427         if (MAI.getSkipEmission(&MI))
428           continue;
429         const unsigned OpCode = MI.getOpcode();
430         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
431           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
432         } else if (OpCode == SPIRV::OpEntryPoint) {
433           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
434         } else if (TII->isDecorationInstr(MI)) {
435           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
436           collectFuncNames(MI, &*F);
437         } else if (TII->isConstantInstr(MI)) {
438           // Now OpSpecConstant*s are not in DT,
439           // but they need to be collected anyway.
440           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
441         } else if (OpCode == SPIRV::OpFunction) {
442           collectFuncNames(MI, &*F);
443         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
444           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
445         }
446       }
447   }
448 }
449 
450 // Number registers in all functions globally from 0 onwards and store
451 // the result in global register alias table. Some registers are already
452 // numbered in collectGlobalEntities.
453 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
454   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
455     if ((*F).isDeclaration())
456       continue;
457     MachineFunction *MF = MMI->getMachineFunction(*F);
458     assert(MF);
459     for (MachineBasicBlock &MBB : *MF) {
460       for (MachineInstr &MI : MBB) {
461         for (MachineOperand &Op : MI.operands()) {
462           if (!Op.isReg())
463             continue;
464           Register Reg = Op.getReg();
465           if (MAI.hasRegisterAlias(MF, Reg))
466             continue;
467           Register NewReg = Register::index2VirtReg(MAI.getNextID());
468           MAI.setRegisterAlias(MF, Reg, NewReg);
469         }
470         if (MI.getOpcode() != SPIRV::OpExtInst)
471           continue;
472         auto Set = MI.getOperand(2).getImm();
473         if (!MAI.ExtInstSetMap.contains(Set))
474           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
475       }
476     }
477   }
478 }
479 
480 // RequirementHandler implementations.
481 void SPIRV::RequirementHandler::getAndAddRequirements(
482     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
483     const SPIRVSubtarget &ST) {
484   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
485 }
486 
487 void SPIRV::RequirementHandler::recursiveAddCapabilities(
488     const CapabilityList &ToPrune) {
489   for (const auto &Cap : ToPrune) {
490     AllCaps.insert(Cap);
491     CapabilityList ImplicitDecls =
492         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
493     recursiveAddCapabilities(ImplicitDecls);
494   }
495 }
496 
497 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
498   for (const auto &Cap : ToAdd) {
499     bool IsNewlyInserted = AllCaps.insert(Cap).second;
500     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
501       continue;
502     CapabilityList ImplicitDecls =
503         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
504     recursiveAddCapabilities(ImplicitDecls);
505     MinimalCaps.push_back(Cap);
506   }
507 }
508 
509 void SPIRV::RequirementHandler::addRequirements(
510     const SPIRV::Requirements &Req) {
511   if (!Req.IsSatisfiable)
512     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
513 
514   if (Req.Cap.has_value())
515     addCapabilities({Req.Cap.value()});
516 
517   addExtensions(Req.Exts);
518 
519   if (!Req.MinVer.empty()) {
520     if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
521       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
522                         << " and <= " << MaxVersion << "\n");
523       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
524     }
525 
526     if (MinVersion.empty() || Req.MinVer > MinVersion)
527       MinVersion = Req.MinVer;
528   }
529 
530   if (!Req.MaxVer.empty()) {
531     if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
532       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
533                         << " and >= " << MinVersion << "\n");
534       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
535     }
536 
537     if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
538       MaxVersion = Req.MaxVer;
539   }
540 }
541 
542 void SPIRV::RequirementHandler::checkSatisfiable(
543     const SPIRVSubtarget &ST) const {
544   // Report as many errors as possible before aborting the compilation.
545   bool IsSatisfiable = true;
546   auto TargetVer = ST.getSPIRVVersion();
547 
548   if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
549     LLVM_DEBUG(
550         dbgs() << "Target SPIR-V version too high for required features\n"
551                << "Required max version: " << MaxVersion << " target version "
552                << TargetVer << "\n");
553     IsSatisfiable = false;
554   }
555 
556   if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
557     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
558                       << "Required min version: " << MinVersion
559                       << " target version " << TargetVer << "\n");
560     IsSatisfiable = false;
561   }
562 
563   if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
564     LLVM_DEBUG(
565         dbgs()
566         << "Version is too low for some features and too high for others.\n"
567         << "Required SPIR-V min version: " << MinVersion
568         << " required SPIR-V max version " << MaxVersion << "\n");
569     IsSatisfiable = false;
570   }
571 
572   for (auto Cap : MinimalCaps) {
573     if (AvailableCaps.contains(Cap))
574       continue;
575     LLVM_DEBUG(dbgs() << "Capability not supported: "
576                       << getSymbolicOperandMnemonic(
577                              OperandCategory::CapabilityOperand, Cap)
578                       << "\n");
579     IsSatisfiable = false;
580   }
581 
582   for (auto Ext : AllExtensions) {
583     if (ST.canUseExtension(Ext))
584       continue;
585     LLVM_DEBUG(dbgs() << "Extension not supported: "
586                       << getSymbolicOperandMnemonic(
587                              OperandCategory::ExtensionOperand, Ext)
588                       << "\n");
589     IsSatisfiable = false;
590   }
591 
592   if (!IsSatisfiable)
593     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
594 }
595 
596 // Add the given capabilities and all their implicitly defined capabilities too.
597 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
598   for (const auto Cap : ToAdd)
599     if (AvailableCaps.insert(Cap).second)
600       addAvailableCaps(getSymbolicOperandCapabilities(
601           SPIRV::OperandCategory::CapabilityOperand, Cap));
602 }
603 
604 void SPIRV::RequirementHandler::removeCapabilityIf(
605     const Capability::Capability ToRemove,
606     const Capability::Capability IfPresent) {
607   if (AllCaps.contains(IfPresent))
608     AllCaps.erase(ToRemove);
609 }
610 
611 namespace llvm {
612 namespace SPIRV {
613 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
614   if (ST.isOpenCLEnv()) {
615     initAvailableCapabilitiesForOpenCL(ST);
616     return;
617   }
618 
619   if (ST.isVulkanEnv()) {
620     initAvailableCapabilitiesForVulkan(ST);
621     return;
622   }
623 
624   report_fatal_error("Unimplemented environment for SPIR-V generation.");
625 }
626 
627 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
628     const SPIRVSubtarget &ST) {
629   // Add the min requirements for different OpenCL and SPIR-V versions.
630   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
631                     Capability::Int16, Capability::Int8, Capability::Kernel,
632                     Capability::Linkage, Capability::Vector16,
633                     Capability::Groups, Capability::GenericPointer,
634                     Capability::Shader});
635   if (ST.hasOpenCLFullProfile())
636     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
637   if (ST.hasOpenCLImageSupport()) {
638     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
639                       Capability::Image1D, Capability::SampledBuffer,
640                       Capability::ImageBuffer});
641     if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
642       addAvailableCaps({Capability::ImageReadWrite});
643   }
644   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
645       ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
646     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
647   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
648     addAvailableCaps({Capability::GroupNonUniform,
649                       Capability::GroupNonUniformVote,
650                       Capability::GroupNonUniformArithmetic,
651                       Capability::GroupNonUniformBallot,
652                       Capability::GroupNonUniformClustered,
653                       Capability::GroupNonUniformShuffle,
654                       Capability::GroupNonUniformShuffleRelative});
655   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
656     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
657                       Capability::SignedZeroInfNanPreserve,
658                       Capability::RoundingModeRTE,
659                       Capability::RoundingModeRTZ});
660   // TODO: verify if this needs some checks.
661   addAvailableCaps({Capability::Float16, Capability::Float64});
662 
663   // Add capabilities enabled by extensions.
664   for (auto Extension : ST.getAllAvailableExtensions()) {
665     CapabilityList EnabledCapabilities =
666         getCapabilitiesEnabledByExtension(Extension);
667     addAvailableCaps(EnabledCapabilities);
668   }
669 
670   // TODO: add OpenCL extensions.
671 }
672 
673 void RequirementHandler::initAvailableCapabilitiesForVulkan(
674     const SPIRVSubtarget &ST) {
675   addAvailableCaps({Capability::Shader, Capability::Linkage});
676 
677   // Provided by all supported Vulkan versions.
678   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
679                     Capability::Float64, Capability::GroupNonUniform});
680 }
681 
682 } // namespace SPIRV
683 } // namespace llvm
684 
685 // Add the required capabilities from a decoration instruction (including
686 // BuiltIns).
687 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
688                               SPIRV::RequirementHandler &Reqs,
689                               const SPIRVSubtarget &ST) {
690   int64_t DecOp = MI.getOperand(DecIndex).getImm();
691   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
692   Reqs.addRequirements(getSymbolicOperandRequirements(
693       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
694 
695   if (Dec == SPIRV::Decoration::BuiltIn) {
696     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
697     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
698     Reqs.addRequirements(getSymbolicOperandRequirements(
699         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
700   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
701     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
702     SPIRV::LinkageType::LinkageType LnkType =
703         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
704     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
705       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
706   } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
707              Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
708     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
709   } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
710     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
711   } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
712              Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
713     Reqs.addExtension(
714         SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
715   }
716 }
717 
718 // Add requirements for image handling.
719 static void addOpTypeImageReqs(const MachineInstr &MI,
720                                SPIRV::RequirementHandler &Reqs,
721                                const SPIRVSubtarget &ST) {
722   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
723   // The operand indices used here are based on the OpTypeImage layout, which
724   // the MachineInstr follows as well.
725   int64_t ImgFormatOp = MI.getOperand(7).getImm();
726   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
727   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
728                              ImgFormat, ST);
729 
730   bool IsArrayed = MI.getOperand(4).getImm() == 1;
731   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
732   bool NoSampler = MI.getOperand(6).getImm() == 2;
733   // Add dimension requirements.
734   assert(MI.getOperand(2).isImm());
735   switch (MI.getOperand(2).getImm()) {
736   case SPIRV::Dim::DIM_1D:
737     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
738                                    : SPIRV::Capability::Sampled1D);
739     break;
740   case SPIRV::Dim::DIM_2D:
741     if (IsMultisampled && NoSampler)
742       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
743     break;
744   case SPIRV::Dim::DIM_Cube:
745     Reqs.addRequirements(SPIRV::Capability::Shader);
746     if (IsArrayed)
747       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
748                                      : SPIRV::Capability::SampledCubeArray);
749     break;
750   case SPIRV::Dim::DIM_Rect:
751     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
752                                    : SPIRV::Capability::SampledRect);
753     break;
754   case SPIRV::Dim::DIM_Buffer:
755     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
756                                    : SPIRV::Capability::SampledBuffer);
757     break;
758   case SPIRV::Dim::DIM_SubpassData:
759     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
760     break;
761   }
762 
763   // Has optional access qualifier.
764   // TODO: check if it's OpenCL's kernel.
765   if (MI.getNumOperands() > 8 &&
766       MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
767     Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
768   else
769     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
770 }
771 
772 // Add requirements for handling atomic float instructions
773 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
774   "The atomic float instruction requires the following SPIR-V "                \
775   "extension: SPV_EXT_shader_atomic_float" ExtName
776 static void AddAtomicFloatRequirements(const MachineInstr &MI,
777                                        SPIRV::RequirementHandler &Reqs,
778                                        const SPIRVSubtarget &ST) {
779   assert(MI.getOperand(1).isReg() &&
780          "Expect register operand in atomic float instruction");
781   Register TypeReg = MI.getOperand(1).getReg();
782   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
783   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
784     report_fatal_error("Result type of an atomic float instruction must be a "
785                        "floating-point type scalar");
786 
787   unsigned BitWidth = TypeDef->getOperand(1).getImm();
788   unsigned Op = MI.getOpcode();
789   if (Op == SPIRV::OpAtomicFAddEXT) {
790     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
791       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
792     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
793     switch (BitWidth) {
794     case 16:
795       if (!ST.canUseExtension(
796               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
797         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
798       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
799       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
800       break;
801     case 32:
802       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
803       break;
804     case 64:
805       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
806       break;
807     default:
808       report_fatal_error(
809           "Unexpected floating-point type width in atomic float instruction");
810     }
811   } else {
812     if (!ST.canUseExtension(
813             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
814       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
815     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
816     switch (BitWidth) {
817     case 16:
818       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
819       break;
820     case 32:
821       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
822       break;
823     case 64:
824       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
825       break;
826     default:
827       report_fatal_error(
828           "Unexpected floating-point type width in atomic float instruction");
829     }
830   }
831 }
832 
833 void addInstrRequirements(const MachineInstr &MI,
834                           SPIRV::RequirementHandler &Reqs,
835                           const SPIRVSubtarget &ST) {
836   switch (MI.getOpcode()) {
837   case SPIRV::OpMemoryModel: {
838     int64_t Addr = MI.getOperand(0).getImm();
839     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
840                                Addr, ST);
841     int64_t Mem = MI.getOperand(1).getImm();
842     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
843                                ST);
844     break;
845   }
846   case SPIRV::OpEntryPoint: {
847     int64_t Exe = MI.getOperand(0).getImm();
848     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
849                                Exe, ST);
850     break;
851   }
852   case SPIRV::OpExecutionMode:
853   case SPIRV::OpExecutionModeId: {
854     int64_t Exe = MI.getOperand(1).getImm();
855     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
856                                Exe, ST);
857     break;
858   }
859   case SPIRV::OpTypeMatrix:
860     Reqs.addCapability(SPIRV::Capability::Matrix);
861     break;
862   case SPIRV::OpTypeInt: {
863     unsigned BitWidth = MI.getOperand(1).getImm();
864     if (BitWidth == 64)
865       Reqs.addCapability(SPIRV::Capability::Int64);
866     else if (BitWidth == 16)
867       Reqs.addCapability(SPIRV::Capability::Int16);
868     else if (BitWidth == 8)
869       Reqs.addCapability(SPIRV::Capability::Int8);
870     break;
871   }
872   case SPIRV::OpTypeFloat: {
873     unsigned BitWidth = MI.getOperand(1).getImm();
874     if (BitWidth == 64)
875       Reqs.addCapability(SPIRV::Capability::Float64);
876     else if (BitWidth == 16)
877       Reqs.addCapability(SPIRV::Capability::Float16);
878     break;
879   }
880   case SPIRV::OpTypeVector: {
881     unsigned NumComponents = MI.getOperand(2).getImm();
882     if (NumComponents == 8 || NumComponents == 16)
883       Reqs.addCapability(SPIRV::Capability::Vector16);
884     break;
885   }
886   case SPIRV::OpTypePointer: {
887     auto SC = MI.getOperand(1).getImm();
888     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
889                                ST);
890     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
891     // capability.
892     if (!ST.isOpenCLEnv())
893       break;
894     assert(MI.getOperand(2).isReg());
895     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
896     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
897     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
898         TypeDef->getOperand(1).getImm() == 16)
899       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
900     break;
901   }
902   case SPIRV::OpBitReverse:
903   case SPIRV::OpBitFieldInsert:
904   case SPIRV::OpBitFieldSExtract:
905   case SPIRV::OpBitFieldUExtract:
906     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
907       Reqs.addCapability(SPIRV::Capability::Shader);
908       break;
909     }
910     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
911     Reqs.addCapability(SPIRV::Capability::BitInstructions);
912     break;
913   case SPIRV::OpTypeRuntimeArray:
914     Reqs.addCapability(SPIRV::Capability::Shader);
915     break;
916   case SPIRV::OpTypeOpaque:
917   case SPIRV::OpTypeEvent:
918     Reqs.addCapability(SPIRV::Capability::Kernel);
919     break;
920   case SPIRV::OpTypePipe:
921   case SPIRV::OpTypeReserveId:
922     Reqs.addCapability(SPIRV::Capability::Pipes);
923     break;
924   case SPIRV::OpTypeDeviceEvent:
925   case SPIRV::OpTypeQueue:
926   case SPIRV::OpBuildNDRange:
927     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
928     break;
929   case SPIRV::OpDecorate:
930   case SPIRV::OpDecorateId:
931   case SPIRV::OpDecorateString:
932     addOpDecorateReqs(MI, 1, Reqs, ST);
933     break;
934   case SPIRV::OpMemberDecorate:
935   case SPIRV::OpMemberDecorateString:
936     addOpDecorateReqs(MI, 2, Reqs, ST);
937     break;
938   case SPIRV::OpInBoundsPtrAccessChain:
939     Reqs.addCapability(SPIRV::Capability::Addresses);
940     break;
941   case SPIRV::OpConstantSampler:
942     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
943     break;
944   case SPIRV::OpTypeImage:
945     addOpTypeImageReqs(MI, Reqs, ST);
946     break;
947   case SPIRV::OpTypeSampler:
948     Reqs.addCapability(SPIRV::Capability::ImageBasic);
949     break;
950   case SPIRV::OpTypeForwardPointer:
951     // TODO: check if it's OpenCL's kernel.
952     Reqs.addCapability(SPIRV::Capability::Addresses);
953     break;
954   case SPIRV::OpAtomicFlagTestAndSet:
955   case SPIRV::OpAtomicLoad:
956   case SPIRV::OpAtomicStore:
957   case SPIRV::OpAtomicExchange:
958   case SPIRV::OpAtomicCompareExchange:
959   case SPIRV::OpAtomicIIncrement:
960   case SPIRV::OpAtomicIDecrement:
961   case SPIRV::OpAtomicIAdd:
962   case SPIRV::OpAtomicISub:
963   case SPIRV::OpAtomicUMin:
964   case SPIRV::OpAtomicUMax:
965   case SPIRV::OpAtomicSMin:
966   case SPIRV::OpAtomicSMax:
967   case SPIRV::OpAtomicAnd:
968   case SPIRV::OpAtomicOr:
969   case SPIRV::OpAtomicXor: {
970     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
971     const MachineInstr *InstrPtr = &MI;
972     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
973       assert(MI.getOperand(3).isReg());
974       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
975       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
976     }
977     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
978     Register TypeReg = InstrPtr->getOperand(1).getReg();
979     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
980     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
981       unsigned BitWidth = TypeDef->getOperand(1).getImm();
982       if (BitWidth == 64)
983         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
984     }
985     break;
986   }
987   case SPIRV::OpGroupNonUniformIAdd:
988   case SPIRV::OpGroupNonUniformFAdd:
989   case SPIRV::OpGroupNonUniformIMul:
990   case SPIRV::OpGroupNonUniformFMul:
991   case SPIRV::OpGroupNonUniformSMin:
992   case SPIRV::OpGroupNonUniformUMin:
993   case SPIRV::OpGroupNonUniformFMin:
994   case SPIRV::OpGroupNonUniformSMax:
995   case SPIRV::OpGroupNonUniformUMax:
996   case SPIRV::OpGroupNonUniformFMax:
997   case SPIRV::OpGroupNonUniformBitwiseAnd:
998   case SPIRV::OpGroupNonUniformBitwiseOr:
999   case SPIRV::OpGroupNonUniformBitwiseXor:
1000   case SPIRV::OpGroupNonUniformLogicalAnd:
1001   case SPIRV::OpGroupNonUniformLogicalOr:
1002   case SPIRV::OpGroupNonUniformLogicalXor: {
1003     assert(MI.getOperand(3).isImm());
1004     int64_t GroupOp = MI.getOperand(3).getImm();
1005     switch (GroupOp) {
1006     case SPIRV::GroupOperation::Reduce:
1007     case SPIRV::GroupOperation::InclusiveScan:
1008     case SPIRV::GroupOperation::ExclusiveScan:
1009       Reqs.addCapability(SPIRV::Capability::Kernel);
1010       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1011       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1012       break;
1013     case SPIRV::GroupOperation::ClusteredReduce:
1014       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1015       break;
1016     case SPIRV::GroupOperation::PartitionedReduceNV:
1017     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1018     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1019       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1020       break;
1021     }
1022     break;
1023   }
1024   case SPIRV::OpGroupNonUniformShuffle:
1025   case SPIRV::OpGroupNonUniformShuffleXor:
1026     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1027     break;
1028   case SPIRV::OpGroupNonUniformShuffleUp:
1029   case SPIRV::OpGroupNonUniformShuffleDown:
1030     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1031     break;
1032   case SPIRV::OpGroupAll:
1033   case SPIRV::OpGroupAny:
1034   case SPIRV::OpGroupBroadcast:
1035   case SPIRV::OpGroupIAdd:
1036   case SPIRV::OpGroupFAdd:
1037   case SPIRV::OpGroupFMin:
1038   case SPIRV::OpGroupUMin:
1039   case SPIRV::OpGroupSMin:
1040   case SPIRV::OpGroupFMax:
1041   case SPIRV::OpGroupUMax:
1042   case SPIRV::OpGroupSMax:
1043     Reqs.addCapability(SPIRV::Capability::Groups);
1044     break;
1045   case SPIRV::OpGroupNonUniformElect:
1046     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1047     break;
1048   case SPIRV::OpGroupNonUniformAll:
1049   case SPIRV::OpGroupNonUniformAny:
1050   case SPIRV::OpGroupNonUniformAllEqual:
1051     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1052     break;
1053   case SPIRV::OpGroupNonUniformBroadcast:
1054   case SPIRV::OpGroupNonUniformBroadcastFirst:
1055   case SPIRV::OpGroupNonUniformBallot:
1056   case SPIRV::OpGroupNonUniformInverseBallot:
1057   case SPIRV::OpGroupNonUniformBallotBitExtract:
1058   case SPIRV::OpGroupNonUniformBallotBitCount:
1059   case SPIRV::OpGroupNonUniformBallotFindLSB:
1060   case SPIRV::OpGroupNonUniformBallotFindMSB:
1061     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1062     break;
1063   case SPIRV::OpSubgroupShuffleINTEL:
1064   case SPIRV::OpSubgroupShuffleDownINTEL:
1065   case SPIRV::OpSubgroupShuffleUpINTEL:
1066   case SPIRV::OpSubgroupShuffleXorINTEL:
1067     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1068       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1069       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1070     }
1071     break;
1072   case SPIRV::OpSubgroupBlockReadINTEL:
1073   case SPIRV::OpSubgroupBlockWriteINTEL:
1074     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1075       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1076       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1077     }
1078     break;
1079   case SPIRV::OpSubgroupImageBlockReadINTEL:
1080   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1081     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1082       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1083       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1084     }
1085     break;
1086   case SPIRV::OpAssumeTrueKHR:
1087   case SPIRV::OpExpectKHR:
1088     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1089       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1090       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1091     }
1092     break;
1093   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1094   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1095     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1096       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1097       Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1098     }
1099     break;
1100   case SPIRV::OpConstantFunctionPointerINTEL:
1101     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1102       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1103       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1104     }
1105     break;
1106   case SPIRV::OpGroupNonUniformRotateKHR:
1107     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1108       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1109                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1110                          false);
1111     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1112     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1113     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1114     break;
1115   case SPIRV::OpGroupIMulKHR:
1116   case SPIRV::OpGroupFMulKHR:
1117   case SPIRV::OpGroupBitwiseAndKHR:
1118   case SPIRV::OpGroupBitwiseOrKHR:
1119   case SPIRV::OpGroupBitwiseXorKHR:
1120   case SPIRV::OpGroupLogicalAndKHR:
1121   case SPIRV::OpGroupLogicalOrKHR:
1122   case SPIRV::OpGroupLogicalXorKHR:
1123     if (ST.canUseExtension(
1124             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1125       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1126       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1127     }
1128     break;
1129   case SPIRV::OpReadClockKHR:
1130     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1131       report_fatal_error("OpReadClockKHR instruction requires the "
1132                          "following SPIR-V extension: SPV_KHR_shader_clock",
1133                          false);
1134     Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1135     Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1136     break;
1137   case SPIRV::OpFunctionPointerCallINTEL:
1138     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1139       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1140       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1141     }
1142     break;
1143   case SPIRV::OpAtomicFAddEXT:
1144   case SPIRV::OpAtomicFMinEXT:
1145   case SPIRV::OpAtomicFMaxEXT:
1146     AddAtomicFloatRequirements(MI, Reqs, ST);
1147     break;
1148   case SPIRV::OpConvertBF16ToFINTEL:
1149   case SPIRV::OpConvertFToBF16INTEL:
1150     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1151       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1152       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1153     }
1154     break;
1155   case SPIRV::OpVariableLengthArrayINTEL:
1156   case SPIRV::OpSaveMemoryINTEL:
1157   case SPIRV::OpRestoreMemoryINTEL:
1158     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1159       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1160       Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1161     }
1162     break;
1163   case SPIRV::OpAsmTargetINTEL:
1164   case SPIRV::OpAsmINTEL:
1165   case SPIRV::OpAsmCallINTEL:
1166     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1167       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1168       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1169     }
1170     break;
1171   case SPIRV::OpTypeCooperativeMatrixKHR:
1172     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1173       report_fatal_error(
1174           "OpTypeCooperativeMatrixKHR type requires the "
1175           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1176           false);
1177     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1178     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1179     break;
1180   default:
1181     break;
1182   }
1183 
1184   // If we require capability Shader, then we can remove the requirement for
1185   // the BitInstructions capability, since Shader is a superset capability
1186   // of BitInstructions.
1187   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1188                           SPIRV::Capability::Shader);
1189 }
1190 
1191 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1192                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1193   // Collect requirements for existing instructions.
1194   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1195     MachineFunction *MF = MMI->getMachineFunction(*F);
1196     if (!MF)
1197       continue;
1198     for (const MachineBasicBlock &MBB : *MF)
1199       for (const MachineInstr &MI : MBB)
1200         addInstrRequirements(MI, MAI.Reqs, ST);
1201   }
1202   // Collect requirements for OpExecutionMode instructions.
1203   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1204   if (Node) {
1205     // SPV_KHR_float_controls is not available until v1.4
1206     bool RequireFloatControls = false,
1207          VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1208     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1209       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1210       const MDOperand &MDOp = MDN->getOperand(1);
1211       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1212         Constant *C = CMeta->getValue();
1213         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1214           auto EM = Const->getZExtValue();
1215           MAI.Reqs.getAndAddRequirements(
1216               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1217           // add SPV_KHR_float_controls if the version is too low
1218           switch (EM) {
1219           case SPIRV::ExecutionMode::DenormPreserve:
1220           case SPIRV::ExecutionMode::DenormFlushToZero:
1221           case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1222           case SPIRV::ExecutionMode::RoundingModeRTE:
1223           case SPIRV::ExecutionMode::RoundingModeRTZ:
1224             RequireFloatControls = VerLower14;
1225             break;
1226           }
1227         }
1228       }
1229     }
1230     if (RequireFloatControls &&
1231         ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1232       MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1233   }
1234   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1235     const Function &F = *FI;
1236     if (F.isDeclaration())
1237       continue;
1238     if (F.getMetadata("reqd_work_group_size"))
1239       MAI.Reqs.getAndAddRequirements(
1240           SPIRV::OperandCategory::ExecutionModeOperand,
1241           SPIRV::ExecutionMode::LocalSize, ST);
1242     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1243       MAI.Reqs.getAndAddRequirements(
1244           SPIRV::OperandCategory::ExecutionModeOperand,
1245           SPIRV::ExecutionMode::LocalSize, ST);
1246     }
1247     if (F.getMetadata("work_group_size_hint"))
1248       MAI.Reqs.getAndAddRequirements(
1249           SPIRV::OperandCategory::ExecutionModeOperand,
1250           SPIRV::ExecutionMode::LocalSizeHint, ST);
1251     if (F.getMetadata("intel_reqd_sub_group_size"))
1252       MAI.Reqs.getAndAddRequirements(
1253           SPIRV::OperandCategory::ExecutionModeOperand,
1254           SPIRV::ExecutionMode::SubgroupSize, ST);
1255     if (F.getMetadata("vec_type_hint"))
1256       MAI.Reqs.getAndAddRequirements(
1257           SPIRV::OperandCategory::ExecutionModeOperand,
1258           SPIRV::ExecutionMode::VecTypeHint, ST);
1259 
1260     if (F.hasOptNone() &&
1261         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1262       // Output OpCapability OptNoneINTEL.
1263       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1264       MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1265     }
1266   }
1267 }
1268 
1269 static unsigned getFastMathFlags(const MachineInstr &I) {
1270   unsigned Flags = SPIRV::FPFastMathMode::None;
1271   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1272     Flags |= SPIRV::FPFastMathMode::NotNaN;
1273   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1274     Flags |= SPIRV::FPFastMathMode::NotInf;
1275   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1276     Flags |= SPIRV::FPFastMathMode::NSZ;
1277   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1278     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1279   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1280     Flags |= SPIRV::FPFastMathMode::Fast;
1281   return Flags;
1282 }
1283 
1284 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1285                                    const SPIRVInstrInfo &TII,
1286                                    SPIRV::RequirementHandler &Reqs) {
1287   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1288       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1289                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1290           .IsSatisfiable) {
1291     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1292                     SPIRV::Decoration::NoSignedWrap, {});
1293   }
1294   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1295       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1296                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1297                                      Reqs)
1298           .IsSatisfiable) {
1299     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1300                     SPIRV::Decoration::NoUnsignedWrap, {});
1301   }
1302   if (!TII.canUseFastMathFlags(I))
1303     return;
1304   unsigned FMFlags = getFastMathFlags(I);
1305   if (FMFlags == SPIRV::FPFastMathMode::None)
1306     return;
1307   Register DstReg = I.getOperand(0).getReg();
1308   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1309 }
1310 
1311 // Walk all functions and add decorations related to MI flags.
1312 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1313                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1314                            SPIRV::ModuleAnalysisInfo &MAI) {
1315   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1316     MachineFunction *MF = MMI->getMachineFunction(*F);
1317     if (!MF)
1318       continue;
1319     for (auto &MBB : *MF)
1320       for (auto &MI : MBB)
1321         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1322   }
1323 }
1324 
1325 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1326 
1327 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1328   AU.addRequired<TargetPassConfig>();
1329   AU.addRequired<MachineModuleInfoWrapperPass>();
1330 }
1331 
1332 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1333   SPIRVTargetMachine &TM =
1334       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1335   ST = TM.getSubtargetImpl();
1336   GR = ST->getSPIRVGlobalRegistry();
1337   TII = ST->getInstrInfo();
1338 
1339   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1340 
1341   setBaseInfo(M);
1342 
1343   addDecorations(M, *TII, MMI, *ST, MAI);
1344 
1345   collectReqs(M, MAI, MMI, *ST);
1346 
1347   // Process type/const/global var/func decl instructions, number their
1348   // destination registers from 0 to N, collect Extensions and Capabilities.
1349   processDefInstrs(M);
1350 
1351   // Number rest of registers from N+1 onwards.
1352   numberRegistersGlobally(M);
1353 
1354   // Update references to OpFunction instructions to use Global Registers
1355   if (GR->hasConstFunPtr())
1356     collectFuncPtrs();
1357 
1358   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1359   processOtherInstrs(M);
1360 
1361   // If there are no entry points, we need the Linkage capability.
1362   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1363     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1364 
1365   // Set maximum ID used.
1366   GR->setBound(MAI.MaxID);
1367 
1368   return false;
1369 }
1370