1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "MCTargetDesc/SPIRVBaseInfo.h" 19 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 20 #include "SPIRV.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "TargetInfo/SPIRVTargetInfo.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/CodeGen/MachineModuleInfo.h" 27 #include "llvm/CodeGen/TargetPassConfig.h" 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "spirv-module-analysis" 32 33 static cl::opt<bool> 34 SPVDumpDeps("spv-dump-deps", 35 cl::desc("Dump MIR with SPIR-V dependencies info"), 36 cl::Optional, cl::init(false)); 37 38 static cl::list<SPIRV::Capability::Capability> 39 AvoidCapabilities("avoid-spirv-capabilities", 40 cl::desc("SPIR-V capabilities to avoid if there are " 41 "other options enabling a feature"), 42 cl::ZeroOrMore, cl::Hidden, 43 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", 44 "SPIR-V Shader capability"))); 45 // Use sets instead of cl::list to check "if contains" condition 46 struct AvoidCapabilitiesSet { 47 SmallSet<SPIRV::Capability::Capability, 4> S; 48 AvoidCapabilitiesSet() { 49 for (auto Cap : AvoidCapabilities) 50 S.insert(Cap); 51 } 52 }; 53 54 char llvm::SPIRVModuleAnalysis::ID = 0; 55 56 namespace llvm { 57 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 58 } // namespace llvm 59 60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 61 true) 62 63 // Retrieve an unsigned from an MDNode with a list of them as operands. 64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 65 unsigned DefaultVal = 0) { 66 if (MdNode && OpIndex < MdNode->getNumOperands()) { 67 const auto &Op = MdNode->getOperand(OpIndex); 68 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 69 } 70 return DefaultVal; 71 } 72 73 static SPIRV::Requirements 74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 75 unsigned i, const SPIRVSubtarget &ST, 76 SPIRV::RequirementHandler &Reqs) { 77 static AvoidCapabilitiesSet 78 AvoidCaps; // contains capabilities to avoid if there is another option 79 80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i); 81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 82 VersionTuple SPIRVVersion = ST.getSPIRVVersion(); 83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer; 84 bool MaxVerOK = 85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer; 86 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 88 if (ReqCaps.empty()) { 89 if (ReqExts.empty()) { 90 if (MinVerOK && MaxVerOK) 91 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 92 return {false, {}, {}, VersionTuple(), VersionTuple()}; 93 } 94 } else if (MinVerOK && MaxVerOK) { 95 if (ReqCaps.size() == 1) { 96 auto Cap = ReqCaps[0]; 97 if (Reqs.isCapabilityAvailable(Cap)) 98 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 99 } else { 100 // By SPIR-V specification: "If an instruction, enumerant, or other 101 // feature specifies multiple enabling capabilities, only one such 102 // capability needs to be declared to use the feature." However, one 103 // capability may be preferred over another. We use command line 104 // argument(s) and AvoidCapabilities to avoid selection of certain 105 // capabilities if there are other options. 106 CapabilityList UseCaps; 107 for (auto Cap : ReqCaps) 108 if (Reqs.isCapabilityAvailable(Cap)) 109 UseCaps.push_back(Cap); 110 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) { 111 auto Cap = UseCaps[i]; 112 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) 113 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 114 } 115 } 116 } 117 // If there are no capabilities, or we can't satisfy the version or 118 // capability requirements, use the list of extensions (if the subtarget 119 // can handle them all). 120 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 121 return ST.canUseExtension(Ext); 122 })) { 123 return {true, 124 {}, 125 ReqExts, 126 VersionTuple(), 127 VersionTuple()}; // TODO: add versions to extensions. 128 } 129 return {false, {}, {}, VersionTuple(), VersionTuple()}; 130 } 131 132 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 133 MAI.MaxID = 0; 134 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 135 MAI.MS[i].clear(); 136 MAI.RegisterAliasTable.clear(); 137 MAI.InstrsToDelete.clear(); 138 MAI.FuncMap.clear(); 139 MAI.GlobalVarList.clear(); 140 MAI.ExtInstSetMap.clear(); 141 MAI.Reqs.clear(); 142 MAI.Reqs.initAvailableCapabilities(*ST); 143 144 // TODO: determine memory model and source language from the configuratoin. 145 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 146 auto MemMD = MemModel->getOperand(0); 147 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 148 getMetadataUInt(MemMD, 0)); 149 MAI.Mem = 150 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 151 } else { 152 // TODO: Add support for VulkanMemoryModel. 153 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL 154 : SPIRV::MemoryModel::GLSL450; 155 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) { 156 unsigned PtrSize = ST->getPointerSize(); 157 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 158 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 159 : SPIRV::AddressingModel::Logical; 160 } else { 161 // TODO: Add support for PhysicalStorageBufferAddress. 162 MAI.Addr = SPIRV::AddressingModel::Logical; 163 } 164 } 165 // Get the OpenCL version number from metadata. 166 // TODO: support other source languages. 167 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 168 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 169 // Construct version literal in accordance with SPIRV-LLVM-Translator. 170 // TODO: support multiple OCL version metadata. 171 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 172 auto VersionMD = VerNode->getOperand(0); 173 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 174 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 175 unsigned RevNum = getMetadataUInt(VersionMD, 2); 176 // Prevent Major part of OpenCL version to be 0 177 MAI.SrcLangVersion = 178 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum; 179 } else { 180 // If there is no information about OpenCL version we are forced to generate 181 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling 182 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV 183 // Translator avoids potential issues with run-times in a similar manner. 184 if (ST->isOpenCLEnv()) { 185 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP; 186 MAI.SrcLangVersion = 100000; 187 } else { 188 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 189 MAI.SrcLangVersion = 0; 190 } 191 } 192 193 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 194 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 195 MDNode *MD = ExtNode->getOperand(I); 196 if (!MD || MD->getNumOperands() == 0) 197 continue; 198 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 199 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 200 } 201 } 202 203 // Update required capabilities for this memory model, addressing model and 204 // source language. 205 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 206 MAI.Mem, *ST); 207 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 208 MAI.SrcLang, *ST); 209 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 210 MAI.Addr, *ST); 211 212 if (ST->isOpenCLEnv()) { 213 // TODO: check if it's required by default. 214 MAI.ExtInstSetMap[static_cast<unsigned>( 215 SPIRV::InstructionSet::OpenCL_std)] = 216 Register::index2VirtReg(MAI.getNextID()); 217 } 218 } 219 220 // Collect MI which defines the register in the given machine function. 221 static void collectDefInstr(Register Reg, const MachineFunction *MF, 222 SPIRV::ModuleAnalysisInfo *MAI, 223 SPIRV::ModuleSectionType MSType, 224 bool DoInsert = true) { 225 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 226 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 227 assert(MI && "There should be an instruction that defines the register"); 228 MAI->setSkipEmission(MI); 229 if (DoInsert) 230 MAI->MS[MSType].push_back(MI); 231 } 232 233 void SPIRVModuleAnalysis::collectGlobalEntities( 234 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 235 SPIRV::ModuleSectionType MSType, 236 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 237 bool UsePreOrder = false) { 238 DenseSet<const SPIRV::DTSortableEntry *> Visited; 239 for (const auto *E : DepsGraph) { 240 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 241 // NOTE: here we prefer recursive approach over iterative because 242 // we don't expect depchains long enough to cause SO. 243 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 244 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 245 if (Visited.count(E) || !Pred(E)) 246 return; 247 Visited.insert(E); 248 249 // Traversing deps graph in post-order allows us to get rid of 250 // register aliases preprocessing. 251 // But pre-order is required for correct processing of function 252 // declaration and arguments processing. 253 if (!UsePreOrder) 254 for (auto *S : E->getDeps()) 255 RecHoistUtil(S); 256 257 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 258 bool IsFirst = true; 259 for (auto &U : *E) { 260 const MachineFunction *MF = U.first; 261 Register Reg = U.second; 262 MAI.setRegisterAlias(MF, Reg, GlobalReg); 263 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 264 continue; 265 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 266 IsFirst = false; 267 if (E->getIsGV()) 268 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 269 } 270 271 if (UsePreOrder) 272 for (auto *S : E->getDeps()) 273 RecHoistUtil(S); 274 }; 275 RecHoistUtil(E); 276 } 277 } 278 279 // The function initializes global register alias table for types, consts, 280 // global vars and func decls and collects these instruction for output 281 // at module level. Also it collects explicit OpExtension/OpCapability 282 // instructions. 283 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 284 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 285 286 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 287 288 collectGlobalEntities( 289 DepsGraph, SPIRV::MB_TypeConstVars, 290 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 291 292 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 293 MachineFunction *MF = MMI->getMachineFunction(*F); 294 if (!MF) 295 continue; 296 // Iterate through and collect OpExtension/OpCapability instructions. 297 for (MachineBasicBlock &MBB : *MF) { 298 for (MachineInstr &MI : MBB) { 299 if (MI.getOpcode() == SPIRV::OpExtension) { 300 // Here, OpExtension just has a single enum operand, not a string. 301 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 302 MAI.Reqs.addExtension(Ext); 303 MAI.setSkipEmission(&MI); 304 } else if (MI.getOpcode() == SPIRV::OpCapability) { 305 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 306 MAI.Reqs.addCapability(Cap); 307 MAI.setSkipEmission(&MI); 308 } 309 } 310 } 311 } 312 313 collectGlobalEntities( 314 DepsGraph, SPIRV::MB_ExtFuncDecls, 315 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 316 } 317 318 // Look for IDs declared with Import linkage, and map the corresponding function 319 // to the register defining that variable (which will usually be the result of 320 // an OpFunction). This lets us call externally imported functions using 321 // the correct ID registers. 322 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 323 const Function *F) { 324 if (MI.getOpcode() == SPIRV::OpDecorate) { 325 // If it's got Import linkage. 326 auto Dec = MI.getOperand(1).getImm(); 327 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 328 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 329 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 330 // Map imported function name to function ID register. 331 const Function *ImportedFunc = 332 F->getParent()->getFunction(getStringImm(MI, 2)); 333 Register Target = MI.getOperand(0).getReg(); 334 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 335 } 336 } 337 } else if (MI.getOpcode() == SPIRV::OpFunction) { 338 // Record all internal OpFunction declarations. 339 Register Reg = MI.defs().begin()->getReg(); 340 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 341 assert(GlobalReg.isValid()); 342 MAI.FuncMap[F] = GlobalReg; 343 } 344 } 345 346 // References to a function via function pointers generate virtual 347 // registers without a definition. We are able to resolve this 348 // reference using Globar Register info into an OpFunction instruction 349 // and replace dummy operands by the corresponding global register references. 350 void SPIRVModuleAnalysis::collectFuncPtrs() { 351 for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) 352 if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) 353 collectFuncPtrs(MI); 354 } 355 356 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { 357 const MachineOperand *FunUse = &MI->getOperand(2); 358 if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { 359 const MachineInstr *FunDefMI = FunDef->getParent(); 360 assert(FunDefMI->getOpcode() == SPIRV::OpFunction && 361 "Constant function pointer must refer to function definition"); 362 Register FunDefReg = FunDef->getReg(); 363 Register GlobalFunDefReg = 364 MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); 365 assert(GlobalFunDefReg.isValid() && 366 "Function definition must refer to a global register"); 367 Register FunPtrReg = FunUse->getReg(); 368 MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); 369 } 370 } 371 372 using InstrSignature = SmallVector<size_t>; 373 using InstrTraces = std::set<InstrSignature>; 374 375 // Returns a representation of an instruction as a vector of MachineOperand 376 // hash values, see llvm::hash_value(const MachineOperand &MO) for details. 377 // This creates a signature of the instruction with the same content 378 // that MachineOperand::isIdenticalTo uses for comparison. 379 static InstrSignature instrToSignature(MachineInstr &MI, 380 SPIRV::ModuleAnalysisInfo &MAI) { 381 InstrSignature Signature; 382 for (unsigned i = 0; i < MI.getNumOperands(); ++i) { 383 const MachineOperand &MO = MI.getOperand(i); 384 size_t h; 385 if (MO.isReg()) { 386 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); 387 // mimic llvm::hash_value(const MachineOperand &MO) 388 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), 389 MO.isDef()); 390 } else { 391 h = hash_value(MO); 392 } 393 Signature.push_back(h); 394 } 395 return Signature; 396 } 397 398 // Collect the given instruction in the specified MS. We assume global register 399 // numbering has already occurred by this point. We can directly compare reg 400 // arguments when detecting duplicates. 401 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 402 SPIRV::ModuleSectionType MSType, InstrTraces &IS, 403 bool Append = true) { 404 MAI.setSkipEmission(&MI); 405 InstrSignature MISign = instrToSignature(MI, MAI); 406 auto FoundMI = IS.insert(MISign); 407 if (!FoundMI.second) 408 return; // insert failed, so we found a duplicate; don't add it to MAI.MS 409 // No duplicates, so add it. 410 if (Append) 411 MAI.MS[MSType].push_back(&MI); 412 else 413 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 414 } 415 416 // Some global instructions make reference to function-local ID regs, so cannot 417 // be correctly collected until these registers are globally numbered. 418 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 419 InstrTraces IS; 420 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 421 if ((*F).isDeclaration()) 422 continue; 423 MachineFunction *MF = MMI->getMachineFunction(*F); 424 assert(MF); 425 for (MachineBasicBlock &MBB : *MF) 426 for (MachineInstr &MI : MBB) { 427 if (MAI.getSkipEmission(&MI)) 428 continue; 429 const unsigned OpCode = MI.getOpcode(); 430 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 431 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS); 432 } else if (OpCode == SPIRV::OpEntryPoint) { 433 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS); 434 } else if (TII->isDecorationInstr(MI)) { 435 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS); 436 collectFuncNames(MI, &*F); 437 } else if (TII->isConstantInstr(MI)) { 438 // Now OpSpecConstant*s are not in DT, 439 // but they need to be collected anyway. 440 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS); 441 } else if (OpCode == SPIRV::OpFunction) { 442 collectFuncNames(MI, &*F); 443 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 444 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false); 445 } 446 } 447 } 448 } 449 450 // Number registers in all functions globally from 0 onwards and store 451 // the result in global register alias table. Some registers are already 452 // numbered in collectGlobalEntities. 453 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 454 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 455 if ((*F).isDeclaration()) 456 continue; 457 MachineFunction *MF = MMI->getMachineFunction(*F); 458 assert(MF); 459 for (MachineBasicBlock &MBB : *MF) { 460 for (MachineInstr &MI : MBB) { 461 for (MachineOperand &Op : MI.operands()) { 462 if (!Op.isReg()) 463 continue; 464 Register Reg = Op.getReg(); 465 if (MAI.hasRegisterAlias(MF, Reg)) 466 continue; 467 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 468 MAI.setRegisterAlias(MF, Reg, NewReg); 469 } 470 if (MI.getOpcode() != SPIRV::OpExtInst) 471 continue; 472 auto Set = MI.getOperand(2).getImm(); 473 if (!MAI.ExtInstSetMap.contains(Set)) 474 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 475 } 476 } 477 } 478 } 479 480 // RequirementHandler implementations. 481 void SPIRV::RequirementHandler::getAndAddRequirements( 482 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 483 const SPIRVSubtarget &ST) { 484 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 485 } 486 487 void SPIRV::RequirementHandler::recursiveAddCapabilities( 488 const CapabilityList &ToPrune) { 489 for (const auto &Cap : ToPrune) { 490 AllCaps.insert(Cap); 491 CapabilityList ImplicitDecls = 492 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 493 recursiveAddCapabilities(ImplicitDecls); 494 } 495 } 496 497 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 498 for (const auto &Cap : ToAdd) { 499 bool IsNewlyInserted = AllCaps.insert(Cap).second; 500 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 501 continue; 502 CapabilityList ImplicitDecls = 503 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 504 recursiveAddCapabilities(ImplicitDecls); 505 MinimalCaps.push_back(Cap); 506 } 507 } 508 509 void SPIRV::RequirementHandler::addRequirements( 510 const SPIRV::Requirements &Req) { 511 if (!Req.IsSatisfiable) 512 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 513 514 if (Req.Cap.has_value()) 515 addCapabilities({Req.Cap.value()}); 516 517 addExtensions(Req.Exts); 518 519 if (!Req.MinVer.empty()) { 520 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) { 521 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 522 << " and <= " << MaxVersion << "\n"); 523 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 524 } 525 526 if (MinVersion.empty() || Req.MinVer > MinVersion) 527 MinVersion = Req.MinVer; 528 } 529 530 if (!Req.MaxVer.empty()) { 531 if (!MinVersion.empty() && Req.MaxVer < MinVersion) { 532 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 533 << " and >= " << MinVersion << "\n"); 534 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 535 } 536 537 if (MaxVersion.empty() || Req.MaxVer < MaxVersion) 538 MaxVersion = Req.MaxVer; 539 } 540 } 541 542 void SPIRV::RequirementHandler::checkSatisfiable( 543 const SPIRVSubtarget &ST) const { 544 // Report as many errors as possible before aborting the compilation. 545 bool IsSatisfiable = true; 546 auto TargetVer = ST.getSPIRVVersion(); 547 548 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) { 549 LLVM_DEBUG( 550 dbgs() << "Target SPIR-V version too high for required features\n" 551 << "Required max version: " << MaxVersion << " target version " 552 << TargetVer << "\n"); 553 IsSatisfiable = false; 554 } 555 556 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) { 557 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 558 << "Required min version: " << MinVersion 559 << " target version " << TargetVer << "\n"); 560 IsSatisfiable = false; 561 } 562 563 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) { 564 LLVM_DEBUG( 565 dbgs() 566 << "Version is too low for some features and too high for others.\n" 567 << "Required SPIR-V min version: " << MinVersion 568 << " required SPIR-V max version " << MaxVersion << "\n"); 569 IsSatisfiable = false; 570 } 571 572 for (auto Cap : MinimalCaps) { 573 if (AvailableCaps.contains(Cap)) 574 continue; 575 LLVM_DEBUG(dbgs() << "Capability not supported: " 576 << getSymbolicOperandMnemonic( 577 OperandCategory::CapabilityOperand, Cap) 578 << "\n"); 579 IsSatisfiable = false; 580 } 581 582 for (auto Ext : AllExtensions) { 583 if (ST.canUseExtension(Ext)) 584 continue; 585 LLVM_DEBUG(dbgs() << "Extension not supported: " 586 << getSymbolicOperandMnemonic( 587 OperandCategory::ExtensionOperand, Ext) 588 << "\n"); 589 IsSatisfiable = false; 590 } 591 592 if (!IsSatisfiable) 593 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 594 } 595 596 // Add the given capabilities and all their implicitly defined capabilities too. 597 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 598 for (const auto Cap : ToAdd) 599 if (AvailableCaps.insert(Cap).second) 600 addAvailableCaps(getSymbolicOperandCapabilities( 601 SPIRV::OperandCategory::CapabilityOperand, Cap)); 602 } 603 604 void SPIRV::RequirementHandler::removeCapabilityIf( 605 const Capability::Capability ToRemove, 606 const Capability::Capability IfPresent) { 607 if (AllCaps.contains(IfPresent)) 608 AllCaps.erase(ToRemove); 609 } 610 611 namespace llvm { 612 namespace SPIRV { 613 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 614 if (ST.isOpenCLEnv()) { 615 initAvailableCapabilitiesForOpenCL(ST); 616 return; 617 } 618 619 if (ST.isVulkanEnv()) { 620 initAvailableCapabilitiesForVulkan(ST); 621 return; 622 } 623 624 report_fatal_error("Unimplemented environment for SPIR-V generation."); 625 } 626 627 void RequirementHandler::initAvailableCapabilitiesForOpenCL( 628 const SPIRVSubtarget &ST) { 629 // Add the min requirements for different OpenCL and SPIR-V versions. 630 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 631 Capability::Int16, Capability::Int8, Capability::Kernel, 632 Capability::Linkage, Capability::Vector16, 633 Capability::Groups, Capability::GenericPointer, 634 Capability::Shader}); 635 if (ST.hasOpenCLFullProfile()) 636 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 637 if (ST.hasOpenCLImageSupport()) { 638 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 639 Capability::Image1D, Capability::SampledBuffer, 640 Capability::ImageBuffer}); 641 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0))) 642 addAvailableCaps({Capability::ImageReadWrite}); 643 } 644 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) && 645 ST.isAtLeastOpenCLVer(VersionTuple(2, 2))) 646 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 647 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3))) 648 addAvailableCaps({Capability::GroupNonUniform, 649 Capability::GroupNonUniformVote, 650 Capability::GroupNonUniformArithmetic, 651 Capability::GroupNonUniformBallot, 652 Capability::GroupNonUniformClustered, 653 Capability::GroupNonUniformShuffle, 654 Capability::GroupNonUniformShuffleRelative}); 655 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4))) 656 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 657 Capability::SignedZeroInfNanPreserve, 658 Capability::RoundingModeRTE, 659 Capability::RoundingModeRTZ}); 660 // TODO: verify if this needs some checks. 661 addAvailableCaps({Capability::Float16, Capability::Float64}); 662 663 // Add capabilities enabled by extensions. 664 for (auto Extension : ST.getAllAvailableExtensions()) { 665 CapabilityList EnabledCapabilities = 666 getCapabilitiesEnabledByExtension(Extension); 667 addAvailableCaps(EnabledCapabilities); 668 } 669 670 // TODO: add OpenCL extensions. 671 } 672 673 void RequirementHandler::initAvailableCapabilitiesForVulkan( 674 const SPIRVSubtarget &ST) { 675 addAvailableCaps({Capability::Shader, Capability::Linkage}); 676 677 // Provided by all supported Vulkan versions. 678 addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16, 679 Capability::Float64, Capability::GroupNonUniform}); 680 } 681 682 } // namespace SPIRV 683 } // namespace llvm 684 685 // Add the required capabilities from a decoration instruction (including 686 // BuiltIns). 687 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 688 SPIRV::RequirementHandler &Reqs, 689 const SPIRVSubtarget &ST) { 690 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 691 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 692 Reqs.addRequirements(getSymbolicOperandRequirements( 693 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 694 695 if (Dec == SPIRV::Decoration::BuiltIn) { 696 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 697 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 698 Reqs.addRequirements(getSymbolicOperandRequirements( 699 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 700 } else if (Dec == SPIRV::Decoration::LinkageAttributes) { 701 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm(); 702 SPIRV::LinkageType::LinkageType LnkType = 703 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp); 704 if (LnkType == SPIRV::LinkageType::LinkOnceODR) 705 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr); 706 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL || 707 Dec == SPIRV::Decoration::CacheControlStoreINTEL) { 708 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls); 709 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) { 710 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access); 711 } else if (Dec == SPIRV::Decoration::InitModeINTEL || 712 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) { 713 Reqs.addExtension( 714 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations); 715 } 716 } 717 718 // Add requirements for image handling. 719 static void addOpTypeImageReqs(const MachineInstr &MI, 720 SPIRV::RequirementHandler &Reqs, 721 const SPIRVSubtarget &ST) { 722 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 723 // The operand indices used here are based on the OpTypeImage layout, which 724 // the MachineInstr follows as well. 725 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 726 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 727 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 728 ImgFormat, ST); 729 730 bool IsArrayed = MI.getOperand(4).getImm() == 1; 731 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 732 bool NoSampler = MI.getOperand(6).getImm() == 2; 733 // Add dimension requirements. 734 assert(MI.getOperand(2).isImm()); 735 switch (MI.getOperand(2).getImm()) { 736 case SPIRV::Dim::DIM_1D: 737 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 738 : SPIRV::Capability::Sampled1D); 739 break; 740 case SPIRV::Dim::DIM_2D: 741 if (IsMultisampled && NoSampler) 742 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 743 break; 744 case SPIRV::Dim::DIM_Cube: 745 Reqs.addRequirements(SPIRV::Capability::Shader); 746 if (IsArrayed) 747 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 748 : SPIRV::Capability::SampledCubeArray); 749 break; 750 case SPIRV::Dim::DIM_Rect: 751 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 752 : SPIRV::Capability::SampledRect); 753 break; 754 case SPIRV::Dim::DIM_Buffer: 755 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 756 : SPIRV::Capability::SampledBuffer); 757 break; 758 case SPIRV::Dim::DIM_SubpassData: 759 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 760 break; 761 } 762 763 // Has optional access qualifier. 764 // TODO: check if it's OpenCL's kernel. 765 if (MI.getNumOperands() > 8 && 766 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 767 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 768 else 769 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 770 } 771 772 // Add requirements for handling atomic float instructions 773 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ 774 "The atomic float instruction requires the following SPIR-V " \ 775 "extension: SPV_EXT_shader_atomic_float" ExtName 776 static void AddAtomicFloatRequirements(const MachineInstr &MI, 777 SPIRV::RequirementHandler &Reqs, 778 const SPIRVSubtarget &ST) { 779 assert(MI.getOperand(1).isReg() && 780 "Expect register operand in atomic float instruction"); 781 Register TypeReg = MI.getOperand(1).getReg(); 782 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); 783 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) 784 report_fatal_error("Result type of an atomic float instruction must be a " 785 "floating-point type scalar"); 786 787 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 788 unsigned Op = MI.getOpcode(); 789 if (Op == SPIRV::OpAtomicFAddEXT) { 790 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add)) 791 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false); 792 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); 793 switch (BitWidth) { 794 case 16: 795 if (!ST.canUseExtension( 796 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) 797 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); 798 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); 799 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); 800 break; 801 case 32: 802 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); 803 break; 804 case 64: 805 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT); 806 break; 807 default: 808 report_fatal_error( 809 "Unexpected floating-point type width in atomic float instruction"); 810 } 811 } else { 812 if (!ST.canUseExtension( 813 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max)) 814 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false); 815 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); 816 switch (BitWidth) { 817 case 16: 818 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); 819 break; 820 case 32: 821 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); 822 break; 823 case 64: 824 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT); 825 break; 826 default: 827 report_fatal_error( 828 "Unexpected floating-point type width in atomic float instruction"); 829 } 830 } 831 } 832 833 void addInstrRequirements(const MachineInstr &MI, 834 SPIRV::RequirementHandler &Reqs, 835 const SPIRVSubtarget &ST) { 836 switch (MI.getOpcode()) { 837 case SPIRV::OpMemoryModel: { 838 int64_t Addr = MI.getOperand(0).getImm(); 839 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 840 Addr, ST); 841 int64_t Mem = MI.getOperand(1).getImm(); 842 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 843 ST); 844 break; 845 } 846 case SPIRV::OpEntryPoint: { 847 int64_t Exe = MI.getOperand(0).getImm(); 848 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 849 Exe, ST); 850 break; 851 } 852 case SPIRV::OpExecutionMode: 853 case SPIRV::OpExecutionModeId: { 854 int64_t Exe = MI.getOperand(1).getImm(); 855 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 856 Exe, ST); 857 break; 858 } 859 case SPIRV::OpTypeMatrix: 860 Reqs.addCapability(SPIRV::Capability::Matrix); 861 break; 862 case SPIRV::OpTypeInt: { 863 unsigned BitWidth = MI.getOperand(1).getImm(); 864 if (BitWidth == 64) 865 Reqs.addCapability(SPIRV::Capability::Int64); 866 else if (BitWidth == 16) 867 Reqs.addCapability(SPIRV::Capability::Int16); 868 else if (BitWidth == 8) 869 Reqs.addCapability(SPIRV::Capability::Int8); 870 break; 871 } 872 case SPIRV::OpTypeFloat: { 873 unsigned BitWidth = MI.getOperand(1).getImm(); 874 if (BitWidth == 64) 875 Reqs.addCapability(SPIRV::Capability::Float64); 876 else if (BitWidth == 16) 877 Reqs.addCapability(SPIRV::Capability::Float16); 878 break; 879 } 880 case SPIRV::OpTypeVector: { 881 unsigned NumComponents = MI.getOperand(2).getImm(); 882 if (NumComponents == 8 || NumComponents == 16) 883 Reqs.addCapability(SPIRV::Capability::Vector16); 884 break; 885 } 886 case SPIRV::OpTypePointer: { 887 auto SC = MI.getOperand(1).getImm(); 888 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 889 ST); 890 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer 891 // capability. 892 if (!ST.isOpenCLEnv()) 893 break; 894 assert(MI.getOperand(2).isReg()); 895 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 896 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 897 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 898 TypeDef->getOperand(1).getImm() == 16) 899 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 900 break; 901 } 902 case SPIRV::OpBitReverse: 903 case SPIRV::OpBitFieldInsert: 904 case SPIRV::OpBitFieldSExtract: 905 case SPIRV::OpBitFieldUExtract: 906 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) { 907 Reqs.addCapability(SPIRV::Capability::Shader); 908 break; 909 } 910 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 911 Reqs.addCapability(SPIRV::Capability::BitInstructions); 912 break; 913 case SPIRV::OpTypeRuntimeArray: 914 Reqs.addCapability(SPIRV::Capability::Shader); 915 break; 916 case SPIRV::OpTypeOpaque: 917 case SPIRV::OpTypeEvent: 918 Reqs.addCapability(SPIRV::Capability::Kernel); 919 break; 920 case SPIRV::OpTypePipe: 921 case SPIRV::OpTypeReserveId: 922 Reqs.addCapability(SPIRV::Capability::Pipes); 923 break; 924 case SPIRV::OpTypeDeviceEvent: 925 case SPIRV::OpTypeQueue: 926 case SPIRV::OpBuildNDRange: 927 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 928 break; 929 case SPIRV::OpDecorate: 930 case SPIRV::OpDecorateId: 931 case SPIRV::OpDecorateString: 932 addOpDecorateReqs(MI, 1, Reqs, ST); 933 break; 934 case SPIRV::OpMemberDecorate: 935 case SPIRV::OpMemberDecorateString: 936 addOpDecorateReqs(MI, 2, Reqs, ST); 937 break; 938 case SPIRV::OpInBoundsPtrAccessChain: 939 Reqs.addCapability(SPIRV::Capability::Addresses); 940 break; 941 case SPIRV::OpConstantSampler: 942 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 943 break; 944 case SPIRV::OpTypeImage: 945 addOpTypeImageReqs(MI, Reqs, ST); 946 break; 947 case SPIRV::OpTypeSampler: 948 Reqs.addCapability(SPIRV::Capability::ImageBasic); 949 break; 950 case SPIRV::OpTypeForwardPointer: 951 // TODO: check if it's OpenCL's kernel. 952 Reqs.addCapability(SPIRV::Capability::Addresses); 953 break; 954 case SPIRV::OpAtomicFlagTestAndSet: 955 case SPIRV::OpAtomicLoad: 956 case SPIRV::OpAtomicStore: 957 case SPIRV::OpAtomicExchange: 958 case SPIRV::OpAtomicCompareExchange: 959 case SPIRV::OpAtomicIIncrement: 960 case SPIRV::OpAtomicIDecrement: 961 case SPIRV::OpAtomicIAdd: 962 case SPIRV::OpAtomicISub: 963 case SPIRV::OpAtomicUMin: 964 case SPIRV::OpAtomicUMax: 965 case SPIRV::OpAtomicSMin: 966 case SPIRV::OpAtomicSMax: 967 case SPIRV::OpAtomicAnd: 968 case SPIRV::OpAtomicOr: 969 case SPIRV::OpAtomicXor: { 970 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 971 const MachineInstr *InstrPtr = &MI; 972 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 973 assert(MI.getOperand(3).isReg()); 974 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 975 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 976 } 977 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 978 Register TypeReg = InstrPtr->getOperand(1).getReg(); 979 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 980 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 981 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 982 if (BitWidth == 64) 983 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 984 } 985 break; 986 } 987 case SPIRV::OpGroupNonUniformIAdd: 988 case SPIRV::OpGroupNonUniformFAdd: 989 case SPIRV::OpGroupNonUniformIMul: 990 case SPIRV::OpGroupNonUniformFMul: 991 case SPIRV::OpGroupNonUniformSMin: 992 case SPIRV::OpGroupNonUniformUMin: 993 case SPIRV::OpGroupNonUniformFMin: 994 case SPIRV::OpGroupNonUniformSMax: 995 case SPIRV::OpGroupNonUniformUMax: 996 case SPIRV::OpGroupNonUniformFMax: 997 case SPIRV::OpGroupNonUniformBitwiseAnd: 998 case SPIRV::OpGroupNonUniformBitwiseOr: 999 case SPIRV::OpGroupNonUniformBitwiseXor: 1000 case SPIRV::OpGroupNonUniformLogicalAnd: 1001 case SPIRV::OpGroupNonUniformLogicalOr: 1002 case SPIRV::OpGroupNonUniformLogicalXor: { 1003 assert(MI.getOperand(3).isImm()); 1004 int64_t GroupOp = MI.getOperand(3).getImm(); 1005 switch (GroupOp) { 1006 case SPIRV::GroupOperation::Reduce: 1007 case SPIRV::GroupOperation::InclusiveScan: 1008 case SPIRV::GroupOperation::ExclusiveScan: 1009 Reqs.addCapability(SPIRV::Capability::Kernel); 1010 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 1011 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1012 break; 1013 case SPIRV::GroupOperation::ClusteredReduce: 1014 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 1015 break; 1016 case SPIRV::GroupOperation::PartitionedReduceNV: 1017 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 1018 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 1019 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 1020 break; 1021 } 1022 break; 1023 } 1024 case SPIRV::OpGroupNonUniformShuffle: 1025 case SPIRV::OpGroupNonUniformShuffleXor: 1026 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 1027 break; 1028 case SPIRV::OpGroupNonUniformShuffleUp: 1029 case SPIRV::OpGroupNonUniformShuffleDown: 1030 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 1031 break; 1032 case SPIRV::OpGroupAll: 1033 case SPIRV::OpGroupAny: 1034 case SPIRV::OpGroupBroadcast: 1035 case SPIRV::OpGroupIAdd: 1036 case SPIRV::OpGroupFAdd: 1037 case SPIRV::OpGroupFMin: 1038 case SPIRV::OpGroupUMin: 1039 case SPIRV::OpGroupSMin: 1040 case SPIRV::OpGroupFMax: 1041 case SPIRV::OpGroupUMax: 1042 case SPIRV::OpGroupSMax: 1043 Reqs.addCapability(SPIRV::Capability::Groups); 1044 break; 1045 case SPIRV::OpGroupNonUniformElect: 1046 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1047 break; 1048 case SPIRV::OpGroupNonUniformAll: 1049 case SPIRV::OpGroupNonUniformAny: 1050 case SPIRV::OpGroupNonUniformAllEqual: 1051 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 1052 break; 1053 case SPIRV::OpGroupNonUniformBroadcast: 1054 case SPIRV::OpGroupNonUniformBroadcastFirst: 1055 case SPIRV::OpGroupNonUniformBallot: 1056 case SPIRV::OpGroupNonUniformInverseBallot: 1057 case SPIRV::OpGroupNonUniformBallotBitExtract: 1058 case SPIRV::OpGroupNonUniformBallotBitCount: 1059 case SPIRV::OpGroupNonUniformBallotFindLSB: 1060 case SPIRV::OpGroupNonUniformBallotFindMSB: 1061 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1062 break; 1063 case SPIRV::OpSubgroupShuffleINTEL: 1064 case SPIRV::OpSubgroupShuffleDownINTEL: 1065 case SPIRV::OpSubgroupShuffleUpINTEL: 1066 case SPIRV::OpSubgroupShuffleXorINTEL: 1067 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1068 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1069 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL); 1070 } 1071 break; 1072 case SPIRV::OpSubgroupBlockReadINTEL: 1073 case SPIRV::OpSubgroupBlockWriteINTEL: 1074 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1075 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1076 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL); 1077 } 1078 break; 1079 case SPIRV::OpSubgroupImageBlockReadINTEL: 1080 case SPIRV::OpSubgroupImageBlockWriteINTEL: 1081 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1082 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1083 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL); 1084 } 1085 break; 1086 case SPIRV::OpAssumeTrueKHR: 1087 case SPIRV::OpExpectKHR: 1088 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) { 1089 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume); 1090 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); 1091 } 1092 break; 1093 case SPIRV::OpPtrCastToCrossWorkgroupINTEL: 1094 case SPIRV::OpCrossWorkgroupCastToPtrINTEL: 1095 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { 1096 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes); 1097 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL); 1098 } 1099 break; 1100 case SPIRV::OpConstantFunctionPointerINTEL: 1101 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1102 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1103 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1104 } 1105 break; 1106 case SPIRV::OpGroupNonUniformRotateKHR: 1107 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate)) 1108 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the " 1109 "following SPIR-V extension: SPV_KHR_subgroup_rotate", 1110 false); 1111 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate); 1112 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR); 1113 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1114 break; 1115 case SPIRV::OpGroupIMulKHR: 1116 case SPIRV::OpGroupFMulKHR: 1117 case SPIRV::OpGroupBitwiseAndKHR: 1118 case SPIRV::OpGroupBitwiseOrKHR: 1119 case SPIRV::OpGroupBitwiseXorKHR: 1120 case SPIRV::OpGroupLogicalAndKHR: 1121 case SPIRV::OpGroupLogicalOrKHR: 1122 case SPIRV::OpGroupLogicalXorKHR: 1123 if (ST.canUseExtension( 1124 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) { 1125 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions); 1126 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR); 1127 } 1128 break; 1129 case SPIRV::OpReadClockKHR: 1130 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) 1131 report_fatal_error("OpReadClockKHR instruction requires the " 1132 "following SPIR-V extension: SPV_KHR_shader_clock", 1133 false); 1134 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock); 1135 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR); 1136 break; 1137 case SPIRV::OpFunctionPointerCallINTEL: 1138 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1139 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1140 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1141 } 1142 break; 1143 case SPIRV::OpAtomicFAddEXT: 1144 case SPIRV::OpAtomicFMinEXT: 1145 case SPIRV::OpAtomicFMaxEXT: 1146 AddAtomicFloatRequirements(MI, Reqs, ST); 1147 break; 1148 case SPIRV::OpConvertBF16ToFINTEL: 1149 case SPIRV::OpConvertFToBF16INTEL: 1150 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) { 1151 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion); 1152 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); 1153 } 1154 break; 1155 case SPIRV::OpVariableLengthArrayINTEL: 1156 case SPIRV::OpSaveMemoryINTEL: 1157 case SPIRV::OpRestoreMemoryINTEL: 1158 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) { 1159 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array); 1160 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL); 1161 } 1162 break; 1163 case SPIRV::OpAsmTargetINTEL: 1164 case SPIRV::OpAsmINTEL: 1165 case SPIRV::OpAsmCallINTEL: 1166 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) { 1167 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly); 1168 Reqs.addCapability(SPIRV::Capability::AsmINTEL); 1169 } 1170 break; 1171 case SPIRV::OpTypeCooperativeMatrixKHR: 1172 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) 1173 report_fatal_error( 1174 "OpTypeCooperativeMatrixKHR type requires the " 1175 "following SPIR-V extension: SPV_KHR_cooperative_matrix", 1176 false); 1177 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); 1178 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); 1179 break; 1180 default: 1181 break; 1182 } 1183 1184 // If we require capability Shader, then we can remove the requirement for 1185 // the BitInstructions capability, since Shader is a superset capability 1186 // of BitInstructions. 1187 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions, 1188 SPIRV::Capability::Shader); 1189 } 1190 1191 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 1192 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 1193 // Collect requirements for existing instructions. 1194 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1195 MachineFunction *MF = MMI->getMachineFunction(*F); 1196 if (!MF) 1197 continue; 1198 for (const MachineBasicBlock &MBB : *MF) 1199 for (const MachineInstr &MI : MBB) 1200 addInstrRequirements(MI, MAI.Reqs, ST); 1201 } 1202 // Collect requirements for OpExecutionMode instructions. 1203 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 1204 if (Node) { 1205 // SPV_KHR_float_controls is not available until v1.4 1206 bool RequireFloatControls = false, 1207 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); 1208 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 1209 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 1210 const MDOperand &MDOp = MDN->getOperand(1); 1211 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 1212 Constant *C = CMeta->getValue(); 1213 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 1214 auto EM = Const->getZExtValue(); 1215 MAI.Reqs.getAndAddRequirements( 1216 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 1217 // add SPV_KHR_float_controls if the version is too low 1218 switch (EM) { 1219 case SPIRV::ExecutionMode::DenormPreserve: 1220 case SPIRV::ExecutionMode::DenormFlushToZero: 1221 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: 1222 case SPIRV::ExecutionMode::RoundingModeRTE: 1223 case SPIRV::ExecutionMode::RoundingModeRTZ: 1224 RequireFloatControls = VerLower14; 1225 break; 1226 } 1227 } 1228 } 1229 } 1230 if (RequireFloatControls && 1231 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) 1232 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); 1233 } 1234 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 1235 const Function &F = *FI; 1236 if (F.isDeclaration()) 1237 continue; 1238 if (F.getMetadata("reqd_work_group_size")) 1239 MAI.Reqs.getAndAddRequirements( 1240 SPIRV::OperandCategory::ExecutionModeOperand, 1241 SPIRV::ExecutionMode::LocalSize, ST); 1242 if (F.getFnAttribute("hlsl.numthreads").isValid()) { 1243 MAI.Reqs.getAndAddRequirements( 1244 SPIRV::OperandCategory::ExecutionModeOperand, 1245 SPIRV::ExecutionMode::LocalSize, ST); 1246 } 1247 if (F.getMetadata("work_group_size_hint")) 1248 MAI.Reqs.getAndAddRequirements( 1249 SPIRV::OperandCategory::ExecutionModeOperand, 1250 SPIRV::ExecutionMode::LocalSizeHint, ST); 1251 if (F.getMetadata("intel_reqd_sub_group_size")) 1252 MAI.Reqs.getAndAddRequirements( 1253 SPIRV::OperandCategory::ExecutionModeOperand, 1254 SPIRV::ExecutionMode::SubgroupSize, ST); 1255 if (F.getMetadata("vec_type_hint")) 1256 MAI.Reqs.getAndAddRequirements( 1257 SPIRV::OperandCategory::ExecutionModeOperand, 1258 SPIRV::ExecutionMode::VecTypeHint, ST); 1259 1260 if (F.hasOptNone() && 1261 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 1262 // Output OpCapability OptNoneINTEL. 1263 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 1264 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 1265 } 1266 } 1267 } 1268 1269 static unsigned getFastMathFlags(const MachineInstr &I) { 1270 unsigned Flags = SPIRV::FPFastMathMode::None; 1271 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 1272 Flags |= SPIRV::FPFastMathMode::NotNaN; 1273 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 1274 Flags |= SPIRV::FPFastMathMode::NotInf; 1275 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 1276 Flags |= SPIRV::FPFastMathMode::NSZ; 1277 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 1278 Flags |= SPIRV::FPFastMathMode::AllowRecip; 1279 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 1280 Flags |= SPIRV::FPFastMathMode::Fast; 1281 return Flags; 1282 } 1283 1284 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 1285 const SPIRVInstrInfo &TII, 1286 SPIRV::RequirementHandler &Reqs) { 1287 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 1288 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1289 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 1290 .IsSatisfiable) { 1291 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1292 SPIRV::Decoration::NoSignedWrap, {}); 1293 } 1294 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 1295 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1296 SPIRV::Decoration::NoUnsignedWrap, ST, 1297 Reqs) 1298 .IsSatisfiable) { 1299 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1300 SPIRV::Decoration::NoUnsignedWrap, {}); 1301 } 1302 if (!TII.canUseFastMathFlags(I)) 1303 return; 1304 unsigned FMFlags = getFastMathFlags(I); 1305 if (FMFlags == SPIRV::FPFastMathMode::None) 1306 return; 1307 Register DstReg = I.getOperand(0).getReg(); 1308 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 1309 } 1310 1311 // Walk all functions and add decorations related to MI flags. 1312 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 1313 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1314 SPIRV::ModuleAnalysisInfo &MAI) { 1315 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1316 MachineFunction *MF = MMI->getMachineFunction(*F); 1317 if (!MF) 1318 continue; 1319 for (auto &MBB : *MF) 1320 for (auto &MI : MBB) 1321 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 1322 } 1323 } 1324 1325 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 1326 1327 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 1328 AU.addRequired<TargetPassConfig>(); 1329 AU.addRequired<MachineModuleInfoWrapperPass>(); 1330 } 1331 1332 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 1333 SPIRVTargetMachine &TM = 1334 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 1335 ST = TM.getSubtargetImpl(); 1336 GR = ST->getSPIRVGlobalRegistry(); 1337 TII = ST->getInstrInfo(); 1338 1339 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 1340 1341 setBaseInfo(M); 1342 1343 addDecorations(M, *TII, MMI, *ST, MAI); 1344 1345 collectReqs(M, MAI, MMI, *ST); 1346 1347 // Process type/const/global var/func decl instructions, number their 1348 // destination registers from 0 to N, collect Extensions and Capabilities. 1349 processDefInstrs(M); 1350 1351 // Number rest of registers from N+1 onwards. 1352 numberRegistersGlobally(M); 1353 1354 // Update references to OpFunction instructions to use Global Registers 1355 if (GR->hasConstFunPtr()) 1356 collectFuncPtrs(); 1357 1358 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1359 processOtherInstrs(M); 1360 1361 // If there are no entry points, we need the Linkage capability. 1362 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1363 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1364 1365 // Set maximum ID used. 1366 GR->setBound(MAI.MaxID); 1367 1368 return false; 1369 } 1370