xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/CodeGen/MachineModuleInfo.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "spirv-module-analysis"
31 
32 static cl::opt<bool>
33     SPVDumpDeps("spv-dump-deps",
34                 cl::desc("Dump MIR with SPIR-V dependencies info"),
35                 cl::Optional, cl::init(false));
36 
37 static cl::list<SPIRV::Capability::Capability>
38     AvoidCapabilities("avoid-spirv-capabilities",
39                       cl::desc("SPIR-V capabilities to avoid if there are "
40                                "other options enabling a feature"),
41                       cl::ZeroOrMore, cl::Hidden,
42                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43                                             "SPIR-V Shader capability")));
44 // Use sets instead of cl::list to check "if contains" condition
45 struct AvoidCapabilitiesSet {
46   SmallSet<SPIRV::Capability::Capability, 4> S;
47   AvoidCapabilitiesSet() { S.insert_range(AvoidCapabilities); }
48 };
49 
50 char llvm::SPIRVModuleAnalysis::ID = 0;
51 
52 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
53                 true)
54 
55 // Retrieve an unsigned from an MDNode with a list of them as operands.
56 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
57                                 unsigned DefaultVal = 0) {
58   if (MdNode && OpIndex < MdNode->getNumOperands()) {
59     const auto &Op = MdNode->getOperand(OpIndex);
60     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
61   }
62   return DefaultVal;
63 }
64 
65 static SPIRV::Requirements
66 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
67                                unsigned i, const SPIRVSubtarget &ST,
68                                SPIRV::RequirementHandler &Reqs) {
69   // A set of capabilities to avoid if there is another option.
70   AvoidCapabilitiesSet AvoidCaps;
71   if (!ST.isShader())
72     AvoidCaps.S.insert(SPIRV::Capability::Shader);
73   else
74     AvoidCaps.S.insert(SPIRV::Capability::Kernel);
75 
76   VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
77   VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
78   VersionTuple SPIRVVersion = ST.getSPIRVVersion();
79   bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
80   bool MaxVerOK =
81       ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
82   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
83   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
84   if (ReqCaps.empty()) {
85     if (ReqExts.empty()) {
86       if (MinVerOK && MaxVerOK)
87         return {true, {}, {}, ReqMinVer, ReqMaxVer};
88       return {false, {}, {}, VersionTuple(), VersionTuple()};
89     }
90   } else if (MinVerOK && MaxVerOK) {
91     if (ReqCaps.size() == 1) {
92       auto Cap = ReqCaps[0];
93       if (Reqs.isCapabilityAvailable(Cap)) {
94         ReqExts.append(getSymbolicOperandExtensions(
95             SPIRV::OperandCategory::CapabilityOperand, Cap));
96         return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
97       }
98     } else {
99       // By SPIR-V specification: "If an instruction, enumerant, or other
100       // feature specifies multiple enabling capabilities, only one such
101       // capability needs to be declared to use the feature." However, one
102       // capability may be preferred over another. We use command line
103       // argument(s) and AvoidCapabilities to avoid selection of certain
104       // capabilities if there are other options.
105       CapabilityList UseCaps;
106       for (auto Cap : ReqCaps)
107         if (Reqs.isCapabilityAvailable(Cap))
108           UseCaps.push_back(Cap);
109       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110         auto Cap = UseCaps[i];
111         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
112           ReqExts.append(getSymbolicOperandExtensions(
113               SPIRV::OperandCategory::CapabilityOperand, Cap));
114           return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
115         }
116       }
117     }
118   }
119   // If there are no capabilities, or we can't satisfy the version or
120   // capability requirements, use the list of extensions (if the subtarget
121   // can handle them all).
122   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
123         return ST.canUseExtension(Ext);
124       })) {
125     return {true,
126             {},
127             ReqExts,
128             VersionTuple(),
129             VersionTuple()}; // TODO: add versions to extensions.
130   }
131   return {false, {}, {}, VersionTuple(), VersionTuple()};
132 }
133 
134 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
135   MAI.MaxID = 0;
136   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
137     MAI.MS[i].clear();
138   MAI.RegisterAliasTable.clear();
139   MAI.InstrsToDelete.clear();
140   MAI.FuncMap.clear();
141   MAI.GlobalVarList.clear();
142   MAI.ExtInstSetMap.clear();
143   MAI.Reqs.clear();
144   MAI.Reqs.initAvailableCapabilities(*ST);
145 
146   // TODO: determine memory model and source language from the configuratoin.
147   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
148     auto MemMD = MemModel->getOperand(0);
149     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
150         getMetadataUInt(MemMD, 0));
151     MAI.Mem =
152         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
153   } else {
154     // TODO: Add support for VulkanMemoryModel.
155     MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
156                              : SPIRV::MemoryModel::OpenCL;
157     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
158       unsigned PtrSize = ST->getPointerSize();
159       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
160                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
161                                  : SPIRV::AddressingModel::Logical;
162     } else {
163       // TODO: Add support for PhysicalStorageBufferAddress.
164       MAI.Addr = SPIRV::AddressingModel::Logical;
165     }
166   }
167   // Get the OpenCL version number from metadata.
168   // TODO: support other source languages.
169   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
170     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
171     // Construct version literal in accordance with SPIRV-LLVM-Translator.
172     // TODO: support multiple OCL version metadata.
173     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
174     auto VersionMD = VerNode->getOperand(0);
175     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
176     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
177     unsigned RevNum = getMetadataUInt(VersionMD, 2);
178     // Prevent Major part of OpenCL version to be 0
179     MAI.SrcLangVersion =
180         (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
181   } else {
182     // If there is no information about OpenCL version we are forced to generate
183     // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
184     // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
185     // Translator avoids potential issues with run-times in a similar manner.
186     if (!ST->isShader()) {
187       MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
188       MAI.SrcLangVersion = 100000;
189     } else {
190       MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
191       MAI.SrcLangVersion = 0;
192     }
193   }
194 
195   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
196     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
197       MDNode *MD = ExtNode->getOperand(I);
198       if (!MD || MD->getNumOperands() == 0)
199         continue;
200       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
201         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
202     }
203   }
204 
205   // Update required capabilities for this memory model, addressing model and
206   // source language.
207   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
208                                  MAI.Mem, *ST);
209   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
210                                  MAI.SrcLang, *ST);
211   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
212                                  MAI.Addr, *ST);
213 
214   if (!ST->isShader()) {
215     // TODO: check if it's required by default.
216     MAI.ExtInstSetMap[static_cast<unsigned>(
217         SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218   }
219 }
220 
221 // Appends the signature of the decoration instructions that decorate R to
222 // Signature.
223 static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
224                                     InstrSignature &Signature) {
225   for (MachineInstr &UseMI : MRI.use_instructions(R)) {
226     // We don't handle OpDecorateId because getting the register alias for the
227     // ID can cause problems, and we do not need it for now.
228     if (UseMI.getOpcode() != SPIRV::OpDecorate &&
229         UseMI.getOpcode() != SPIRV::OpMemberDecorate)
230       continue;
231 
232     for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
233       const MachineOperand &MO = UseMI.getOperand(I);
234       if (MO.isReg())
235         continue;
236       Signature.push_back(hash_value(MO));
237     }
238   }
239 }
240 
241 // Returns a representation of an instruction as a vector of MachineOperand
242 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
243 // This creates a signature of the instruction with the same content
244 // that MachineOperand::isIdenticalTo uses for comparison.
245 static InstrSignature instrToSignature(const MachineInstr &MI,
246                                        SPIRV::ModuleAnalysisInfo &MAI,
247                                        bool UseDefReg) {
248   Register DefReg;
249   InstrSignature Signature{MI.getOpcode()};
250   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
251     const MachineOperand &MO = MI.getOperand(i);
252     size_t h;
253     if (MO.isReg()) {
254       if (!UseDefReg && MO.isDef()) {
255         assert(!DefReg.isValid() && "Multiple def registers.");
256         DefReg = MO.getReg();
257         continue;
258       }
259       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
260       if (!RegAlias.isValid()) {
261         LLVM_DEBUG({
262           dbgs() << "Unexpectedly, no global id found for the operand ";
263           MO.print(dbgs());
264           dbgs() << "\nInstruction: ";
265           MI.print(dbgs());
266           dbgs() << "\n";
267         });
268         report_fatal_error("All v-regs must have been mapped to global id's");
269       }
270       // mimic llvm::hash_value(const MachineOperand &MO)
271       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
272                        MO.isDef());
273     } else {
274       h = hash_value(MO);
275     }
276     Signature.push_back(h);
277   }
278 
279   if (DefReg.isValid()) {
280     // Decorations change the semantics of the current instruction. So two
281     // identical instruction with different decorations cannot be merged. That
282     // is why we add the decorations to the signature.
283     appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
284   }
285   return Signature;
286 }
287 
288 bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
289                                         const MachineInstr &MI) {
290   unsigned Opcode = MI.getOpcode();
291   switch (Opcode) {
292   case SPIRV::OpTypeForwardPointer:
293     // omit now, collect later
294     return false;
295   case SPIRV::OpVariable:
296     return static_cast<SPIRV::StorageClass::StorageClass>(
297                MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
298   case SPIRV::OpFunction:
299   case SPIRV::OpFunctionParameter:
300     return true;
301   }
302   if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
303     Register DefReg = MI.getOperand(0).getReg();
304     for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
305       if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
306         continue;
307       // it's a dummy definition, FP constant refers to a function,
308       // and this is resolved in another way; let's skip this definition
309       assert(UseMI.getOperand(2).isReg() &&
310              UseMI.getOperand(2).getReg() == DefReg);
311       MAI.setSkipEmission(&MI);
312       return false;
313     }
314   }
315   return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
316          TII->isInlineAsmDefInstr(MI);
317 }
318 
319 // This is a special case of a function pointer refering to a possibly
320 // forward function declaration. The operand is a dummy OpUndef that
321 // requires a special treatment.
322 void SPIRVModuleAnalysis::visitFunPtrUse(
323     Register OpReg, InstrGRegsMap &SignatureToGReg,
324     std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
325     const MachineInstr &MI) {
326   const MachineOperand *OpFunDef =
327       GR->getFunctionDefinitionByUse(&MI.getOperand(2));
328   assert(OpFunDef && OpFunDef->isReg());
329   // find the actual function definition and number it globally in advance
330   const MachineInstr *OpDefMI = OpFunDef->getParent();
331   assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
332   const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
333   const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
334   do {
335     visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
336     OpDefMI = OpDefMI->getNextNode();
337   } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
338                        OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
339   // associate the function pointer with the newly assigned global number
340   MCRegister GlobalFunDefReg =
341       MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
342   assert(GlobalFunDefReg.isValid() &&
343          "Function definition must refer to a global register");
344   MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
345 }
346 
347 // Depth first recursive traversal of dependencies. Repeated visits are guarded
348 // by MAI.hasRegisterAlias().
349 void SPIRVModuleAnalysis::visitDecl(
350     const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
351     std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
352     const MachineInstr &MI) {
353   unsigned Opcode = MI.getOpcode();
354 
355   // Process each operand of the instruction to resolve dependencies
356   for (const MachineOperand &MO : MI.operands()) {
357     if (!MO.isReg() || MO.isDef())
358       continue;
359     Register OpReg = MO.getReg();
360     // Handle function pointers special case
361     if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
362         MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
363       visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
364       continue;
365     }
366     // Skip already processed instructions
367     if (MAI.hasRegisterAlias(MF, MO.getReg()))
368       continue;
369     // Recursively visit dependencies
370     if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
371       if (isDeclSection(MRI, *OpDefMI))
372         visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
373       continue;
374     }
375     // Handle the unexpected case of no unique definition for the SPIR-V
376     // instruction
377     LLVM_DEBUG({
378       dbgs() << "Unexpectedly, no unique definition for the operand ";
379       MO.print(dbgs());
380       dbgs() << "\nInstruction: ";
381       MI.print(dbgs());
382       dbgs() << "\n";
383     });
384     report_fatal_error(
385         "No unique definition is found for the virtual register");
386   }
387 
388   MCRegister GReg;
389   bool IsFunDef = false;
390   if (TII->isSpecConstantInstr(MI)) {
391     GReg = MAI.getNextIDRegister();
392     MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
393   } else if (Opcode == SPIRV::OpFunction ||
394              Opcode == SPIRV::OpFunctionParameter) {
395     GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
396   } else if (Opcode == SPIRV::OpTypeStruct ||
397              Opcode == SPIRV::OpConstantComposite) {
398     GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
399     const MachineInstr *NextInstr = MI.getNextNode();
400     while (NextInstr &&
401            ((Opcode == SPIRV::OpTypeStruct &&
402              NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
403             (Opcode == SPIRV::OpConstantComposite &&
404              NextInstr->getOpcode() ==
405                  SPIRV::OpConstantCompositeContinuedINTEL))) {
406       MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
407       MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
408       MAI.setSkipEmission(NextInstr);
409       NextInstr = NextInstr->getNextNode();
410     }
411   } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
412              TII->isInlineAsmDefInstr(MI)) {
413     GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
414   } else if (Opcode == SPIRV::OpVariable) {
415     GReg = handleVariable(MF, MI, GlobalToGReg);
416   } else {
417     LLVM_DEBUG({
418       dbgs() << "\nInstruction: ";
419       MI.print(dbgs());
420       dbgs() << "\n";
421     });
422     llvm_unreachable("Unexpected instruction is visited");
423   }
424   MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
425   if (!IsFunDef)
426     MAI.setSkipEmission(&MI);
427 }
428 
429 MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
430     const MachineFunction *MF, const MachineInstr &MI,
431     std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
432   const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
433   assert(GObj && "Unregistered global definition");
434   const Function *F = dyn_cast<Function>(GObj);
435   if (!F)
436     F = dyn_cast<Argument>(GObj)->getParent();
437   assert(F && "Expected a reference to a function or an argument");
438   IsFunDef = !F->isDeclaration();
439   auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
440   if (!Inserted)
441     return It->second;
442   MCRegister GReg = MAI.getNextIDRegister();
443   It->second = GReg;
444   if (!IsFunDef)
445     MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
446   return GReg;
447 }
448 
449 MCRegister
450 SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
451                                               InstrGRegsMap &SignatureToGReg) {
452   InstrSignature MISign = instrToSignature(MI, MAI, false);
453   auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
454   if (!Inserted)
455     return It->second;
456   MCRegister GReg = MAI.getNextIDRegister();
457   It->second = GReg;
458   MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
459   return GReg;
460 }
461 
462 MCRegister SPIRVModuleAnalysis::handleVariable(
463     const MachineFunction *MF, const MachineInstr &MI,
464     std::map<const Value *, unsigned> &GlobalToGReg) {
465   MAI.GlobalVarList.push_back(&MI);
466   const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
467   assert(GObj && "Unregistered global definition");
468   auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
469   if (!Inserted)
470     return It->second;
471   MCRegister GReg = MAI.getNextIDRegister();
472   It->second = GReg;
473   MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
474   return GReg;
475 }
476 
477 void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
478   InstrGRegsMap SignatureToGReg;
479   std::map<const Value *, unsigned> GlobalToGReg;
480   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
481     MachineFunction *MF = MMI->getMachineFunction(*F);
482     if (!MF)
483       continue;
484     const MachineRegisterInfo &MRI = MF->getRegInfo();
485     unsigned PastHeader = 0;
486     for (MachineBasicBlock &MBB : *MF) {
487       for (MachineInstr &MI : MBB) {
488         if (MI.getNumOperands() == 0)
489           continue;
490         unsigned Opcode = MI.getOpcode();
491         if (Opcode == SPIRV::OpFunction) {
492           if (PastHeader == 0) {
493             PastHeader = 1;
494             continue;
495           }
496         } else if (Opcode == SPIRV::OpFunctionParameter) {
497           if (PastHeader < 2)
498             continue;
499         } else if (PastHeader > 0) {
500           PastHeader = 2;
501         }
502 
503         const MachineOperand &DefMO = MI.getOperand(0);
504         switch (Opcode) {
505         case SPIRV::OpExtension:
506           MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
507           MAI.setSkipEmission(&MI);
508           break;
509         case SPIRV::OpCapability:
510           MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
511           MAI.setSkipEmission(&MI);
512           if (PastHeader > 0)
513             PastHeader = 2;
514           break;
515         default:
516           if (DefMO.isReg() && isDeclSection(MRI, MI) &&
517               !MAI.hasRegisterAlias(MF, DefMO.getReg()))
518             visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
519         }
520       }
521     }
522   }
523 }
524 
525 // Look for IDs declared with Import linkage, and map the corresponding function
526 // to the register defining that variable (which will usually be the result of
527 // an OpFunction). This lets us call externally imported functions using
528 // the correct ID registers.
529 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
530                                            const Function *F) {
531   if (MI.getOpcode() == SPIRV::OpDecorate) {
532     // If it's got Import linkage.
533     auto Dec = MI.getOperand(1).getImm();
534     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
535       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
536       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
537         // Map imported function name to function ID register.
538         const Function *ImportedFunc =
539             F->getParent()->getFunction(getStringImm(MI, 2));
540         Register Target = MI.getOperand(0).getReg();
541         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
542       }
543     }
544   } else if (MI.getOpcode() == SPIRV::OpFunction) {
545     // Record all internal OpFunction declarations.
546     Register Reg = MI.defs().begin()->getReg();
547     MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
548     assert(GlobalReg.isValid());
549     MAI.FuncMap[F] = GlobalReg;
550   }
551 }
552 
553 // Collect the given instruction in the specified MS. We assume global register
554 // numbering has already occurred by this point. We can directly compare reg
555 // arguments when detecting duplicates.
556 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
557                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
558                               bool Append = true) {
559   MAI.setSkipEmission(&MI);
560   InstrSignature MISign = instrToSignature(MI, MAI, true);
561   auto FoundMI = IS.insert(MISign);
562   if (!FoundMI.second)
563     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
564   // No duplicates, so add it.
565   if (Append)
566     MAI.MS[MSType].push_back(&MI);
567   else
568     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
569 }
570 
571 // Some global instructions make reference to function-local ID regs, so cannot
572 // be correctly collected until these registers are globally numbered.
573 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
574   InstrTraces IS;
575   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
576     if ((*F).isDeclaration())
577       continue;
578     MachineFunction *MF = MMI->getMachineFunction(*F);
579     assert(MF);
580 
581     for (MachineBasicBlock &MBB : *MF)
582       for (MachineInstr &MI : MBB) {
583         if (MAI.getSkipEmission(&MI))
584           continue;
585         const unsigned OpCode = MI.getOpcode();
586         if (OpCode == SPIRV::OpString) {
587           collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
588         } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
589                    MI.getOperand(2).getImm() ==
590                        SPIRV::InstructionSet::
591                            NonSemantic_Shader_DebugInfo_100) {
592           MachineOperand Ins = MI.getOperand(3);
593           namespace NS = SPIRV::NonSemanticExtInst;
594           static constexpr int64_t GlobalNonSemanticDITy[] = {
595               NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
596               NS::DebugTypeBasic, NS::DebugTypePointer};
597           bool IsGlobalDI = false;
598           for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
599             IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
600           if (IsGlobalDI)
601             collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
602         } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
603           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
604         } else if (OpCode == SPIRV::OpEntryPoint) {
605           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
606         } else if (TII->isAliasingInstr(MI)) {
607           collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
608         } else if (TII->isDecorationInstr(MI)) {
609           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
610           collectFuncNames(MI, &*F);
611         } else if (TII->isConstantInstr(MI)) {
612           // Now OpSpecConstant*s are not in DT,
613           // but they need to be collected anyway.
614           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
615         } else if (OpCode == SPIRV::OpFunction) {
616           collectFuncNames(MI, &*F);
617         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
618           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
619         }
620       }
621   }
622 }
623 
624 // Number registers in all functions globally from 0 onwards and store
625 // the result in global register alias table. Some registers are already
626 // numbered.
627 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
628   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
629     if ((*F).isDeclaration())
630       continue;
631     MachineFunction *MF = MMI->getMachineFunction(*F);
632     assert(MF);
633     for (MachineBasicBlock &MBB : *MF) {
634       for (MachineInstr &MI : MBB) {
635         for (MachineOperand &Op : MI.operands()) {
636           if (!Op.isReg())
637             continue;
638           Register Reg = Op.getReg();
639           if (MAI.hasRegisterAlias(MF, Reg))
640             continue;
641           MCRegister NewReg = MAI.getNextIDRegister();
642           MAI.setRegisterAlias(MF, Reg, NewReg);
643         }
644         if (MI.getOpcode() != SPIRV::OpExtInst)
645           continue;
646         auto Set = MI.getOperand(2).getImm();
647         auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
648         if (Inserted)
649           It->second = MAI.getNextIDRegister();
650       }
651     }
652   }
653 }
654 
655 // RequirementHandler implementations.
656 void SPIRV::RequirementHandler::getAndAddRequirements(
657     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
658     const SPIRVSubtarget &ST) {
659   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
660 }
661 
662 void SPIRV::RequirementHandler::recursiveAddCapabilities(
663     const CapabilityList &ToPrune) {
664   for (const auto &Cap : ToPrune) {
665     AllCaps.insert(Cap);
666     CapabilityList ImplicitDecls =
667         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
668     recursiveAddCapabilities(ImplicitDecls);
669   }
670 }
671 
672 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
673   for (const auto &Cap : ToAdd) {
674     bool IsNewlyInserted = AllCaps.insert(Cap).second;
675     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
676       continue;
677     CapabilityList ImplicitDecls =
678         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
679     recursiveAddCapabilities(ImplicitDecls);
680     MinimalCaps.push_back(Cap);
681   }
682 }
683 
684 void SPIRV::RequirementHandler::addRequirements(
685     const SPIRV::Requirements &Req) {
686   if (!Req.IsSatisfiable)
687     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
688 
689   if (Req.Cap.has_value())
690     addCapabilities({Req.Cap.value()});
691 
692   addExtensions(Req.Exts);
693 
694   if (!Req.MinVer.empty()) {
695     if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
696       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
697                         << " and <= " << MaxVersion << "\n");
698       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
699     }
700 
701     if (MinVersion.empty() || Req.MinVer > MinVersion)
702       MinVersion = Req.MinVer;
703   }
704 
705   if (!Req.MaxVer.empty()) {
706     if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
707       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
708                         << " and >= " << MinVersion << "\n");
709       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
710     }
711 
712     if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
713       MaxVersion = Req.MaxVer;
714   }
715 }
716 
717 void SPIRV::RequirementHandler::checkSatisfiable(
718     const SPIRVSubtarget &ST) const {
719   // Report as many errors as possible before aborting the compilation.
720   bool IsSatisfiable = true;
721   auto TargetVer = ST.getSPIRVVersion();
722 
723   if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
724     LLVM_DEBUG(
725         dbgs() << "Target SPIR-V version too high for required features\n"
726                << "Required max version: " << MaxVersion << " target version "
727                << TargetVer << "\n");
728     IsSatisfiable = false;
729   }
730 
731   if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
732     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
733                       << "Required min version: " << MinVersion
734                       << " target version " << TargetVer << "\n");
735     IsSatisfiable = false;
736   }
737 
738   if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
739     LLVM_DEBUG(
740         dbgs()
741         << "Version is too low for some features and too high for others.\n"
742         << "Required SPIR-V min version: " << MinVersion
743         << " required SPIR-V max version " << MaxVersion << "\n");
744     IsSatisfiable = false;
745   }
746 
747   for (auto Cap : MinimalCaps) {
748     if (AvailableCaps.contains(Cap))
749       continue;
750     LLVM_DEBUG(dbgs() << "Capability not supported: "
751                       << getSymbolicOperandMnemonic(
752                              OperandCategory::CapabilityOperand, Cap)
753                       << "\n");
754     IsSatisfiable = false;
755   }
756 
757   for (auto Ext : AllExtensions) {
758     if (ST.canUseExtension(Ext))
759       continue;
760     LLVM_DEBUG(dbgs() << "Extension not supported: "
761                       << getSymbolicOperandMnemonic(
762                              OperandCategory::ExtensionOperand, Ext)
763                       << "\n");
764     IsSatisfiable = false;
765   }
766 
767   if (!IsSatisfiable)
768     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
769 }
770 
771 // Add the given capabilities and all their implicitly defined capabilities too.
772 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
773   for (const auto Cap : ToAdd)
774     if (AvailableCaps.insert(Cap).second)
775       addAvailableCaps(getSymbolicOperandCapabilities(
776           SPIRV::OperandCategory::CapabilityOperand, Cap));
777 }
778 
779 void SPIRV::RequirementHandler::removeCapabilityIf(
780     const Capability::Capability ToRemove,
781     const Capability::Capability IfPresent) {
782   if (AllCaps.contains(IfPresent))
783     AllCaps.erase(ToRemove);
784 }
785 
786 namespace llvm {
787 namespace SPIRV {
788 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
789   // Provided by both all supported Vulkan versions and OpenCl.
790   addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
791                     Capability::Int16});
792 
793   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
794     addAvailableCaps({Capability::GroupNonUniform,
795                       Capability::GroupNonUniformVote,
796                       Capability::GroupNonUniformArithmetic,
797                       Capability::GroupNonUniformBallot,
798                       Capability::GroupNonUniformClustered,
799                       Capability::GroupNonUniformShuffle,
800                       Capability::GroupNonUniformShuffleRelative});
801 
802   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
803     addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
804                       Capability::DotProductInput4x8Bit,
805                       Capability::DotProductInput4x8BitPacked,
806                       Capability::DemoteToHelperInvocation});
807 
808   // Add capabilities enabled by extensions.
809   for (auto Extension : ST.getAllAvailableExtensions()) {
810     CapabilityList EnabledCapabilities =
811         getCapabilitiesEnabledByExtension(Extension);
812     addAvailableCaps(EnabledCapabilities);
813   }
814 
815   if (!ST.isShader()) {
816     initAvailableCapabilitiesForOpenCL(ST);
817     return;
818   }
819 
820   if (ST.isShader()) {
821     initAvailableCapabilitiesForVulkan(ST);
822     return;
823   }
824 
825   report_fatal_error("Unimplemented environment for SPIR-V generation.");
826 }
827 
828 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
829     const SPIRVSubtarget &ST) {
830   // Add the min requirements for different OpenCL and SPIR-V versions.
831   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
832                     Capability::Kernel, Capability::Vector16,
833                     Capability::Groups, Capability::GenericPointer,
834                     Capability::StorageImageWriteWithoutFormat,
835                     Capability::StorageImageReadWithoutFormat});
836   if (ST.hasOpenCLFullProfile())
837     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
838   if (ST.hasOpenCLImageSupport()) {
839     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
840                       Capability::Image1D, Capability::SampledBuffer,
841                       Capability::ImageBuffer});
842     if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
843       addAvailableCaps({Capability::ImageReadWrite});
844   }
845   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
846       ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
847     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
848   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
849     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
850                       Capability::SignedZeroInfNanPreserve,
851                       Capability::RoundingModeRTE,
852                       Capability::RoundingModeRTZ});
853   // TODO: verify if this needs some checks.
854   addAvailableCaps({Capability::Float16, Capability::Float64});
855 
856   // TODO: add OpenCL extensions.
857 }
858 
859 void RequirementHandler::initAvailableCapabilitiesForVulkan(
860     const SPIRVSubtarget &ST) {
861 
862   // Core in Vulkan 1.1 and earlier.
863   addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
864                     Capability::GroupNonUniform, Capability::Image1D,
865                     Capability::SampledBuffer, Capability::ImageBuffer,
866                     Capability::UniformBufferArrayDynamicIndexing,
867                     Capability::SampledImageArrayDynamicIndexing,
868                     Capability::StorageBufferArrayDynamicIndexing,
869                     Capability::StorageImageArrayDynamicIndexing});
870 
871   // Became core in Vulkan 1.2
872   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
873     addAvailableCaps(
874         {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
875          Capability::InputAttachmentArrayDynamicIndexingEXT,
876          Capability::UniformTexelBufferArrayDynamicIndexingEXT,
877          Capability::StorageTexelBufferArrayDynamicIndexingEXT,
878          Capability::UniformBufferArrayNonUniformIndexingEXT,
879          Capability::SampledImageArrayNonUniformIndexingEXT,
880          Capability::StorageBufferArrayNonUniformIndexingEXT,
881          Capability::StorageImageArrayNonUniformIndexingEXT,
882          Capability::InputAttachmentArrayNonUniformIndexingEXT,
883          Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
884          Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
885   }
886 
887   // Became core in Vulkan 1.3
888   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
889     addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
890                       Capability::StorageImageReadWithoutFormat});
891 }
892 
893 } // namespace SPIRV
894 } // namespace llvm
895 
896 // Add the required capabilities from a decoration instruction (including
897 // BuiltIns).
898 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
899                               SPIRV::RequirementHandler &Reqs,
900                               const SPIRVSubtarget &ST) {
901   int64_t DecOp = MI.getOperand(DecIndex).getImm();
902   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
903   Reqs.addRequirements(getSymbolicOperandRequirements(
904       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
905 
906   if (Dec == SPIRV::Decoration::BuiltIn) {
907     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
908     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
909     Reqs.addRequirements(getSymbolicOperandRequirements(
910         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
911   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
912     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
913     SPIRV::LinkageType::LinkageType LnkType =
914         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
915     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
916       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
917   } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
918              Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
919     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
920   } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
921     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
922   } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
923              Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
924     Reqs.addExtension(
925         SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
926   } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
927     Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
928   } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
929     Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
930     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
931   }
932 }
933 
934 // Add requirements for image handling.
935 static void addOpTypeImageReqs(const MachineInstr &MI,
936                                SPIRV::RequirementHandler &Reqs,
937                                const SPIRVSubtarget &ST) {
938   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
939   // The operand indices used here are based on the OpTypeImage layout, which
940   // the MachineInstr follows as well.
941   int64_t ImgFormatOp = MI.getOperand(7).getImm();
942   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
943   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
944                              ImgFormat, ST);
945 
946   bool IsArrayed = MI.getOperand(4).getImm() == 1;
947   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
948   bool NoSampler = MI.getOperand(6).getImm() == 2;
949   // Add dimension requirements.
950   assert(MI.getOperand(2).isImm());
951   switch (MI.getOperand(2).getImm()) {
952   case SPIRV::Dim::DIM_1D:
953     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
954                                    : SPIRV::Capability::Sampled1D);
955     break;
956   case SPIRV::Dim::DIM_2D:
957     if (IsMultisampled && NoSampler)
958       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
959     break;
960   case SPIRV::Dim::DIM_Cube:
961     Reqs.addRequirements(SPIRV::Capability::Shader);
962     if (IsArrayed)
963       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
964                                      : SPIRV::Capability::SampledCubeArray);
965     break;
966   case SPIRV::Dim::DIM_Rect:
967     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
968                                    : SPIRV::Capability::SampledRect);
969     break;
970   case SPIRV::Dim::DIM_Buffer:
971     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
972                                    : SPIRV::Capability::SampledBuffer);
973     break;
974   case SPIRV::Dim::DIM_SubpassData:
975     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
976     break;
977   }
978 
979   // Has optional access qualifier.
980   if (!ST.isShader()) {
981     if (MI.getNumOperands() > 8 &&
982         MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
983       Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
984     else
985       Reqs.addRequirements(SPIRV::Capability::ImageBasic);
986   }
987 }
988 
989 // Add requirements for handling atomic float instructions
990 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
991   "The atomic float instruction requires the following SPIR-V "                \
992   "extension: SPV_EXT_shader_atomic_float" ExtName
993 static void AddAtomicFloatRequirements(const MachineInstr &MI,
994                                        SPIRV::RequirementHandler &Reqs,
995                                        const SPIRVSubtarget &ST) {
996   assert(MI.getOperand(1).isReg() &&
997          "Expect register operand in atomic float instruction");
998   Register TypeReg = MI.getOperand(1).getReg();
999   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1000   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1001     report_fatal_error("Result type of an atomic float instruction must be a "
1002                        "floating-point type scalar");
1003 
1004   unsigned BitWidth = TypeDef->getOperand(1).getImm();
1005   unsigned Op = MI.getOpcode();
1006   if (Op == SPIRV::OpAtomicFAddEXT) {
1007     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1008       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
1009     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1010     switch (BitWidth) {
1011     case 16:
1012       if (!ST.canUseExtension(
1013               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1014         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1015       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1016       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1017       break;
1018     case 32:
1019       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1020       break;
1021     case 64:
1022       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1023       break;
1024     default:
1025       report_fatal_error(
1026           "Unexpected floating-point type width in atomic float instruction");
1027     }
1028   } else {
1029     if (!ST.canUseExtension(
1030             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1031       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1032     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1033     switch (BitWidth) {
1034     case 16:
1035       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1036       break;
1037     case 32:
1038       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1039       break;
1040     case 64:
1041       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1042       break;
1043     default:
1044       report_fatal_error(
1045           "Unexpected floating-point type width in atomic float instruction");
1046     }
1047   }
1048 }
1049 
1050 bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1051   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1052     return false;
1053   uint32_t Dim = ImageInst->getOperand(2).getImm();
1054   uint32_t Sampled = ImageInst->getOperand(6).getImm();
1055   return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1056 }
1057 
1058 bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1059   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1060     return false;
1061   uint32_t Dim = ImageInst->getOperand(2).getImm();
1062   uint32_t Sampled = ImageInst->getOperand(6).getImm();
1063   return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1064 }
1065 
1066 bool isSampledImage(MachineInstr *ImageInst) {
1067   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1068     return false;
1069   uint32_t Dim = ImageInst->getOperand(2).getImm();
1070   uint32_t Sampled = ImageInst->getOperand(6).getImm();
1071   return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1072 }
1073 
1074 bool isInputAttachment(MachineInstr *ImageInst) {
1075   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1076     return false;
1077   uint32_t Dim = ImageInst->getOperand(2).getImm();
1078   uint32_t Sampled = ImageInst->getOperand(6).getImm();
1079   return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1080 }
1081 
1082 bool isStorageImage(MachineInstr *ImageInst) {
1083   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1084     return false;
1085   uint32_t Dim = ImageInst->getOperand(2).getImm();
1086   uint32_t Sampled = ImageInst->getOperand(6).getImm();
1087   return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1088 }
1089 
1090 bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1091   if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1092     return false;
1093 
1094   const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1095   Register ImageReg = SampledImageInst->getOperand(1).getReg();
1096   auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1097   return isSampledImage(ImageInst);
1098 }
1099 
1100 bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1101   for (const auto &MI : MRI.reg_instructions(Reg)) {
1102     if (MI.getOpcode() != SPIRV::OpDecorate)
1103       continue;
1104 
1105     uint32_t Dec = MI.getOperand(1).getImm();
1106     if (Dec == SPIRV::Decoration::NonUniformEXT)
1107       return true;
1108   }
1109   return false;
1110 }
1111 
1112 void addOpAccessChainReqs(const MachineInstr &Instr,
1113                           SPIRV::RequirementHandler &Handler,
1114                           const SPIRVSubtarget &Subtarget) {
1115   const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1116   // Get the result type. If it is an image type, then the shader uses
1117   // descriptor indexing. The appropriate capabilities will be added based
1118   // on the specifics of the image.
1119   Register ResTypeReg = Instr.getOperand(1).getReg();
1120   MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1121 
1122   assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1123   uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1124   if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1125       StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1126       StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1127     return;
1128   }
1129 
1130   Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1131   MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1132   if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1133       PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1134       PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1135     return;
1136   }
1137 
1138   bool IsNonUniform =
1139       hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1140   if (isUniformTexelBuffer(PointeeType)) {
1141     if (IsNonUniform)
1142       Handler.addRequirements(
1143           SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1144     else
1145       Handler.addRequirements(
1146           SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1147   } else if (isInputAttachment(PointeeType)) {
1148     if (IsNonUniform)
1149       Handler.addRequirements(
1150           SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1151     else
1152       Handler.addRequirements(
1153           SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1154   } else if (isStorageTexelBuffer(PointeeType)) {
1155     if (IsNonUniform)
1156       Handler.addRequirements(
1157           SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1158     else
1159       Handler.addRequirements(
1160           SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1161   } else if (isSampledImage(PointeeType) ||
1162              isCombinedImageSampler(PointeeType) ||
1163              PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1164     if (IsNonUniform)
1165       Handler.addRequirements(
1166           SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1167     else
1168       Handler.addRequirements(
1169           SPIRV::Capability::SampledImageArrayDynamicIndexing);
1170   } else if (isStorageImage(PointeeType)) {
1171     if (IsNonUniform)
1172       Handler.addRequirements(
1173           SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1174     else
1175       Handler.addRequirements(
1176           SPIRV::Capability::StorageImageArrayDynamicIndexing);
1177   }
1178 }
1179 
1180 static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1181   if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1182     return false;
1183   assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1184   return TypeInst->getOperand(7).getImm() == 0;
1185 }
1186 
1187 static void AddDotProductRequirements(const MachineInstr &MI,
1188                                       SPIRV::RequirementHandler &Reqs,
1189                                       const SPIRVSubtarget &ST) {
1190   if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1191     Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1192   Reqs.addCapability(SPIRV::Capability::DotProduct);
1193 
1194   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1195   assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1196   // We do not consider what the previous instruction is. This is just used
1197   // to get the input register and to check the type.
1198   const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1199   assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1200   Register InputReg = Input->getOperand(1).getReg();
1201 
1202   SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1203   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1204     assert(TypeDef->getOperand(1).getImm() == 32);
1205     Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1206   } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1207     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1208     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1209     if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1210       assert(TypeDef->getOperand(2).getImm() == 4 &&
1211              "Dot operand of 8-bit integer type requires 4 components");
1212       Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1213     } else {
1214       Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1215     }
1216   }
1217 }
1218 
1219 void addInstrRequirements(const MachineInstr &MI,
1220                           SPIRV::RequirementHandler &Reqs,
1221                           const SPIRVSubtarget &ST) {
1222   switch (MI.getOpcode()) {
1223   case SPIRV::OpMemoryModel: {
1224     int64_t Addr = MI.getOperand(0).getImm();
1225     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1226                                Addr, ST);
1227     int64_t Mem = MI.getOperand(1).getImm();
1228     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1229                                ST);
1230     break;
1231   }
1232   case SPIRV::OpEntryPoint: {
1233     int64_t Exe = MI.getOperand(0).getImm();
1234     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1235                                Exe, ST);
1236     break;
1237   }
1238   case SPIRV::OpExecutionMode:
1239   case SPIRV::OpExecutionModeId: {
1240     int64_t Exe = MI.getOperand(1).getImm();
1241     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1242                                Exe, ST);
1243     break;
1244   }
1245   case SPIRV::OpTypeMatrix:
1246     Reqs.addCapability(SPIRV::Capability::Matrix);
1247     break;
1248   case SPIRV::OpTypeInt: {
1249     unsigned BitWidth = MI.getOperand(1).getImm();
1250     if (BitWidth == 64)
1251       Reqs.addCapability(SPIRV::Capability::Int64);
1252     else if (BitWidth == 16)
1253       Reqs.addCapability(SPIRV::Capability::Int16);
1254     else if (BitWidth == 8)
1255       Reqs.addCapability(SPIRV::Capability::Int8);
1256     break;
1257   }
1258   case SPIRV::OpTypeFloat: {
1259     unsigned BitWidth = MI.getOperand(1).getImm();
1260     if (BitWidth == 64)
1261       Reqs.addCapability(SPIRV::Capability::Float64);
1262     else if (BitWidth == 16)
1263       Reqs.addCapability(SPIRV::Capability::Float16);
1264     break;
1265   }
1266   case SPIRV::OpTypeVector: {
1267     unsigned NumComponents = MI.getOperand(2).getImm();
1268     if (NumComponents == 8 || NumComponents == 16)
1269       Reqs.addCapability(SPIRV::Capability::Vector16);
1270     break;
1271   }
1272   case SPIRV::OpTypePointer: {
1273     auto SC = MI.getOperand(1).getImm();
1274     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1275                                ST);
1276     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1277     // capability.
1278     if (ST.isShader())
1279       break;
1280     assert(MI.getOperand(2).isReg());
1281     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1282     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1283     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1284         TypeDef->getOperand(1).getImm() == 16)
1285       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1286     break;
1287   }
1288   case SPIRV::OpExtInst: {
1289     if (MI.getOperand(2).getImm() ==
1290         static_cast<int64_t>(
1291             SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1292       Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1293     }
1294     break;
1295   }
1296   case SPIRV::OpAliasDomainDeclINTEL:
1297   case SPIRV::OpAliasScopeDeclINTEL:
1298   case SPIRV::OpAliasScopeListDeclINTEL: {
1299     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1300     Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1301     break;
1302   }
1303   case SPIRV::OpBitReverse:
1304   case SPIRV::OpBitFieldInsert:
1305   case SPIRV::OpBitFieldSExtract:
1306   case SPIRV::OpBitFieldUExtract:
1307     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1308       Reqs.addCapability(SPIRV::Capability::Shader);
1309       break;
1310     }
1311     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1312     Reqs.addCapability(SPIRV::Capability::BitInstructions);
1313     break;
1314   case SPIRV::OpTypeRuntimeArray:
1315     Reqs.addCapability(SPIRV::Capability::Shader);
1316     break;
1317   case SPIRV::OpTypeOpaque:
1318   case SPIRV::OpTypeEvent:
1319     Reqs.addCapability(SPIRV::Capability::Kernel);
1320     break;
1321   case SPIRV::OpTypePipe:
1322   case SPIRV::OpTypeReserveId:
1323     Reqs.addCapability(SPIRV::Capability::Pipes);
1324     break;
1325   case SPIRV::OpTypeDeviceEvent:
1326   case SPIRV::OpTypeQueue:
1327   case SPIRV::OpBuildNDRange:
1328     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1329     break;
1330   case SPIRV::OpDecorate:
1331   case SPIRV::OpDecorateId:
1332   case SPIRV::OpDecorateString:
1333     addOpDecorateReqs(MI, 1, Reqs, ST);
1334     break;
1335   case SPIRV::OpMemberDecorate:
1336   case SPIRV::OpMemberDecorateString:
1337     addOpDecorateReqs(MI, 2, Reqs, ST);
1338     break;
1339   case SPIRV::OpInBoundsPtrAccessChain:
1340     Reqs.addCapability(SPIRV::Capability::Addresses);
1341     break;
1342   case SPIRV::OpConstantSampler:
1343     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1344     break;
1345   case SPIRV::OpInBoundsAccessChain:
1346   case SPIRV::OpAccessChain:
1347     addOpAccessChainReqs(MI, Reqs, ST);
1348     break;
1349   case SPIRV::OpTypeImage:
1350     addOpTypeImageReqs(MI, Reqs, ST);
1351     break;
1352   case SPIRV::OpTypeSampler:
1353     if (!ST.isShader()) {
1354       Reqs.addCapability(SPIRV::Capability::ImageBasic);
1355     }
1356     break;
1357   case SPIRV::OpTypeForwardPointer:
1358     // TODO: check if it's OpenCL's kernel.
1359     Reqs.addCapability(SPIRV::Capability::Addresses);
1360     break;
1361   case SPIRV::OpAtomicFlagTestAndSet:
1362   case SPIRV::OpAtomicLoad:
1363   case SPIRV::OpAtomicStore:
1364   case SPIRV::OpAtomicExchange:
1365   case SPIRV::OpAtomicCompareExchange:
1366   case SPIRV::OpAtomicIIncrement:
1367   case SPIRV::OpAtomicIDecrement:
1368   case SPIRV::OpAtomicIAdd:
1369   case SPIRV::OpAtomicISub:
1370   case SPIRV::OpAtomicUMin:
1371   case SPIRV::OpAtomicUMax:
1372   case SPIRV::OpAtomicSMin:
1373   case SPIRV::OpAtomicSMax:
1374   case SPIRV::OpAtomicAnd:
1375   case SPIRV::OpAtomicOr:
1376   case SPIRV::OpAtomicXor: {
1377     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1378     const MachineInstr *InstrPtr = &MI;
1379     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1380       assert(MI.getOperand(3).isReg());
1381       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1382       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1383     }
1384     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1385     Register TypeReg = InstrPtr->getOperand(1).getReg();
1386     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1387     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1388       unsigned BitWidth = TypeDef->getOperand(1).getImm();
1389       if (BitWidth == 64)
1390         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1391     }
1392     break;
1393   }
1394   case SPIRV::OpGroupNonUniformIAdd:
1395   case SPIRV::OpGroupNonUniformFAdd:
1396   case SPIRV::OpGroupNonUniformIMul:
1397   case SPIRV::OpGroupNonUniformFMul:
1398   case SPIRV::OpGroupNonUniformSMin:
1399   case SPIRV::OpGroupNonUniformUMin:
1400   case SPIRV::OpGroupNonUniformFMin:
1401   case SPIRV::OpGroupNonUniformSMax:
1402   case SPIRV::OpGroupNonUniformUMax:
1403   case SPIRV::OpGroupNonUniformFMax:
1404   case SPIRV::OpGroupNonUniformBitwiseAnd:
1405   case SPIRV::OpGroupNonUniformBitwiseOr:
1406   case SPIRV::OpGroupNonUniformBitwiseXor:
1407   case SPIRV::OpGroupNonUniformLogicalAnd:
1408   case SPIRV::OpGroupNonUniformLogicalOr:
1409   case SPIRV::OpGroupNonUniformLogicalXor: {
1410     assert(MI.getOperand(3).isImm());
1411     int64_t GroupOp = MI.getOperand(3).getImm();
1412     switch (GroupOp) {
1413     case SPIRV::GroupOperation::Reduce:
1414     case SPIRV::GroupOperation::InclusiveScan:
1415     case SPIRV::GroupOperation::ExclusiveScan:
1416       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1417       break;
1418     case SPIRV::GroupOperation::ClusteredReduce:
1419       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1420       break;
1421     case SPIRV::GroupOperation::PartitionedReduceNV:
1422     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1423     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1424       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1425       break;
1426     }
1427     break;
1428   }
1429   case SPIRV::OpGroupNonUniformShuffle:
1430   case SPIRV::OpGroupNonUniformShuffleXor:
1431     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1432     break;
1433   case SPIRV::OpGroupNonUniformShuffleUp:
1434   case SPIRV::OpGroupNonUniformShuffleDown:
1435     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1436     break;
1437   case SPIRV::OpGroupAll:
1438   case SPIRV::OpGroupAny:
1439   case SPIRV::OpGroupBroadcast:
1440   case SPIRV::OpGroupIAdd:
1441   case SPIRV::OpGroupFAdd:
1442   case SPIRV::OpGroupFMin:
1443   case SPIRV::OpGroupUMin:
1444   case SPIRV::OpGroupSMin:
1445   case SPIRV::OpGroupFMax:
1446   case SPIRV::OpGroupUMax:
1447   case SPIRV::OpGroupSMax:
1448     Reqs.addCapability(SPIRV::Capability::Groups);
1449     break;
1450   case SPIRV::OpGroupNonUniformElect:
1451     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1452     break;
1453   case SPIRV::OpGroupNonUniformAll:
1454   case SPIRV::OpGroupNonUniformAny:
1455   case SPIRV::OpGroupNonUniformAllEqual:
1456     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1457     break;
1458   case SPIRV::OpGroupNonUniformBroadcast:
1459   case SPIRV::OpGroupNonUniformBroadcastFirst:
1460   case SPIRV::OpGroupNonUniformBallot:
1461   case SPIRV::OpGroupNonUniformInverseBallot:
1462   case SPIRV::OpGroupNonUniformBallotBitExtract:
1463   case SPIRV::OpGroupNonUniformBallotBitCount:
1464   case SPIRV::OpGroupNonUniformBallotFindLSB:
1465   case SPIRV::OpGroupNonUniformBallotFindMSB:
1466     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1467     break;
1468   case SPIRV::OpSubgroupShuffleINTEL:
1469   case SPIRV::OpSubgroupShuffleDownINTEL:
1470   case SPIRV::OpSubgroupShuffleUpINTEL:
1471   case SPIRV::OpSubgroupShuffleXorINTEL:
1472     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1473       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1474       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1475     }
1476     break;
1477   case SPIRV::OpSubgroupBlockReadINTEL:
1478   case SPIRV::OpSubgroupBlockWriteINTEL:
1479     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1480       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1481       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1482     }
1483     break;
1484   case SPIRV::OpSubgroupImageBlockReadINTEL:
1485   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1486     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1487       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1488       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1489     }
1490     break;
1491   case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1492   case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1493     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1494       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1495       Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1496     }
1497     break;
1498   case SPIRV::OpAssumeTrueKHR:
1499   case SPIRV::OpExpectKHR:
1500     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1501       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1502       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1503     }
1504     break;
1505   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1506   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1507     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1508       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1509       Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1510     }
1511     break;
1512   case SPIRV::OpConstantFunctionPointerINTEL:
1513     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1514       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1515       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1516     }
1517     break;
1518   case SPIRV::OpGroupNonUniformRotateKHR:
1519     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1520       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1521                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1522                          false);
1523     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1524     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1525     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1526     break;
1527   case SPIRV::OpGroupIMulKHR:
1528   case SPIRV::OpGroupFMulKHR:
1529   case SPIRV::OpGroupBitwiseAndKHR:
1530   case SPIRV::OpGroupBitwiseOrKHR:
1531   case SPIRV::OpGroupBitwiseXorKHR:
1532   case SPIRV::OpGroupLogicalAndKHR:
1533   case SPIRV::OpGroupLogicalOrKHR:
1534   case SPIRV::OpGroupLogicalXorKHR:
1535     if (ST.canUseExtension(
1536             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1537       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1538       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1539     }
1540     break;
1541   case SPIRV::OpReadClockKHR:
1542     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1543       report_fatal_error("OpReadClockKHR instruction requires the "
1544                          "following SPIR-V extension: SPV_KHR_shader_clock",
1545                          false);
1546     Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1547     Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1548     break;
1549   case SPIRV::OpFunctionPointerCallINTEL:
1550     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1551       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1552       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1553     }
1554     break;
1555   case SPIRV::OpAtomicFAddEXT:
1556   case SPIRV::OpAtomicFMinEXT:
1557   case SPIRV::OpAtomicFMaxEXT:
1558     AddAtomicFloatRequirements(MI, Reqs, ST);
1559     break;
1560   case SPIRV::OpConvertBF16ToFINTEL:
1561   case SPIRV::OpConvertFToBF16INTEL:
1562     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1563       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1564       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1565     }
1566     break;
1567   case SPIRV::OpVariableLengthArrayINTEL:
1568   case SPIRV::OpSaveMemoryINTEL:
1569   case SPIRV::OpRestoreMemoryINTEL:
1570     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1571       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1572       Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1573     }
1574     break;
1575   case SPIRV::OpAsmTargetINTEL:
1576   case SPIRV::OpAsmINTEL:
1577   case SPIRV::OpAsmCallINTEL:
1578     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1579       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1580       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1581     }
1582     break;
1583   case SPIRV::OpTypeCooperativeMatrixKHR:
1584     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1585       report_fatal_error(
1586           "OpTypeCooperativeMatrixKHR type requires the "
1587           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1588           false);
1589     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1590     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1591     break;
1592   case SPIRV::OpArithmeticFenceEXT:
1593     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1594       report_fatal_error("OpArithmeticFenceEXT requires the "
1595                          "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1596                          false);
1597     Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1598     Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1599     break;
1600   case SPIRV::OpControlBarrierArriveINTEL:
1601   case SPIRV::OpControlBarrierWaitINTEL:
1602     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1603       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1604       Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1605     }
1606     break;
1607   case SPIRV::OpCooperativeMatrixMulAddKHR: {
1608     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1609       report_fatal_error("Cooperative matrix instructions require the "
1610                          "following SPIR-V extension: "
1611                          "SPV_KHR_cooperative_matrix",
1612                          false);
1613     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1614     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1615     constexpr unsigned MulAddMaxSize = 6;
1616     if (MI.getNumOperands() != MulAddMaxSize)
1617       break;
1618     const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1619     if (CoopOperands &
1620         SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1621       if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1622         report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1623                            "require the following SPIR-V extension: "
1624                            "SPV_INTEL_joint_matrix",
1625                            false);
1626       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1627       Reqs.addCapability(
1628           SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1629     }
1630     if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1631                            MatrixAAndBBFloat16ComponentsINTEL ||
1632         CoopOperands &
1633             SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1634         CoopOperands & SPIRV::CooperativeMatrixOperands::
1635                            MatrixResultBFloat16ComponentsINTEL) {
1636       if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1637         report_fatal_error("***BF16ComponentsINTEL type interpretations "
1638                            "require the following SPIR-V extension: "
1639                            "SPV_INTEL_joint_matrix",
1640                            false);
1641       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1642       Reqs.addCapability(
1643           SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1644     }
1645     break;
1646   }
1647   case SPIRV::OpCooperativeMatrixLoadKHR:
1648   case SPIRV::OpCooperativeMatrixStoreKHR:
1649   case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1650   case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1651   case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1652     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1653       report_fatal_error("Cooperative matrix instructions require the "
1654                          "following SPIR-V extension: "
1655                          "SPV_KHR_cooperative_matrix",
1656                          false);
1657     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1658     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1659 
1660     // Check Layout operand in case if it's not a standard one and add the
1661     // appropriate capability.
1662     std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1663         {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1664         {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1665         {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1666         {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1667         {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1668 
1669     const auto OpCode = MI.getOpcode();
1670     const unsigned LayoutNum = LayoutToInstMap[OpCode];
1671     Register RegLayout = MI.getOperand(LayoutNum).getReg();
1672     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1673     MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1674     if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1675       const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1676       if (LayoutVal ==
1677           static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1678         if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1679           report_fatal_error("PackedINTEL layout require the following SPIR-V "
1680                              "extension: SPV_INTEL_joint_matrix",
1681                              false);
1682         Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1683         Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1684       }
1685     }
1686 
1687     // Nothing to do.
1688     if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1689         OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1690       break;
1691 
1692     std::string InstName;
1693     switch (OpCode) {
1694     case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1695       InstName = "OpCooperativeMatrixPrefetchINTEL";
1696       break;
1697     case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1698       InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1699       break;
1700     case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1701       InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1702       break;
1703     }
1704 
1705     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1706       const std::string ErrorMsg =
1707           InstName + " instruction requires the "
1708                      "following SPIR-V extension: SPV_INTEL_joint_matrix";
1709       report_fatal_error(ErrorMsg.c_str(), false);
1710     }
1711     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1712     if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1713       Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1714       break;
1715     }
1716     Reqs.addCapability(
1717         SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1718     break;
1719   }
1720   case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1721     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1722       report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1723                          "instructions require the following SPIR-V extension: "
1724                          "SPV_INTEL_joint_matrix",
1725                          false);
1726     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1727     Reqs.addCapability(
1728         SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1729     break;
1730   case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1731     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1732       report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1733                          "following SPIR-V extension: SPV_INTEL_joint_matrix",
1734                          false);
1735     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1736     Reqs.addCapability(
1737         SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1738     break;
1739   case SPIRV::OpConvertHandleToImageINTEL:
1740   case SPIRV::OpConvertHandleToSamplerINTEL:
1741   case SPIRV::OpConvertHandleToSampledImageINTEL:
1742     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
1743       report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
1744                          "instructions require the following SPIR-V extension: "
1745                          "SPV_INTEL_bindless_images",
1746                          false);
1747     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
1748     Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
1749     break;
1750   case SPIRV::OpSubgroup2DBlockLoadINTEL:
1751   case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
1752   case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
1753   case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
1754   case SPIRV::OpSubgroup2DBlockStoreINTEL: {
1755     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
1756       report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
1757                          "Prefetch/Store]INTEL instructions require the "
1758                          "following SPIR-V extension: SPV_INTEL_2d_block_io",
1759                          false);
1760     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
1761     Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
1762 
1763     const auto OpCode = MI.getOpcode();
1764     if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
1765       Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
1766       break;
1767     }
1768     if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
1769       Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
1770       break;
1771     }
1772     break;
1773   }
1774   case SPIRV::OpKill: {
1775     Reqs.addCapability(SPIRV::Capability::Shader);
1776   } break;
1777   case SPIRV::OpDemoteToHelperInvocation:
1778     Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1779 
1780     if (ST.canUseExtension(
1781             SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1782       if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
1783         Reqs.addExtension(
1784             SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
1785     }
1786     break;
1787   case SPIRV::OpSDot:
1788   case SPIRV::OpUDot:
1789   case SPIRV::OpSUDot:
1790   case SPIRV::OpSDotAccSat:
1791   case SPIRV::OpUDotAccSat:
1792   case SPIRV::OpSUDotAccSat:
1793     AddDotProductRequirements(MI, Reqs, ST);
1794     break;
1795   case SPIRV::OpImageRead: {
1796     Register ImageReg = MI.getOperand(2).getReg();
1797     SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1798         ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1799     // OpImageRead and OpImageWrite can use Unknown Image Formats
1800     // when the Kernel capability is declared. In the OpenCL environment we are
1801     // not allowed to produce
1802     // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1803     // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1804 
1805     if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1806       Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
1807     break;
1808   }
1809   case SPIRV::OpImageWrite: {
1810     Register ImageReg = MI.getOperand(0).getReg();
1811     SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1812         ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1813     // OpImageRead and OpImageWrite can use Unknown Image Formats
1814     // when the Kernel capability is declared. In the OpenCL environment we are
1815     // not allowed to produce
1816     // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1817     // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1818 
1819     if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1820       Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
1821     break;
1822   }
1823   case SPIRV::OpTypeStructContinuedINTEL:
1824   case SPIRV::OpConstantCompositeContinuedINTEL:
1825   case SPIRV::OpSpecConstantCompositeContinuedINTEL:
1826   case SPIRV::OpCompositeConstructContinuedINTEL: {
1827     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
1828       report_fatal_error(
1829           "Continued instructions require the "
1830           "following SPIR-V extension: SPV_INTEL_long_composites",
1831           false);
1832     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
1833     Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
1834     break;
1835   }
1836   case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
1837     if (!ST.canUseExtension(
1838             SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
1839       report_fatal_error(
1840           "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
1841           "following SPIR-V "
1842           "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
1843           false);
1844     Reqs.addExtension(
1845         SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
1846     Reqs.addCapability(
1847         SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
1848     break;
1849   }
1850   case SPIRV::OpBitwiseFunctionINTEL: {
1851     if (!ST.canUseExtension(
1852             SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
1853       report_fatal_error(
1854           "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
1855           "extension: SPV_INTEL_ternary_bitwise_function",
1856           false);
1857     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
1858     Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
1859     break;
1860   }
1861 
1862   default:
1863     break;
1864   }
1865 
1866   // If we require capability Shader, then we can remove the requirement for
1867   // the BitInstructions capability, since Shader is a superset capability
1868   // of BitInstructions.
1869   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1870                           SPIRV::Capability::Shader);
1871 }
1872 
1873 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1874                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1875   // Collect requirements for existing instructions.
1876   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1877     MachineFunction *MF = MMI->getMachineFunction(*F);
1878     if (!MF)
1879       continue;
1880     for (const MachineBasicBlock &MBB : *MF)
1881       for (const MachineInstr &MI : MBB)
1882         addInstrRequirements(MI, MAI.Reqs, ST);
1883   }
1884   // Collect requirements for OpExecutionMode instructions.
1885   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1886   if (Node) {
1887     bool RequireFloatControls = false, RequireFloatControls2 = false,
1888          VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1889     bool HasFloatControls2 =
1890         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
1891     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1892       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1893       const MDOperand &MDOp = MDN->getOperand(1);
1894       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1895         Constant *C = CMeta->getValue();
1896         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1897           auto EM = Const->getZExtValue();
1898           // SPV_KHR_float_controls is not available until v1.4:
1899           // add SPV_KHR_float_controls if the version is too low
1900           switch (EM) {
1901           case SPIRV::ExecutionMode::DenormPreserve:
1902           case SPIRV::ExecutionMode::DenormFlushToZero:
1903           case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1904           case SPIRV::ExecutionMode::RoundingModeRTE:
1905           case SPIRV::ExecutionMode::RoundingModeRTZ:
1906             RequireFloatControls = VerLower14;
1907             MAI.Reqs.getAndAddRequirements(
1908                 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1909             break;
1910           case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
1911           case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
1912           case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
1913           case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
1914             if (HasFloatControls2) {
1915               RequireFloatControls2 = true;
1916               MAI.Reqs.getAndAddRequirements(
1917                   SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1918             }
1919             break;
1920           default:
1921             MAI.Reqs.getAndAddRequirements(
1922                 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1923           }
1924         }
1925       }
1926     }
1927     if (RequireFloatControls &&
1928         ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1929       MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1930     if (RequireFloatControls2)
1931       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
1932   }
1933   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1934     const Function &F = *FI;
1935     if (F.isDeclaration())
1936       continue;
1937     if (F.getMetadata("reqd_work_group_size"))
1938       MAI.Reqs.getAndAddRequirements(
1939           SPIRV::OperandCategory::ExecutionModeOperand,
1940           SPIRV::ExecutionMode::LocalSize, ST);
1941     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1942       MAI.Reqs.getAndAddRequirements(
1943           SPIRV::OperandCategory::ExecutionModeOperand,
1944           SPIRV::ExecutionMode::LocalSize, ST);
1945     }
1946     if (F.getMetadata("work_group_size_hint"))
1947       MAI.Reqs.getAndAddRequirements(
1948           SPIRV::OperandCategory::ExecutionModeOperand,
1949           SPIRV::ExecutionMode::LocalSizeHint, ST);
1950     if (F.getMetadata("intel_reqd_sub_group_size"))
1951       MAI.Reqs.getAndAddRequirements(
1952           SPIRV::OperandCategory::ExecutionModeOperand,
1953           SPIRV::ExecutionMode::SubgroupSize, ST);
1954     if (F.getMetadata("vec_type_hint"))
1955       MAI.Reqs.getAndAddRequirements(
1956           SPIRV::OperandCategory::ExecutionModeOperand,
1957           SPIRV::ExecutionMode::VecTypeHint, ST);
1958 
1959     if (F.hasOptNone()) {
1960       if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1961         MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1962         MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1963       } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
1964         MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
1965         MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
1966       }
1967     }
1968   }
1969 }
1970 
1971 static unsigned getFastMathFlags(const MachineInstr &I) {
1972   unsigned Flags = SPIRV::FPFastMathMode::None;
1973   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1974     Flags |= SPIRV::FPFastMathMode::NotNaN;
1975   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1976     Flags |= SPIRV::FPFastMathMode::NotInf;
1977   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1978     Flags |= SPIRV::FPFastMathMode::NSZ;
1979   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1980     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1981   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1982     Flags |= SPIRV::FPFastMathMode::Fast;
1983   return Flags;
1984 }
1985 
1986 static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) {
1987   if (ST.isKernel())
1988     return true;
1989   if (ST.getSPIRVVersion() < VersionTuple(1, 2))
1990     return false;
1991   return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1992 }
1993 
1994 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1995                                    const SPIRVInstrInfo &TII,
1996                                    SPIRV::RequirementHandler &Reqs) {
1997   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1998       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1999                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2000           .IsSatisfiable) {
2001     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2002                     SPIRV::Decoration::NoSignedWrap, {});
2003   }
2004   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2005       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2006                                      SPIRV::Decoration::NoUnsignedWrap, ST,
2007                                      Reqs)
2008           .IsSatisfiable) {
2009     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2010                     SPIRV::Decoration::NoUnsignedWrap, {});
2011   }
2012   if (!TII.canUseFastMathFlags(I))
2013     return;
2014   unsigned FMFlags = getFastMathFlags(I);
2015   if (FMFlags == SPIRV::FPFastMathMode::None)
2016     return;
2017 
2018   if (isFastMathMathModeAvailable(ST)) {
2019     Register DstReg = I.getOperand(0).getReg();
2020     buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2021                     {FMFlags});
2022   }
2023 }
2024 
2025 // Walk all functions and add decorations related to MI flags.
2026 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2027                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2028                            SPIRV::ModuleAnalysisInfo &MAI) {
2029   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2030     MachineFunction *MF = MMI->getMachineFunction(*F);
2031     if (!MF)
2032       continue;
2033     for (auto &MBB : *MF)
2034       for (auto &MI : MBB)
2035         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
2036   }
2037 }
2038 
2039 static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2040                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2041                         SPIRV::ModuleAnalysisInfo &MAI) {
2042   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2043     MachineFunction *MF = MMI->getMachineFunction(*F);
2044     if (!MF)
2045       continue;
2046     MachineRegisterInfo &MRI = MF->getRegInfo();
2047     for (auto &MBB : *MF) {
2048       if (!MBB.hasName() || MBB.empty())
2049         continue;
2050       // Emit basic block names.
2051       Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
2052       MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2053       buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2054       MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2055       MAI.setRegisterAlias(MF, Reg, GlobalReg);
2056     }
2057   }
2058 }
2059 
2060 // patching Instruction::PHI to SPIRV::OpPhi
2061 static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2062                       const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2063   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2064     MachineFunction *MF = MMI->getMachineFunction(*F);
2065     if (!MF)
2066       continue;
2067     for (auto &MBB : *MF) {
2068       for (MachineInstr &MI : MBB.phis()) {
2069         MI.setDesc(TII.get(SPIRV::OpPhi));
2070         Register ResTypeReg = GR->getSPIRVTypeID(
2071             GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2072         MI.insert(MI.operands_begin() + 1,
2073                   {MachineOperand::CreateReg(ResTypeReg, false)});
2074       }
2075     }
2076 
2077     MF->getProperties().setNoPHIs();
2078   }
2079 }
2080 
2081 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
2082 
2083 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
2084   AU.addRequired<TargetPassConfig>();
2085   AU.addRequired<MachineModuleInfoWrapperPass>();
2086 }
2087 
2088 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
2089   SPIRVTargetMachine &TM =
2090       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2091   ST = TM.getSubtargetImpl();
2092   GR = ST->getSPIRVGlobalRegistry();
2093   TII = ST->getInstrInfo();
2094 
2095   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
2096 
2097   setBaseInfo(M);
2098 
2099   patchPhis(M, GR, *TII, MMI);
2100 
2101   addMBBNames(M, *TII, MMI, *ST, MAI);
2102   addDecorations(M, *TII, MMI, *ST, MAI);
2103 
2104   collectReqs(M, MAI, MMI, *ST);
2105 
2106   // Process type/const/global var/func decl instructions, number their
2107   // destination registers from 0 to N, collect Extensions and Capabilities.
2108   collectReqs(M, MAI, MMI, *ST);
2109   collectDeclarations(M);
2110 
2111   // Number rest of registers from N+1 onwards.
2112   numberRegistersGlobally(M);
2113 
2114   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2115   processOtherInstrs(M);
2116 
2117   // If there are no entry points, we need the Linkage capability.
2118   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2119     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2120 
2121   // Set maximum ID used.
2122   GR->setBound(MAI.MaxID);
2123 
2124   return false;
2125 }
2126