1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/CodeGen/MachineModuleInfo.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27
28 using namespace llvm;
29
30 #define DEBUG_TYPE "spirv-module-analysis"
31
32 static cl::opt<bool>
33 SPVDumpDeps("spv-dump-deps",
34 cl::desc("Dump MIR with SPIR-V dependencies info"),
35 cl::Optional, cl::init(false));
36
37 static cl::list<SPIRV::Capability::Capability>
38 AvoidCapabilities("avoid-spirv-capabilities",
39 cl::desc("SPIR-V capabilities to avoid if there are "
40 "other options enabling a feature"),
41 cl::ZeroOrMore, cl::Hidden,
42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43 "SPIR-V Shader capability")));
44 // Use sets instead of cl::list to check "if contains" condition
45 struct AvoidCapabilitiesSet {
46 SmallSet<SPIRV::Capability::Capability, 4> S;
AvoidCapabilitiesSetAvoidCapabilitiesSet47 AvoidCapabilitiesSet() { S.insert_range(AvoidCapabilities); }
48 };
49
50 char llvm::SPIRVModuleAnalysis::ID = 0;
51
52 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
53 true)
54
55 // Retrieve an unsigned from an MDNode with a list of them as operands.
getMetadataUInt(MDNode * MdNode,unsigned OpIndex,unsigned DefaultVal=0)56 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
57 unsigned DefaultVal = 0) {
58 if (MdNode && OpIndex < MdNode->getNumOperands()) {
59 const auto &Op = MdNode->getOperand(OpIndex);
60 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
61 }
62 return DefaultVal;
63 }
64
65 static SPIRV::Requirements
getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,unsigned i,const SPIRVSubtarget & ST,SPIRV::RequirementHandler & Reqs)66 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
67 unsigned i, const SPIRVSubtarget &ST,
68 SPIRV::RequirementHandler &Reqs) {
69 // A set of capabilities to avoid if there is another option.
70 AvoidCapabilitiesSet AvoidCaps;
71 if (!ST.isShader())
72 AvoidCaps.S.insert(SPIRV::Capability::Shader);
73 else
74 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
75
76 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
77 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
78 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
79 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
80 bool MaxVerOK =
81 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
82 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
83 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
84 if (ReqCaps.empty()) {
85 if (ReqExts.empty()) {
86 if (MinVerOK && MaxVerOK)
87 return {true, {}, {}, ReqMinVer, ReqMaxVer};
88 return {false, {}, {}, VersionTuple(), VersionTuple()};
89 }
90 } else if (MinVerOK && MaxVerOK) {
91 if (ReqCaps.size() == 1) {
92 auto Cap = ReqCaps[0];
93 if (Reqs.isCapabilityAvailable(Cap)) {
94 ReqExts.append(getSymbolicOperandExtensions(
95 SPIRV::OperandCategory::CapabilityOperand, Cap));
96 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
97 }
98 } else {
99 // By SPIR-V specification: "If an instruction, enumerant, or other
100 // feature specifies multiple enabling capabilities, only one such
101 // capability needs to be declared to use the feature." However, one
102 // capability may be preferred over another. We use command line
103 // argument(s) and AvoidCapabilities to avoid selection of certain
104 // capabilities if there are other options.
105 CapabilityList UseCaps;
106 for (auto Cap : ReqCaps)
107 if (Reqs.isCapabilityAvailable(Cap))
108 UseCaps.push_back(Cap);
109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110 auto Cap = UseCaps[i];
111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
112 ReqExts.append(getSymbolicOperandExtensions(
113 SPIRV::OperandCategory::CapabilityOperand, Cap));
114 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
115 }
116 }
117 }
118 }
119 // If there are no capabilities, or we can't satisfy the version or
120 // capability requirements, use the list of extensions (if the subtarget
121 // can handle them all).
122 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
123 return ST.canUseExtension(Ext);
124 })) {
125 return {true,
126 {},
127 ReqExts,
128 VersionTuple(),
129 VersionTuple()}; // TODO: add versions to extensions.
130 }
131 return {false, {}, {}, VersionTuple(), VersionTuple()};
132 }
133
setBaseInfo(const Module & M)134 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
135 MAI.MaxID = 0;
136 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
137 MAI.MS[i].clear();
138 MAI.RegisterAliasTable.clear();
139 MAI.InstrsToDelete.clear();
140 MAI.FuncMap.clear();
141 MAI.GlobalVarList.clear();
142 MAI.ExtInstSetMap.clear();
143 MAI.Reqs.clear();
144 MAI.Reqs.initAvailableCapabilities(*ST);
145
146 // TODO: determine memory model and source language from the configuratoin.
147 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
148 auto MemMD = MemModel->getOperand(0);
149 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
150 getMetadataUInt(MemMD, 0));
151 MAI.Mem =
152 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
153 } else {
154 // TODO: Add support for VulkanMemoryModel.
155 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
156 : SPIRV::MemoryModel::OpenCL;
157 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
158 unsigned PtrSize = ST->getPointerSize();
159 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
160 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
161 : SPIRV::AddressingModel::Logical;
162 } else {
163 // TODO: Add support for PhysicalStorageBufferAddress.
164 MAI.Addr = SPIRV::AddressingModel::Logical;
165 }
166 }
167 // Get the OpenCL version number from metadata.
168 // TODO: support other source languages.
169 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
170 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
171 // Construct version literal in accordance with SPIRV-LLVM-Translator.
172 // TODO: support multiple OCL version metadata.
173 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
174 auto VersionMD = VerNode->getOperand(0);
175 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
176 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
177 unsigned RevNum = getMetadataUInt(VersionMD, 2);
178 // Prevent Major part of OpenCL version to be 0
179 MAI.SrcLangVersion =
180 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
181 } else {
182 // If there is no information about OpenCL version we are forced to generate
183 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
184 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
185 // Translator avoids potential issues with run-times in a similar manner.
186 if (!ST->isShader()) {
187 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
188 MAI.SrcLangVersion = 100000;
189 } else {
190 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
191 MAI.SrcLangVersion = 0;
192 }
193 }
194
195 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
196 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
197 MDNode *MD = ExtNode->getOperand(I);
198 if (!MD || MD->getNumOperands() == 0)
199 continue;
200 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
201 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
202 }
203 }
204
205 // Update required capabilities for this memory model, addressing model and
206 // source language.
207 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
208 MAI.Mem, *ST);
209 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
210 MAI.SrcLang, *ST);
211 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
212 MAI.Addr, *ST);
213
214 if (!ST->isShader()) {
215 // TODO: check if it's required by default.
216 MAI.ExtInstSetMap[static_cast<unsigned>(
217 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218 }
219 }
220
221 // Appends the signature of the decoration instructions that decorate R to
222 // Signature.
appendDecorationsForReg(const MachineRegisterInfo & MRI,Register R,InstrSignature & Signature)223 static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
224 InstrSignature &Signature) {
225 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
226 // We don't handle OpDecorateId because getting the register alias for the
227 // ID can cause problems, and we do not need it for now.
228 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
229 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
230 continue;
231
232 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
233 const MachineOperand &MO = UseMI.getOperand(I);
234 if (MO.isReg())
235 continue;
236 Signature.push_back(hash_value(MO));
237 }
238 }
239 }
240
241 // Returns a representation of an instruction as a vector of MachineOperand
242 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
243 // This creates a signature of the instruction with the same content
244 // that MachineOperand::isIdenticalTo uses for comparison.
instrToSignature(const MachineInstr & MI,SPIRV::ModuleAnalysisInfo & MAI,bool UseDefReg)245 static InstrSignature instrToSignature(const MachineInstr &MI,
246 SPIRV::ModuleAnalysisInfo &MAI,
247 bool UseDefReg) {
248 Register DefReg;
249 InstrSignature Signature{MI.getOpcode()};
250 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
251 const MachineOperand &MO = MI.getOperand(i);
252 size_t h;
253 if (MO.isReg()) {
254 if (!UseDefReg && MO.isDef()) {
255 assert(!DefReg.isValid() && "Multiple def registers.");
256 DefReg = MO.getReg();
257 continue;
258 }
259 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
260 if (!RegAlias.isValid()) {
261 LLVM_DEBUG({
262 dbgs() << "Unexpectedly, no global id found for the operand ";
263 MO.print(dbgs());
264 dbgs() << "\nInstruction: ";
265 MI.print(dbgs());
266 dbgs() << "\n";
267 });
268 report_fatal_error("All v-regs must have been mapped to global id's");
269 }
270 // mimic llvm::hash_value(const MachineOperand &MO)
271 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
272 MO.isDef());
273 } else {
274 h = hash_value(MO);
275 }
276 Signature.push_back(h);
277 }
278
279 if (DefReg.isValid()) {
280 // Decorations change the semantics of the current instruction. So two
281 // identical instruction with different decorations cannot be merged. That
282 // is why we add the decorations to the signature.
283 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
284 }
285 return Signature;
286 }
287
isDeclSection(const MachineRegisterInfo & MRI,const MachineInstr & MI)288 bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
289 const MachineInstr &MI) {
290 unsigned Opcode = MI.getOpcode();
291 switch (Opcode) {
292 case SPIRV::OpTypeForwardPointer:
293 // omit now, collect later
294 return false;
295 case SPIRV::OpVariable:
296 return static_cast<SPIRV::StorageClass::StorageClass>(
297 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
298 case SPIRV::OpFunction:
299 case SPIRV::OpFunctionParameter:
300 return true;
301 }
302 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
303 Register DefReg = MI.getOperand(0).getReg();
304 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
305 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
306 continue;
307 // it's a dummy definition, FP constant refers to a function,
308 // and this is resolved in another way; let's skip this definition
309 assert(UseMI.getOperand(2).isReg() &&
310 UseMI.getOperand(2).getReg() == DefReg);
311 MAI.setSkipEmission(&MI);
312 return false;
313 }
314 }
315 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
316 TII->isInlineAsmDefInstr(MI);
317 }
318
319 // This is a special case of a function pointer refering to a possibly
320 // forward function declaration. The operand is a dummy OpUndef that
321 // requires a special treatment.
visitFunPtrUse(Register OpReg,InstrGRegsMap & SignatureToGReg,std::map<const Value *,unsigned> & GlobalToGReg,const MachineFunction * MF,const MachineInstr & MI)322 void SPIRVModuleAnalysis::visitFunPtrUse(
323 Register OpReg, InstrGRegsMap &SignatureToGReg,
324 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
325 const MachineInstr &MI) {
326 const MachineOperand *OpFunDef =
327 GR->getFunctionDefinitionByUse(&MI.getOperand(2));
328 assert(OpFunDef && OpFunDef->isReg());
329 // find the actual function definition and number it globally in advance
330 const MachineInstr *OpDefMI = OpFunDef->getParent();
331 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
332 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
333 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
334 do {
335 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
336 OpDefMI = OpDefMI->getNextNode();
337 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
338 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
339 // associate the function pointer with the newly assigned global number
340 MCRegister GlobalFunDefReg =
341 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
342 assert(GlobalFunDefReg.isValid() &&
343 "Function definition must refer to a global register");
344 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
345 }
346
347 // Depth first recursive traversal of dependencies. Repeated visits are guarded
348 // by MAI.hasRegisterAlias().
visitDecl(const MachineRegisterInfo & MRI,InstrGRegsMap & SignatureToGReg,std::map<const Value *,unsigned> & GlobalToGReg,const MachineFunction * MF,const MachineInstr & MI)349 void SPIRVModuleAnalysis::visitDecl(
350 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
351 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
352 const MachineInstr &MI) {
353 unsigned Opcode = MI.getOpcode();
354
355 // Process each operand of the instruction to resolve dependencies
356 for (const MachineOperand &MO : MI.operands()) {
357 if (!MO.isReg() || MO.isDef())
358 continue;
359 Register OpReg = MO.getReg();
360 // Handle function pointers special case
361 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
362 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
363 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
364 continue;
365 }
366 // Skip already processed instructions
367 if (MAI.hasRegisterAlias(MF, MO.getReg()))
368 continue;
369 // Recursively visit dependencies
370 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
371 if (isDeclSection(MRI, *OpDefMI))
372 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
373 continue;
374 }
375 // Handle the unexpected case of no unique definition for the SPIR-V
376 // instruction
377 LLVM_DEBUG({
378 dbgs() << "Unexpectedly, no unique definition for the operand ";
379 MO.print(dbgs());
380 dbgs() << "\nInstruction: ";
381 MI.print(dbgs());
382 dbgs() << "\n";
383 });
384 report_fatal_error(
385 "No unique definition is found for the virtual register");
386 }
387
388 MCRegister GReg;
389 bool IsFunDef = false;
390 if (TII->isSpecConstantInstr(MI)) {
391 GReg = MAI.getNextIDRegister();
392 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
393 } else if (Opcode == SPIRV::OpFunction ||
394 Opcode == SPIRV::OpFunctionParameter) {
395 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
396 } else if (Opcode == SPIRV::OpTypeStruct ||
397 Opcode == SPIRV::OpConstantComposite) {
398 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
399 const MachineInstr *NextInstr = MI.getNextNode();
400 while (NextInstr &&
401 ((Opcode == SPIRV::OpTypeStruct &&
402 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
403 (Opcode == SPIRV::OpConstantComposite &&
404 NextInstr->getOpcode() ==
405 SPIRV::OpConstantCompositeContinuedINTEL))) {
406 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
407 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
408 MAI.setSkipEmission(NextInstr);
409 NextInstr = NextInstr->getNextNode();
410 }
411 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
412 TII->isInlineAsmDefInstr(MI)) {
413 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
414 } else if (Opcode == SPIRV::OpVariable) {
415 GReg = handleVariable(MF, MI, GlobalToGReg);
416 } else {
417 LLVM_DEBUG({
418 dbgs() << "\nInstruction: ";
419 MI.print(dbgs());
420 dbgs() << "\n";
421 });
422 llvm_unreachable("Unexpected instruction is visited");
423 }
424 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
425 if (!IsFunDef)
426 MAI.setSkipEmission(&MI);
427 }
428
handleFunctionOrParameter(const MachineFunction * MF,const MachineInstr & MI,std::map<const Value *,unsigned> & GlobalToGReg,bool & IsFunDef)429 MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
430 const MachineFunction *MF, const MachineInstr &MI,
431 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
432 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
433 assert(GObj && "Unregistered global definition");
434 const Function *F = dyn_cast<Function>(GObj);
435 if (!F)
436 F = dyn_cast<Argument>(GObj)->getParent();
437 assert(F && "Expected a reference to a function or an argument");
438 IsFunDef = !F->isDeclaration();
439 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
440 if (!Inserted)
441 return It->second;
442 MCRegister GReg = MAI.getNextIDRegister();
443 It->second = GReg;
444 if (!IsFunDef)
445 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
446 return GReg;
447 }
448
449 MCRegister
handleTypeDeclOrConstant(const MachineInstr & MI,InstrGRegsMap & SignatureToGReg)450 SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
451 InstrGRegsMap &SignatureToGReg) {
452 InstrSignature MISign = instrToSignature(MI, MAI, false);
453 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
454 if (!Inserted)
455 return It->second;
456 MCRegister GReg = MAI.getNextIDRegister();
457 It->second = GReg;
458 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
459 return GReg;
460 }
461
handleVariable(const MachineFunction * MF,const MachineInstr & MI,std::map<const Value *,unsigned> & GlobalToGReg)462 MCRegister SPIRVModuleAnalysis::handleVariable(
463 const MachineFunction *MF, const MachineInstr &MI,
464 std::map<const Value *, unsigned> &GlobalToGReg) {
465 MAI.GlobalVarList.push_back(&MI);
466 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
467 assert(GObj && "Unregistered global definition");
468 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
469 if (!Inserted)
470 return It->second;
471 MCRegister GReg = MAI.getNextIDRegister();
472 It->second = GReg;
473 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
474 return GReg;
475 }
476
collectDeclarations(const Module & M)477 void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
478 InstrGRegsMap SignatureToGReg;
479 std::map<const Value *, unsigned> GlobalToGReg;
480 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
481 MachineFunction *MF = MMI->getMachineFunction(*F);
482 if (!MF)
483 continue;
484 const MachineRegisterInfo &MRI = MF->getRegInfo();
485 unsigned PastHeader = 0;
486 for (MachineBasicBlock &MBB : *MF) {
487 for (MachineInstr &MI : MBB) {
488 if (MI.getNumOperands() == 0)
489 continue;
490 unsigned Opcode = MI.getOpcode();
491 if (Opcode == SPIRV::OpFunction) {
492 if (PastHeader == 0) {
493 PastHeader = 1;
494 continue;
495 }
496 } else if (Opcode == SPIRV::OpFunctionParameter) {
497 if (PastHeader < 2)
498 continue;
499 } else if (PastHeader > 0) {
500 PastHeader = 2;
501 }
502
503 const MachineOperand &DefMO = MI.getOperand(0);
504 switch (Opcode) {
505 case SPIRV::OpExtension:
506 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
507 MAI.setSkipEmission(&MI);
508 break;
509 case SPIRV::OpCapability:
510 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
511 MAI.setSkipEmission(&MI);
512 if (PastHeader > 0)
513 PastHeader = 2;
514 break;
515 default:
516 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
517 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
518 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
519 }
520 }
521 }
522 }
523 }
524
525 // Look for IDs declared with Import linkage, and map the corresponding function
526 // to the register defining that variable (which will usually be the result of
527 // an OpFunction). This lets us call externally imported functions using
528 // the correct ID registers.
collectFuncNames(MachineInstr & MI,const Function * F)529 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
530 const Function *F) {
531 if (MI.getOpcode() == SPIRV::OpDecorate) {
532 // If it's got Import linkage.
533 auto Dec = MI.getOperand(1).getImm();
534 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
535 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
536 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
537 // Map imported function name to function ID register.
538 const Function *ImportedFunc =
539 F->getParent()->getFunction(getStringImm(MI, 2));
540 Register Target = MI.getOperand(0).getReg();
541 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
542 }
543 }
544 } else if (MI.getOpcode() == SPIRV::OpFunction) {
545 // Record all internal OpFunction declarations.
546 Register Reg = MI.defs().begin()->getReg();
547 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
548 assert(GlobalReg.isValid());
549 MAI.FuncMap[F] = GlobalReg;
550 }
551 }
552
553 // Collect the given instruction in the specified MS. We assume global register
554 // numbering has already occurred by this point. We can directly compare reg
555 // arguments when detecting duplicates.
collectOtherInstr(MachineInstr & MI,SPIRV::ModuleAnalysisInfo & MAI,SPIRV::ModuleSectionType MSType,InstrTraces & IS,bool Append=true)556 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
557 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
558 bool Append = true) {
559 MAI.setSkipEmission(&MI);
560 InstrSignature MISign = instrToSignature(MI, MAI, true);
561 auto FoundMI = IS.insert(MISign);
562 if (!FoundMI.second)
563 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
564 // No duplicates, so add it.
565 if (Append)
566 MAI.MS[MSType].push_back(&MI);
567 else
568 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
569 }
570
571 // Some global instructions make reference to function-local ID regs, so cannot
572 // be correctly collected until these registers are globally numbered.
processOtherInstrs(const Module & M)573 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
574 InstrTraces IS;
575 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
576 if ((*F).isDeclaration())
577 continue;
578 MachineFunction *MF = MMI->getMachineFunction(*F);
579 assert(MF);
580
581 for (MachineBasicBlock &MBB : *MF)
582 for (MachineInstr &MI : MBB) {
583 if (MAI.getSkipEmission(&MI))
584 continue;
585 const unsigned OpCode = MI.getOpcode();
586 if (OpCode == SPIRV::OpString) {
587 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
588 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
589 MI.getOperand(2).getImm() ==
590 SPIRV::InstructionSet::
591 NonSemantic_Shader_DebugInfo_100) {
592 MachineOperand Ins = MI.getOperand(3);
593 namespace NS = SPIRV::NonSemanticExtInst;
594 static constexpr int64_t GlobalNonSemanticDITy[] = {
595 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
596 NS::DebugTypeBasic, NS::DebugTypePointer};
597 bool IsGlobalDI = false;
598 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
599 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
600 if (IsGlobalDI)
601 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
602 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
603 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
604 } else if (OpCode == SPIRV::OpEntryPoint) {
605 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
606 } else if (TII->isAliasingInstr(MI)) {
607 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
608 } else if (TII->isDecorationInstr(MI)) {
609 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
610 collectFuncNames(MI, &*F);
611 } else if (TII->isConstantInstr(MI)) {
612 // Now OpSpecConstant*s are not in DT,
613 // but they need to be collected anyway.
614 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
615 } else if (OpCode == SPIRV::OpFunction) {
616 collectFuncNames(MI, &*F);
617 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
618 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
619 }
620 }
621 }
622 }
623
624 // Number registers in all functions globally from 0 onwards and store
625 // the result in global register alias table. Some registers are already
626 // numbered.
numberRegistersGlobally(const Module & M)627 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
628 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
629 if ((*F).isDeclaration())
630 continue;
631 MachineFunction *MF = MMI->getMachineFunction(*F);
632 assert(MF);
633 for (MachineBasicBlock &MBB : *MF) {
634 for (MachineInstr &MI : MBB) {
635 for (MachineOperand &Op : MI.operands()) {
636 if (!Op.isReg())
637 continue;
638 Register Reg = Op.getReg();
639 if (MAI.hasRegisterAlias(MF, Reg))
640 continue;
641 MCRegister NewReg = MAI.getNextIDRegister();
642 MAI.setRegisterAlias(MF, Reg, NewReg);
643 }
644 if (MI.getOpcode() != SPIRV::OpExtInst)
645 continue;
646 auto Set = MI.getOperand(2).getImm();
647 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
648 if (Inserted)
649 It->second = MAI.getNextIDRegister();
650 }
651 }
652 }
653 }
654
655 // RequirementHandler implementations.
getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category,uint32_t i,const SPIRVSubtarget & ST)656 void SPIRV::RequirementHandler::getAndAddRequirements(
657 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
658 const SPIRVSubtarget &ST) {
659 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
660 }
661
recursiveAddCapabilities(const CapabilityList & ToPrune)662 void SPIRV::RequirementHandler::recursiveAddCapabilities(
663 const CapabilityList &ToPrune) {
664 for (const auto &Cap : ToPrune) {
665 AllCaps.insert(Cap);
666 CapabilityList ImplicitDecls =
667 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
668 recursiveAddCapabilities(ImplicitDecls);
669 }
670 }
671
addCapabilities(const CapabilityList & ToAdd)672 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
673 for (const auto &Cap : ToAdd) {
674 bool IsNewlyInserted = AllCaps.insert(Cap).second;
675 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
676 continue;
677 CapabilityList ImplicitDecls =
678 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
679 recursiveAddCapabilities(ImplicitDecls);
680 MinimalCaps.push_back(Cap);
681 }
682 }
683
addRequirements(const SPIRV::Requirements & Req)684 void SPIRV::RequirementHandler::addRequirements(
685 const SPIRV::Requirements &Req) {
686 if (!Req.IsSatisfiable)
687 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
688
689 if (Req.Cap.has_value())
690 addCapabilities({Req.Cap.value()});
691
692 addExtensions(Req.Exts);
693
694 if (!Req.MinVer.empty()) {
695 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
696 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
697 << " and <= " << MaxVersion << "\n");
698 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
699 }
700
701 if (MinVersion.empty() || Req.MinVer > MinVersion)
702 MinVersion = Req.MinVer;
703 }
704
705 if (!Req.MaxVer.empty()) {
706 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
707 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
708 << " and >= " << MinVersion << "\n");
709 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
710 }
711
712 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
713 MaxVersion = Req.MaxVer;
714 }
715 }
716
checkSatisfiable(const SPIRVSubtarget & ST) const717 void SPIRV::RequirementHandler::checkSatisfiable(
718 const SPIRVSubtarget &ST) const {
719 // Report as many errors as possible before aborting the compilation.
720 bool IsSatisfiable = true;
721 auto TargetVer = ST.getSPIRVVersion();
722
723 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
724 LLVM_DEBUG(
725 dbgs() << "Target SPIR-V version too high for required features\n"
726 << "Required max version: " << MaxVersion << " target version "
727 << TargetVer << "\n");
728 IsSatisfiable = false;
729 }
730
731 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
732 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
733 << "Required min version: " << MinVersion
734 << " target version " << TargetVer << "\n");
735 IsSatisfiable = false;
736 }
737
738 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
739 LLVM_DEBUG(
740 dbgs()
741 << "Version is too low for some features and too high for others.\n"
742 << "Required SPIR-V min version: " << MinVersion
743 << " required SPIR-V max version " << MaxVersion << "\n");
744 IsSatisfiable = false;
745 }
746
747 for (auto Cap : MinimalCaps) {
748 if (AvailableCaps.contains(Cap))
749 continue;
750 LLVM_DEBUG(dbgs() << "Capability not supported: "
751 << getSymbolicOperandMnemonic(
752 OperandCategory::CapabilityOperand, Cap)
753 << "\n");
754 IsSatisfiable = false;
755 }
756
757 for (auto Ext : AllExtensions) {
758 if (ST.canUseExtension(Ext))
759 continue;
760 LLVM_DEBUG(dbgs() << "Extension not supported: "
761 << getSymbolicOperandMnemonic(
762 OperandCategory::ExtensionOperand, Ext)
763 << "\n");
764 IsSatisfiable = false;
765 }
766
767 if (!IsSatisfiable)
768 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
769 }
770
771 // Add the given capabilities and all their implicitly defined capabilities too.
addAvailableCaps(const CapabilityList & ToAdd)772 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
773 for (const auto Cap : ToAdd)
774 if (AvailableCaps.insert(Cap).second)
775 addAvailableCaps(getSymbolicOperandCapabilities(
776 SPIRV::OperandCategory::CapabilityOperand, Cap));
777 }
778
removeCapabilityIf(const Capability::Capability ToRemove,const Capability::Capability IfPresent)779 void SPIRV::RequirementHandler::removeCapabilityIf(
780 const Capability::Capability ToRemove,
781 const Capability::Capability IfPresent) {
782 if (AllCaps.contains(IfPresent))
783 AllCaps.erase(ToRemove);
784 }
785
786 namespace llvm {
787 namespace SPIRV {
initAvailableCapabilities(const SPIRVSubtarget & ST)788 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
789 // Provided by both all supported Vulkan versions and OpenCl.
790 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
791 Capability::Int16});
792
793 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
794 addAvailableCaps({Capability::GroupNonUniform,
795 Capability::GroupNonUniformVote,
796 Capability::GroupNonUniformArithmetic,
797 Capability::GroupNonUniformBallot,
798 Capability::GroupNonUniformClustered,
799 Capability::GroupNonUniformShuffle,
800 Capability::GroupNonUniformShuffleRelative});
801
802 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
803 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
804 Capability::DotProductInput4x8Bit,
805 Capability::DotProductInput4x8BitPacked,
806 Capability::DemoteToHelperInvocation});
807
808 // Add capabilities enabled by extensions.
809 for (auto Extension : ST.getAllAvailableExtensions()) {
810 CapabilityList EnabledCapabilities =
811 getCapabilitiesEnabledByExtension(Extension);
812 addAvailableCaps(EnabledCapabilities);
813 }
814
815 if (!ST.isShader()) {
816 initAvailableCapabilitiesForOpenCL(ST);
817 return;
818 }
819
820 if (ST.isShader()) {
821 initAvailableCapabilitiesForVulkan(ST);
822 return;
823 }
824
825 report_fatal_error("Unimplemented environment for SPIR-V generation.");
826 }
827
initAvailableCapabilitiesForOpenCL(const SPIRVSubtarget & ST)828 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
829 const SPIRVSubtarget &ST) {
830 // Add the min requirements for different OpenCL and SPIR-V versions.
831 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
832 Capability::Kernel, Capability::Vector16,
833 Capability::Groups, Capability::GenericPointer,
834 Capability::StorageImageWriteWithoutFormat,
835 Capability::StorageImageReadWithoutFormat});
836 if (ST.hasOpenCLFullProfile())
837 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
838 if (ST.hasOpenCLImageSupport()) {
839 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
840 Capability::Image1D, Capability::SampledBuffer,
841 Capability::ImageBuffer});
842 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
843 addAvailableCaps({Capability::ImageReadWrite});
844 }
845 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
846 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
847 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
848 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
849 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
850 Capability::SignedZeroInfNanPreserve,
851 Capability::RoundingModeRTE,
852 Capability::RoundingModeRTZ});
853 // TODO: verify if this needs some checks.
854 addAvailableCaps({Capability::Float16, Capability::Float64});
855
856 // TODO: add OpenCL extensions.
857 }
858
initAvailableCapabilitiesForVulkan(const SPIRVSubtarget & ST)859 void RequirementHandler::initAvailableCapabilitiesForVulkan(
860 const SPIRVSubtarget &ST) {
861
862 // Core in Vulkan 1.1 and earlier.
863 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
864 Capability::GroupNonUniform, Capability::Image1D,
865 Capability::SampledBuffer, Capability::ImageBuffer,
866 Capability::UniformBufferArrayDynamicIndexing,
867 Capability::SampledImageArrayDynamicIndexing,
868 Capability::StorageBufferArrayDynamicIndexing,
869 Capability::StorageImageArrayDynamicIndexing});
870
871 // Became core in Vulkan 1.2
872 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
873 addAvailableCaps(
874 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
875 Capability::InputAttachmentArrayDynamicIndexingEXT,
876 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
877 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
878 Capability::UniformBufferArrayNonUniformIndexingEXT,
879 Capability::SampledImageArrayNonUniformIndexingEXT,
880 Capability::StorageBufferArrayNonUniformIndexingEXT,
881 Capability::StorageImageArrayNonUniformIndexingEXT,
882 Capability::InputAttachmentArrayNonUniformIndexingEXT,
883 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
884 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
885 }
886
887 // Became core in Vulkan 1.3
888 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
889 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
890 Capability::StorageImageReadWithoutFormat});
891 }
892
893 } // namespace SPIRV
894 } // namespace llvm
895
896 // Add the required capabilities from a decoration instruction (including
897 // BuiltIns).
addOpDecorateReqs(const MachineInstr & MI,unsigned DecIndex,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)898 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
899 SPIRV::RequirementHandler &Reqs,
900 const SPIRVSubtarget &ST) {
901 int64_t DecOp = MI.getOperand(DecIndex).getImm();
902 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
903 Reqs.addRequirements(getSymbolicOperandRequirements(
904 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
905
906 if (Dec == SPIRV::Decoration::BuiltIn) {
907 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
908 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
909 Reqs.addRequirements(getSymbolicOperandRequirements(
910 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
911 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
912 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
913 SPIRV::LinkageType::LinkageType LnkType =
914 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
915 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
916 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
917 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
918 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
919 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
920 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
921 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
922 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
923 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
924 Reqs.addExtension(
925 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
926 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
927 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
928 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
929 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
930 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
931 }
932 }
933
934 // Add requirements for image handling.
addOpTypeImageReqs(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)935 static void addOpTypeImageReqs(const MachineInstr &MI,
936 SPIRV::RequirementHandler &Reqs,
937 const SPIRVSubtarget &ST) {
938 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
939 // The operand indices used here are based on the OpTypeImage layout, which
940 // the MachineInstr follows as well.
941 int64_t ImgFormatOp = MI.getOperand(7).getImm();
942 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
943 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
944 ImgFormat, ST);
945
946 bool IsArrayed = MI.getOperand(4).getImm() == 1;
947 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
948 bool NoSampler = MI.getOperand(6).getImm() == 2;
949 // Add dimension requirements.
950 assert(MI.getOperand(2).isImm());
951 switch (MI.getOperand(2).getImm()) {
952 case SPIRV::Dim::DIM_1D:
953 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
954 : SPIRV::Capability::Sampled1D);
955 break;
956 case SPIRV::Dim::DIM_2D:
957 if (IsMultisampled && NoSampler)
958 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
959 break;
960 case SPIRV::Dim::DIM_Cube:
961 Reqs.addRequirements(SPIRV::Capability::Shader);
962 if (IsArrayed)
963 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
964 : SPIRV::Capability::SampledCubeArray);
965 break;
966 case SPIRV::Dim::DIM_Rect:
967 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
968 : SPIRV::Capability::SampledRect);
969 break;
970 case SPIRV::Dim::DIM_Buffer:
971 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
972 : SPIRV::Capability::SampledBuffer);
973 break;
974 case SPIRV::Dim::DIM_SubpassData:
975 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
976 break;
977 }
978
979 // Has optional access qualifier.
980 if (!ST.isShader()) {
981 if (MI.getNumOperands() > 8 &&
982 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
983 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
984 else
985 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
986 }
987 }
988
989 // Add requirements for handling atomic float instructions
990 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \
991 "The atomic float instruction requires the following SPIR-V " \
992 "extension: SPV_EXT_shader_atomic_float" ExtName
AddAtomicFloatRequirements(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)993 static void AddAtomicFloatRequirements(const MachineInstr &MI,
994 SPIRV::RequirementHandler &Reqs,
995 const SPIRVSubtarget &ST) {
996 assert(MI.getOperand(1).isReg() &&
997 "Expect register operand in atomic float instruction");
998 Register TypeReg = MI.getOperand(1).getReg();
999 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1000 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1001 report_fatal_error("Result type of an atomic float instruction must be a "
1002 "floating-point type scalar");
1003
1004 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1005 unsigned Op = MI.getOpcode();
1006 if (Op == SPIRV::OpAtomicFAddEXT) {
1007 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1008 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
1009 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1010 switch (BitWidth) {
1011 case 16:
1012 if (!ST.canUseExtension(
1013 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1014 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1015 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1016 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1017 break;
1018 case 32:
1019 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1020 break;
1021 case 64:
1022 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1023 break;
1024 default:
1025 report_fatal_error(
1026 "Unexpected floating-point type width in atomic float instruction");
1027 }
1028 } else {
1029 if (!ST.canUseExtension(
1030 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1031 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1032 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1033 switch (BitWidth) {
1034 case 16:
1035 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1036 break;
1037 case 32:
1038 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1039 break;
1040 case 64:
1041 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1042 break;
1043 default:
1044 report_fatal_error(
1045 "Unexpected floating-point type width in atomic float instruction");
1046 }
1047 }
1048 }
1049
isUniformTexelBuffer(MachineInstr * ImageInst)1050 bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1051 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1052 return false;
1053 uint32_t Dim = ImageInst->getOperand(2).getImm();
1054 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1055 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1056 }
1057
isStorageTexelBuffer(MachineInstr * ImageInst)1058 bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1059 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1060 return false;
1061 uint32_t Dim = ImageInst->getOperand(2).getImm();
1062 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1063 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1064 }
1065
isSampledImage(MachineInstr * ImageInst)1066 bool isSampledImage(MachineInstr *ImageInst) {
1067 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1068 return false;
1069 uint32_t Dim = ImageInst->getOperand(2).getImm();
1070 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1071 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1072 }
1073
isInputAttachment(MachineInstr * ImageInst)1074 bool isInputAttachment(MachineInstr *ImageInst) {
1075 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1076 return false;
1077 uint32_t Dim = ImageInst->getOperand(2).getImm();
1078 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1079 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1080 }
1081
isStorageImage(MachineInstr * ImageInst)1082 bool isStorageImage(MachineInstr *ImageInst) {
1083 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1084 return false;
1085 uint32_t Dim = ImageInst->getOperand(2).getImm();
1086 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1087 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1088 }
1089
isCombinedImageSampler(MachineInstr * SampledImageInst)1090 bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1091 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1092 return false;
1093
1094 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1095 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1096 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1097 return isSampledImage(ImageInst);
1098 }
1099
hasNonUniformDecoration(Register Reg,const MachineRegisterInfo & MRI)1100 bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1101 for (const auto &MI : MRI.reg_instructions(Reg)) {
1102 if (MI.getOpcode() != SPIRV::OpDecorate)
1103 continue;
1104
1105 uint32_t Dec = MI.getOperand(1).getImm();
1106 if (Dec == SPIRV::Decoration::NonUniformEXT)
1107 return true;
1108 }
1109 return false;
1110 }
1111
addOpAccessChainReqs(const MachineInstr & Instr,SPIRV::RequirementHandler & Handler,const SPIRVSubtarget & Subtarget)1112 void addOpAccessChainReqs(const MachineInstr &Instr,
1113 SPIRV::RequirementHandler &Handler,
1114 const SPIRVSubtarget &Subtarget) {
1115 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1116 // Get the result type. If it is an image type, then the shader uses
1117 // descriptor indexing. The appropriate capabilities will be added based
1118 // on the specifics of the image.
1119 Register ResTypeReg = Instr.getOperand(1).getReg();
1120 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1121
1122 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1123 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1124 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1125 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1126 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1127 return;
1128 }
1129
1130 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1131 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1132 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1133 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1134 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1135 return;
1136 }
1137
1138 bool IsNonUniform =
1139 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1140 if (isUniformTexelBuffer(PointeeType)) {
1141 if (IsNonUniform)
1142 Handler.addRequirements(
1143 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1144 else
1145 Handler.addRequirements(
1146 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1147 } else if (isInputAttachment(PointeeType)) {
1148 if (IsNonUniform)
1149 Handler.addRequirements(
1150 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1151 else
1152 Handler.addRequirements(
1153 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1154 } else if (isStorageTexelBuffer(PointeeType)) {
1155 if (IsNonUniform)
1156 Handler.addRequirements(
1157 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1158 else
1159 Handler.addRequirements(
1160 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1161 } else if (isSampledImage(PointeeType) ||
1162 isCombinedImageSampler(PointeeType) ||
1163 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1164 if (IsNonUniform)
1165 Handler.addRequirements(
1166 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1167 else
1168 Handler.addRequirements(
1169 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1170 } else if (isStorageImage(PointeeType)) {
1171 if (IsNonUniform)
1172 Handler.addRequirements(
1173 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1174 else
1175 Handler.addRequirements(
1176 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1177 }
1178 }
1179
isImageTypeWithUnknownFormat(SPIRVType * TypeInst)1180 static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1181 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1182 return false;
1183 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1184 return TypeInst->getOperand(7).getImm() == 0;
1185 }
1186
AddDotProductRequirements(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)1187 static void AddDotProductRequirements(const MachineInstr &MI,
1188 SPIRV::RequirementHandler &Reqs,
1189 const SPIRVSubtarget &ST) {
1190 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1191 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1192 Reqs.addCapability(SPIRV::Capability::DotProduct);
1193
1194 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1195 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1196 // We do not consider what the previous instruction is. This is just used
1197 // to get the input register and to check the type.
1198 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1199 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1200 Register InputReg = Input->getOperand(1).getReg();
1201
1202 SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1203 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1204 assert(TypeDef->getOperand(1).getImm() == 32);
1205 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1206 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1207 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1208 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1209 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1210 assert(TypeDef->getOperand(2).getImm() == 4 &&
1211 "Dot operand of 8-bit integer type requires 4 components");
1212 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1213 } else {
1214 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1215 }
1216 }
1217 }
1218
addInstrRequirements(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)1219 void addInstrRequirements(const MachineInstr &MI,
1220 SPIRV::RequirementHandler &Reqs,
1221 const SPIRVSubtarget &ST) {
1222 switch (MI.getOpcode()) {
1223 case SPIRV::OpMemoryModel: {
1224 int64_t Addr = MI.getOperand(0).getImm();
1225 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1226 Addr, ST);
1227 int64_t Mem = MI.getOperand(1).getImm();
1228 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1229 ST);
1230 break;
1231 }
1232 case SPIRV::OpEntryPoint: {
1233 int64_t Exe = MI.getOperand(0).getImm();
1234 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1235 Exe, ST);
1236 break;
1237 }
1238 case SPIRV::OpExecutionMode:
1239 case SPIRV::OpExecutionModeId: {
1240 int64_t Exe = MI.getOperand(1).getImm();
1241 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1242 Exe, ST);
1243 break;
1244 }
1245 case SPIRV::OpTypeMatrix:
1246 Reqs.addCapability(SPIRV::Capability::Matrix);
1247 break;
1248 case SPIRV::OpTypeInt: {
1249 unsigned BitWidth = MI.getOperand(1).getImm();
1250 if (BitWidth == 64)
1251 Reqs.addCapability(SPIRV::Capability::Int64);
1252 else if (BitWidth == 16)
1253 Reqs.addCapability(SPIRV::Capability::Int16);
1254 else if (BitWidth == 8)
1255 Reqs.addCapability(SPIRV::Capability::Int8);
1256 break;
1257 }
1258 case SPIRV::OpTypeFloat: {
1259 unsigned BitWidth = MI.getOperand(1).getImm();
1260 if (BitWidth == 64)
1261 Reqs.addCapability(SPIRV::Capability::Float64);
1262 else if (BitWidth == 16)
1263 Reqs.addCapability(SPIRV::Capability::Float16);
1264 break;
1265 }
1266 case SPIRV::OpTypeVector: {
1267 unsigned NumComponents = MI.getOperand(2).getImm();
1268 if (NumComponents == 8 || NumComponents == 16)
1269 Reqs.addCapability(SPIRV::Capability::Vector16);
1270 break;
1271 }
1272 case SPIRV::OpTypePointer: {
1273 auto SC = MI.getOperand(1).getImm();
1274 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1275 ST);
1276 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1277 // capability.
1278 if (ST.isShader())
1279 break;
1280 assert(MI.getOperand(2).isReg());
1281 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1282 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1283 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1284 TypeDef->getOperand(1).getImm() == 16)
1285 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1286 break;
1287 }
1288 case SPIRV::OpExtInst: {
1289 if (MI.getOperand(2).getImm() ==
1290 static_cast<int64_t>(
1291 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1292 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1293 }
1294 break;
1295 }
1296 case SPIRV::OpAliasDomainDeclINTEL:
1297 case SPIRV::OpAliasScopeDeclINTEL:
1298 case SPIRV::OpAliasScopeListDeclINTEL: {
1299 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1300 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1301 break;
1302 }
1303 case SPIRV::OpBitReverse:
1304 case SPIRV::OpBitFieldInsert:
1305 case SPIRV::OpBitFieldSExtract:
1306 case SPIRV::OpBitFieldUExtract:
1307 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1308 Reqs.addCapability(SPIRV::Capability::Shader);
1309 break;
1310 }
1311 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1312 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1313 break;
1314 case SPIRV::OpTypeRuntimeArray:
1315 Reqs.addCapability(SPIRV::Capability::Shader);
1316 break;
1317 case SPIRV::OpTypeOpaque:
1318 case SPIRV::OpTypeEvent:
1319 Reqs.addCapability(SPIRV::Capability::Kernel);
1320 break;
1321 case SPIRV::OpTypePipe:
1322 case SPIRV::OpTypeReserveId:
1323 Reqs.addCapability(SPIRV::Capability::Pipes);
1324 break;
1325 case SPIRV::OpTypeDeviceEvent:
1326 case SPIRV::OpTypeQueue:
1327 case SPIRV::OpBuildNDRange:
1328 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1329 break;
1330 case SPIRV::OpDecorate:
1331 case SPIRV::OpDecorateId:
1332 case SPIRV::OpDecorateString:
1333 addOpDecorateReqs(MI, 1, Reqs, ST);
1334 break;
1335 case SPIRV::OpMemberDecorate:
1336 case SPIRV::OpMemberDecorateString:
1337 addOpDecorateReqs(MI, 2, Reqs, ST);
1338 break;
1339 case SPIRV::OpInBoundsPtrAccessChain:
1340 Reqs.addCapability(SPIRV::Capability::Addresses);
1341 break;
1342 case SPIRV::OpConstantSampler:
1343 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1344 break;
1345 case SPIRV::OpInBoundsAccessChain:
1346 case SPIRV::OpAccessChain:
1347 addOpAccessChainReqs(MI, Reqs, ST);
1348 break;
1349 case SPIRV::OpTypeImage:
1350 addOpTypeImageReqs(MI, Reqs, ST);
1351 break;
1352 case SPIRV::OpTypeSampler:
1353 if (!ST.isShader()) {
1354 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1355 }
1356 break;
1357 case SPIRV::OpTypeForwardPointer:
1358 // TODO: check if it's OpenCL's kernel.
1359 Reqs.addCapability(SPIRV::Capability::Addresses);
1360 break;
1361 case SPIRV::OpAtomicFlagTestAndSet:
1362 case SPIRV::OpAtomicLoad:
1363 case SPIRV::OpAtomicStore:
1364 case SPIRV::OpAtomicExchange:
1365 case SPIRV::OpAtomicCompareExchange:
1366 case SPIRV::OpAtomicIIncrement:
1367 case SPIRV::OpAtomicIDecrement:
1368 case SPIRV::OpAtomicIAdd:
1369 case SPIRV::OpAtomicISub:
1370 case SPIRV::OpAtomicUMin:
1371 case SPIRV::OpAtomicUMax:
1372 case SPIRV::OpAtomicSMin:
1373 case SPIRV::OpAtomicSMax:
1374 case SPIRV::OpAtomicAnd:
1375 case SPIRV::OpAtomicOr:
1376 case SPIRV::OpAtomicXor: {
1377 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1378 const MachineInstr *InstrPtr = &MI;
1379 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1380 assert(MI.getOperand(3).isReg());
1381 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1382 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1383 }
1384 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1385 Register TypeReg = InstrPtr->getOperand(1).getReg();
1386 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1387 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1388 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1389 if (BitWidth == 64)
1390 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1391 }
1392 break;
1393 }
1394 case SPIRV::OpGroupNonUniformIAdd:
1395 case SPIRV::OpGroupNonUniformFAdd:
1396 case SPIRV::OpGroupNonUniformIMul:
1397 case SPIRV::OpGroupNonUniformFMul:
1398 case SPIRV::OpGroupNonUniformSMin:
1399 case SPIRV::OpGroupNonUniformUMin:
1400 case SPIRV::OpGroupNonUniformFMin:
1401 case SPIRV::OpGroupNonUniformSMax:
1402 case SPIRV::OpGroupNonUniformUMax:
1403 case SPIRV::OpGroupNonUniformFMax:
1404 case SPIRV::OpGroupNonUniformBitwiseAnd:
1405 case SPIRV::OpGroupNonUniformBitwiseOr:
1406 case SPIRV::OpGroupNonUniformBitwiseXor:
1407 case SPIRV::OpGroupNonUniformLogicalAnd:
1408 case SPIRV::OpGroupNonUniformLogicalOr:
1409 case SPIRV::OpGroupNonUniformLogicalXor: {
1410 assert(MI.getOperand(3).isImm());
1411 int64_t GroupOp = MI.getOperand(3).getImm();
1412 switch (GroupOp) {
1413 case SPIRV::GroupOperation::Reduce:
1414 case SPIRV::GroupOperation::InclusiveScan:
1415 case SPIRV::GroupOperation::ExclusiveScan:
1416 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1417 break;
1418 case SPIRV::GroupOperation::ClusteredReduce:
1419 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1420 break;
1421 case SPIRV::GroupOperation::PartitionedReduceNV:
1422 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1423 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1424 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1425 break;
1426 }
1427 break;
1428 }
1429 case SPIRV::OpGroupNonUniformShuffle:
1430 case SPIRV::OpGroupNonUniformShuffleXor:
1431 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1432 break;
1433 case SPIRV::OpGroupNonUniformShuffleUp:
1434 case SPIRV::OpGroupNonUniformShuffleDown:
1435 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1436 break;
1437 case SPIRV::OpGroupAll:
1438 case SPIRV::OpGroupAny:
1439 case SPIRV::OpGroupBroadcast:
1440 case SPIRV::OpGroupIAdd:
1441 case SPIRV::OpGroupFAdd:
1442 case SPIRV::OpGroupFMin:
1443 case SPIRV::OpGroupUMin:
1444 case SPIRV::OpGroupSMin:
1445 case SPIRV::OpGroupFMax:
1446 case SPIRV::OpGroupUMax:
1447 case SPIRV::OpGroupSMax:
1448 Reqs.addCapability(SPIRV::Capability::Groups);
1449 break;
1450 case SPIRV::OpGroupNonUniformElect:
1451 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1452 break;
1453 case SPIRV::OpGroupNonUniformAll:
1454 case SPIRV::OpGroupNonUniformAny:
1455 case SPIRV::OpGroupNonUniformAllEqual:
1456 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1457 break;
1458 case SPIRV::OpGroupNonUniformBroadcast:
1459 case SPIRV::OpGroupNonUniformBroadcastFirst:
1460 case SPIRV::OpGroupNonUniformBallot:
1461 case SPIRV::OpGroupNonUniformInverseBallot:
1462 case SPIRV::OpGroupNonUniformBallotBitExtract:
1463 case SPIRV::OpGroupNonUniformBallotBitCount:
1464 case SPIRV::OpGroupNonUniformBallotFindLSB:
1465 case SPIRV::OpGroupNonUniformBallotFindMSB:
1466 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1467 break;
1468 case SPIRV::OpSubgroupShuffleINTEL:
1469 case SPIRV::OpSubgroupShuffleDownINTEL:
1470 case SPIRV::OpSubgroupShuffleUpINTEL:
1471 case SPIRV::OpSubgroupShuffleXorINTEL:
1472 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1473 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1474 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1475 }
1476 break;
1477 case SPIRV::OpSubgroupBlockReadINTEL:
1478 case SPIRV::OpSubgroupBlockWriteINTEL:
1479 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1480 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1481 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1482 }
1483 break;
1484 case SPIRV::OpSubgroupImageBlockReadINTEL:
1485 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1486 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1487 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1488 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1489 }
1490 break;
1491 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1492 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1493 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1494 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1495 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1496 }
1497 break;
1498 case SPIRV::OpAssumeTrueKHR:
1499 case SPIRV::OpExpectKHR:
1500 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1501 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1502 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1503 }
1504 break;
1505 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1506 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1507 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1508 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1509 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1510 }
1511 break;
1512 case SPIRV::OpConstantFunctionPointerINTEL:
1513 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1514 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1515 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1516 }
1517 break;
1518 case SPIRV::OpGroupNonUniformRotateKHR:
1519 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1520 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1521 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1522 false);
1523 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1524 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1525 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1526 break;
1527 case SPIRV::OpGroupIMulKHR:
1528 case SPIRV::OpGroupFMulKHR:
1529 case SPIRV::OpGroupBitwiseAndKHR:
1530 case SPIRV::OpGroupBitwiseOrKHR:
1531 case SPIRV::OpGroupBitwiseXorKHR:
1532 case SPIRV::OpGroupLogicalAndKHR:
1533 case SPIRV::OpGroupLogicalOrKHR:
1534 case SPIRV::OpGroupLogicalXorKHR:
1535 if (ST.canUseExtension(
1536 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1537 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1538 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1539 }
1540 break;
1541 case SPIRV::OpReadClockKHR:
1542 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1543 report_fatal_error("OpReadClockKHR instruction requires the "
1544 "following SPIR-V extension: SPV_KHR_shader_clock",
1545 false);
1546 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1547 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1548 break;
1549 case SPIRV::OpFunctionPointerCallINTEL:
1550 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1551 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1552 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1553 }
1554 break;
1555 case SPIRV::OpAtomicFAddEXT:
1556 case SPIRV::OpAtomicFMinEXT:
1557 case SPIRV::OpAtomicFMaxEXT:
1558 AddAtomicFloatRequirements(MI, Reqs, ST);
1559 break;
1560 case SPIRV::OpConvertBF16ToFINTEL:
1561 case SPIRV::OpConvertFToBF16INTEL:
1562 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1563 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1564 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1565 }
1566 break;
1567 case SPIRV::OpVariableLengthArrayINTEL:
1568 case SPIRV::OpSaveMemoryINTEL:
1569 case SPIRV::OpRestoreMemoryINTEL:
1570 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1571 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1572 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1573 }
1574 break;
1575 case SPIRV::OpAsmTargetINTEL:
1576 case SPIRV::OpAsmINTEL:
1577 case SPIRV::OpAsmCallINTEL:
1578 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1579 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1580 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1581 }
1582 break;
1583 case SPIRV::OpTypeCooperativeMatrixKHR:
1584 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1585 report_fatal_error(
1586 "OpTypeCooperativeMatrixKHR type requires the "
1587 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1588 false);
1589 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1590 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1591 break;
1592 case SPIRV::OpArithmeticFenceEXT:
1593 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1594 report_fatal_error("OpArithmeticFenceEXT requires the "
1595 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1596 false);
1597 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1598 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1599 break;
1600 case SPIRV::OpControlBarrierArriveINTEL:
1601 case SPIRV::OpControlBarrierWaitINTEL:
1602 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1603 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1604 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1605 }
1606 break;
1607 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1608 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1609 report_fatal_error("Cooperative matrix instructions require the "
1610 "following SPIR-V extension: "
1611 "SPV_KHR_cooperative_matrix",
1612 false);
1613 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1614 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1615 constexpr unsigned MulAddMaxSize = 6;
1616 if (MI.getNumOperands() != MulAddMaxSize)
1617 break;
1618 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1619 if (CoopOperands &
1620 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1621 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1622 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1623 "require the following SPIR-V extension: "
1624 "SPV_INTEL_joint_matrix",
1625 false);
1626 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1627 Reqs.addCapability(
1628 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1629 }
1630 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1631 MatrixAAndBBFloat16ComponentsINTEL ||
1632 CoopOperands &
1633 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1634 CoopOperands & SPIRV::CooperativeMatrixOperands::
1635 MatrixResultBFloat16ComponentsINTEL) {
1636 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1637 report_fatal_error("***BF16ComponentsINTEL type interpretations "
1638 "require the following SPIR-V extension: "
1639 "SPV_INTEL_joint_matrix",
1640 false);
1641 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1642 Reqs.addCapability(
1643 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1644 }
1645 break;
1646 }
1647 case SPIRV::OpCooperativeMatrixLoadKHR:
1648 case SPIRV::OpCooperativeMatrixStoreKHR:
1649 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1650 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1651 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1652 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1653 report_fatal_error("Cooperative matrix instructions require the "
1654 "following SPIR-V extension: "
1655 "SPV_KHR_cooperative_matrix",
1656 false);
1657 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1658 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1659
1660 // Check Layout operand in case if it's not a standard one and add the
1661 // appropriate capability.
1662 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1663 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1664 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1665 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1666 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1667 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1668
1669 const auto OpCode = MI.getOpcode();
1670 const unsigned LayoutNum = LayoutToInstMap[OpCode];
1671 Register RegLayout = MI.getOperand(LayoutNum).getReg();
1672 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1673 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1674 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1675 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1676 if (LayoutVal ==
1677 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1678 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1679 report_fatal_error("PackedINTEL layout require the following SPIR-V "
1680 "extension: SPV_INTEL_joint_matrix",
1681 false);
1682 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1683 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1684 }
1685 }
1686
1687 // Nothing to do.
1688 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1689 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1690 break;
1691
1692 std::string InstName;
1693 switch (OpCode) {
1694 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1695 InstName = "OpCooperativeMatrixPrefetchINTEL";
1696 break;
1697 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1698 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1699 break;
1700 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1701 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1702 break;
1703 }
1704
1705 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1706 const std::string ErrorMsg =
1707 InstName + " instruction requires the "
1708 "following SPIR-V extension: SPV_INTEL_joint_matrix";
1709 report_fatal_error(ErrorMsg.c_str(), false);
1710 }
1711 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1712 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1713 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1714 break;
1715 }
1716 Reqs.addCapability(
1717 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1718 break;
1719 }
1720 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1721 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1722 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1723 "instructions require the following SPIR-V extension: "
1724 "SPV_INTEL_joint_matrix",
1725 false);
1726 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1727 Reqs.addCapability(
1728 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1729 break;
1730 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1731 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1732 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1733 "following SPIR-V extension: SPV_INTEL_joint_matrix",
1734 false);
1735 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1736 Reqs.addCapability(
1737 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1738 break;
1739 case SPIRV::OpConvertHandleToImageINTEL:
1740 case SPIRV::OpConvertHandleToSamplerINTEL:
1741 case SPIRV::OpConvertHandleToSampledImageINTEL:
1742 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
1743 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
1744 "instructions require the following SPIR-V extension: "
1745 "SPV_INTEL_bindless_images",
1746 false);
1747 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
1748 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
1749 break;
1750 case SPIRV::OpSubgroup2DBlockLoadINTEL:
1751 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
1752 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
1753 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
1754 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
1755 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
1756 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
1757 "Prefetch/Store]INTEL instructions require the "
1758 "following SPIR-V extension: SPV_INTEL_2d_block_io",
1759 false);
1760 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
1761 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
1762
1763 const auto OpCode = MI.getOpcode();
1764 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
1765 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
1766 break;
1767 }
1768 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
1769 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
1770 break;
1771 }
1772 break;
1773 }
1774 case SPIRV::OpKill: {
1775 Reqs.addCapability(SPIRV::Capability::Shader);
1776 } break;
1777 case SPIRV::OpDemoteToHelperInvocation:
1778 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1779
1780 if (ST.canUseExtension(
1781 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1782 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
1783 Reqs.addExtension(
1784 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
1785 }
1786 break;
1787 case SPIRV::OpSDot:
1788 case SPIRV::OpUDot:
1789 case SPIRV::OpSUDot:
1790 case SPIRV::OpSDotAccSat:
1791 case SPIRV::OpUDotAccSat:
1792 case SPIRV::OpSUDotAccSat:
1793 AddDotProductRequirements(MI, Reqs, ST);
1794 break;
1795 case SPIRV::OpImageRead: {
1796 Register ImageReg = MI.getOperand(2).getReg();
1797 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1798 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1799 // OpImageRead and OpImageWrite can use Unknown Image Formats
1800 // when the Kernel capability is declared. In the OpenCL environment we are
1801 // not allowed to produce
1802 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1803 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1804
1805 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1806 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
1807 break;
1808 }
1809 case SPIRV::OpImageWrite: {
1810 Register ImageReg = MI.getOperand(0).getReg();
1811 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
1812 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
1813 // OpImageRead and OpImageWrite can use Unknown Image Formats
1814 // when the Kernel capability is declared. In the OpenCL environment we are
1815 // not allowed to produce
1816 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
1817 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
1818
1819 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
1820 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
1821 break;
1822 }
1823 case SPIRV::OpTypeStructContinuedINTEL:
1824 case SPIRV::OpConstantCompositeContinuedINTEL:
1825 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
1826 case SPIRV::OpCompositeConstructContinuedINTEL: {
1827 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
1828 report_fatal_error(
1829 "Continued instructions require the "
1830 "following SPIR-V extension: SPV_INTEL_long_composites",
1831 false);
1832 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
1833 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
1834 break;
1835 }
1836 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
1837 if (!ST.canUseExtension(
1838 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
1839 report_fatal_error(
1840 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
1841 "following SPIR-V "
1842 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
1843 false);
1844 Reqs.addExtension(
1845 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
1846 Reqs.addCapability(
1847 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
1848 break;
1849 }
1850 case SPIRV::OpBitwiseFunctionINTEL: {
1851 if (!ST.canUseExtension(
1852 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
1853 report_fatal_error(
1854 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
1855 "extension: SPV_INTEL_ternary_bitwise_function",
1856 false);
1857 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
1858 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
1859 break;
1860 }
1861
1862 default:
1863 break;
1864 }
1865
1866 // If we require capability Shader, then we can remove the requirement for
1867 // the BitInstructions capability, since Shader is a superset capability
1868 // of BitInstructions.
1869 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1870 SPIRV::Capability::Shader);
1871 }
1872
collectReqs(const Module & M,SPIRV::ModuleAnalysisInfo & MAI,MachineModuleInfo * MMI,const SPIRVSubtarget & ST)1873 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1874 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1875 // Collect requirements for existing instructions.
1876 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1877 MachineFunction *MF = MMI->getMachineFunction(*F);
1878 if (!MF)
1879 continue;
1880 for (const MachineBasicBlock &MBB : *MF)
1881 for (const MachineInstr &MI : MBB)
1882 addInstrRequirements(MI, MAI.Reqs, ST);
1883 }
1884 // Collect requirements for OpExecutionMode instructions.
1885 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1886 if (Node) {
1887 bool RequireFloatControls = false, RequireFloatControls2 = false,
1888 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1889 bool HasFloatControls2 =
1890 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
1891 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1892 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1893 const MDOperand &MDOp = MDN->getOperand(1);
1894 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1895 Constant *C = CMeta->getValue();
1896 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1897 auto EM = Const->getZExtValue();
1898 // SPV_KHR_float_controls is not available until v1.4:
1899 // add SPV_KHR_float_controls if the version is too low
1900 switch (EM) {
1901 case SPIRV::ExecutionMode::DenormPreserve:
1902 case SPIRV::ExecutionMode::DenormFlushToZero:
1903 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1904 case SPIRV::ExecutionMode::RoundingModeRTE:
1905 case SPIRV::ExecutionMode::RoundingModeRTZ:
1906 RequireFloatControls = VerLower14;
1907 MAI.Reqs.getAndAddRequirements(
1908 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1909 break;
1910 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
1911 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
1912 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
1913 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
1914 if (HasFloatControls2) {
1915 RequireFloatControls2 = true;
1916 MAI.Reqs.getAndAddRequirements(
1917 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1918 }
1919 break;
1920 default:
1921 MAI.Reqs.getAndAddRequirements(
1922 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1923 }
1924 }
1925 }
1926 }
1927 if (RequireFloatControls &&
1928 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1929 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1930 if (RequireFloatControls2)
1931 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
1932 }
1933 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1934 const Function &F = *FI;
1935 if (F.isDeclaration())
1936 continue;
1937 if (F.getMetadata("reqd_work_group_size"))
1938 MAI.Reqs.getAndAddRequirements(
1939 SPIRV::OperandCategory::ExecutionModeOperand,
1940 SPIRV::ExecutionMode::LocalSize, ST);
1941 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1942 MAI.Reqs.getAndAddRequirements(
1943 SPIRV::OperandCategory::ExecutionModeOperand,
1944 SPIRV::ExecutionMode::LocalSize, ST);
1945 }
1946 if (F.getMetadata("work_group_size_hint"))
1947 MAI.Reqs.getAndAddRequirements(
1948 SPIRV::OperandCategory::ExecutionModeOperand,
1949 SPIRV::ExecutionMode::LocalSizeHint, ST);
1950 if (F.getMetadata("intel_reqd_sub_group_size"))
1951 MAI.Reqs.getAndAddRequirements(
1952 SPIRV::OperandCategory::ExecutionModeOperand,
1953 SPIRV::ExecutionMode::SubgroupSize, ST);
1954 if (F.getMetadata("vec_type_hint"))
1955 MAI.Reqs.getAndAddRequirements(
1956 SPIRV::OperandCategory::ExecutionModeOperand,
1957 SPIRV::ExecutionMode::VecTypeHint, ST);
1958
1959 if (F.hasOptNone()) {
1960 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1961 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1962 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1963 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
1964 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
1965 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
1966 }
1967 }
1968 }
1969 }
1970
getFastMathFlags(const MachineInstr & I)1971 static unsigned getFastMathFlags(const MachineInstr &I) {
1972 unsigned Flags = SPIRV::FPFastMathMode::None;
1973 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1974 Flags |= SPIRV::FPFastMathMode::NotNaN;
1975 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1976 Flags |= SPIRV::FPFastMathMode::NotInf;
1977 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1978 Flags |= SPIRV::FPFastMathMode::NSZ;
1979 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1980 Flags |= SPIRV::FPFastMathMode::AllowRecip;
1981 if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1982 Flags |= SPIRV::FPFastMathMode::Fast;
1983 return Flags;
1984 }
1985
isFastMathMathModeAvailable(const SPIRVSubtarget & ST)1986 static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) {
1987 if (ST.isKernel())
1988 return true;
1989 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
1990 return false;
1991 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1992 }
1993
handleMIFlagDecoration(MachineInstr & I,const SPIRVSubtarget & ST,const SPIRVInstrInfo & TII,SPIRV::RequirementHandler & Reqs)1994 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1995 const SPIRVInstrInfo &TII,
1996 SPIRV::RequirementHandler &Reqs) {
1997 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1998 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1999 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2000 .IsSatisfiable) {
2001 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2002 SPIRV::Decoration::NoSignedWrap, {});
2003 }
2004 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2005 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2006 SPIRV::Decoration::NoUnsignedWrap, ST,
2007 Reqs)
2008 .IsSatisfiable) {
2009 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2010 SPIRV::Decoration::NoUnsignedWrap, {});
2011 }
2012 if (!TII.canUseFastMathFlags(I))
2013 return;
2014 unsigned FMFlags = getFastMathFlags(I);
2015 if (FMFlags == SPIRV::FPFastMathMode::None)
2016 return;
2017
2018 if (isFastMathMathModeAvailable(ST)) {
2019 Register DstReg = I.getOperand(0).getReg();
2020 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2021 {FMFlags});
2022 }
2023 }
2024
2025 // Walk all functions and add decorations related to MI flags.
addDecorations(const Module & M,const SPIRVInstrInfo & TII,MachineModuleInfo * MMI,const SPIRVSubtarget & ST,SPIRV::ModuleAnalysisInfo & MAI)2026 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2027 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2028 SPIRV::ModuleAnalysisInfo &MAI) {
2029 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2030 MachineFunction *MF = MMI->getMachineFunction(*F);
2031 if (!MF)
2032 continue;
2033 for (auto &MBB : *MF)
2034 for (auto &MI : MBB)
2035 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
2036 }
2037 }
2038
addMBBNames(const Module & M,const SPIRVInstrInfo & TII,MachineModuleInfo * MMI,const SPIRVSubtarget & ST,SPIRV::ModuleAnalysisInfo & MAI)2039 static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2040 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2041 SPIRV::ModuleAnalysisInfo &MAI) {
2042 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2043 MachineFunction *MF = MMI->getMachineFunction(*F);
2044 if (!MF)
2045 continue;
2046 MachineRegisterInfo &MRI = MF->getRegInfo();
2047 for (auto &MBB : *MF) {
2048 if (!MBB.hasName() || MBB.empty())
2049 continue;
2050 // Emit basic block names.
2051 Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
2052 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2053 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2054 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2055 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2056 }
2057 }
2058 }
2059
2060 // patching Instruction::PHI to SPIRV::OpPhi
patchPhis(const Module & M,SPIRVGlobalRegistry * GR,const SPIRVInstrInfo & TII,MachineModuleInfo * MMI)2061 static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2062 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2063 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2064 MachineFunction *MF = MMI->getMachineFunction(*F);
2065 if (!MF)
2066 continue;
2067 for (auto &MBB : *MF) {
2068 for (MachineInstr &MI : MBB.phis()) {
2069 MI.setDesc(TII.get(SPIRV::OpPhi));
2070 Register ResTypeReg = GR->getSPIRVTypeID(
2071 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2072 MI.insert(MI.operands_begin() + 1,
2073 {MachineOperand::CreateReg(ResTypeReg, false)});
2074 }
2075 }
2076
2077 MF->getProperties().setNoPHIs();
2078 }
2079 }
2080
2081 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
2082
getAnalysisUsage(AnalysisUsage & AU) const2083 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
2084 AU.addRequired<TargetPassConfig>();
2085 AU.addRequired<MachineModuleInfoWrapperPass>();
2086 }
2087
runOnModule(Module & M)2088 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
2089 SPIRVTargetMachine &TM =
2090 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2091 ST = TM.getSubtargetImpl();
2092 GR = ST->getSPIRVGlobalRegistry();
2093 TII = ST->getInstrInfo();
2094
2095 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
2096
2097 setBaseInfo(M);
2098
2099 patchPhis(M, GR, *TII, MMI);
2100
2101 addMBBNames(M, *TII, MMI, *ST, MAI);
2102 addDecorations(M, *TII, MMI, *ST, MAI);
2103
2104 collectReqs(M, MAI, MMI, *ST);
2105
2106 // Process type/const/global var/func decl instructions, number their
2107 // destination registers from 0 to N, collect Extensions and Capabilities.
2108 collectReqs(M, MAI, MMI, *ST);
2109 collectDeclarations(M);
2110
2111 // Number rest of registers from N+1 onwards.
2112 numberRegistersGlobally(M);
2113
2114 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2115 processOtherInstrs(M);
2116
2117 // If there are no entry points, we need the Linkage capability.
2118 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2119 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2120
2121 // Set maximum ID used.
2122 GR->setBound(MAI.MaxID);
2123
2124 return false;
2125 }
2126