1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "SPIRV.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 #include "TargetInfo/SPIRVTargetInfo.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/CodeGen/MachineModuleInfo.h" 25 #include "llvm/CodeGen/TargetPassConfig.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "spirv-module-analysis" 30 31 static cl::opt<bool> 32 SPVDumpDeps("spv-dump-deps", 33 cl::desc("Dump MIR with SPIR-V dependencies info"), 34 cl::Optional, cl::init(false)); 35 36 char llvm::SPIRVModuleAnalysis::ID = 0; 37 38 namespace llvm { 39 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 40 } // namespace llvm 41 42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 43 true) 44 45 // Retrieve an unsigned from an MDNode with a list of them as operands. 46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 47 unsigned DefaultVal = 0) { 48 if (MdNode && OpIndex < MdNode->getNumOperands()) { 49 const auto &Op = MdNode->getOperand(OpIndex); 50 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 51 } 52 return DefaultVal; 53 } 54 55 static SPIRV::Requirements 56 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 57 unsigned i, const SPIRVSubtarget &ST, 58 SPIRV::RequirementHandler &Reqs) { 59 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i); 60 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 61 unsigned TargetVer = ST.getSPIRVVersion(); 62 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer; 63 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer; 64 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 65 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 66 if (ReqCaps.empty()) { 67 if (ReqExts.empty()) { 68 if (MinVerOK && MaxVerOK) 69 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 70 return {false, {}, {}, 0, 0}; 71 } 72 } else if (MinVerOK && MaxVerOK) { 73 for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work. 74 if (Reqs.isCapabilityAvailable(Cap)) 75 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 76 } 77 } 78 // If there are no capabilities, or we can't satisfy the version or 79 // capability requirements, use the list of extensions (if the subtarget 80 // can handle them all). 81 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 82 return ST.canUseExtension(Ext); 83 })) { 84 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions. 85 } 86 return {false, {}, {}, 0, 0}; 87 } 88 89 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 90 MAI.MaxID = 0; 91 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 92 MAI.MS[i].clear(); 93 MAI.RegisterAliasTable.clear(); 94 MAI.InstrsToDelete.clear(); 95 MAI.FuncMap.clear(); 96 MAI.GlobalVarList.clear(); 97 MAI.ExtInstSetMap.clear(); 98 MAI.Reqs.clear(); 99 MAI.Reqs.initAvailableCapabilities(*ST); 100 101 // TODO: determine memory model and source language from the configuratoin. 102 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 103 auto MemMD = MemModel->getOperand(0); 104 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 105 getMetadataUInt(MemMD, 0)); 106 MAI.Mem = 107 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 108 } else { 109 MAI.Mem = SPIRV::MemoryModel::OpenCL; 110 unsigned PtrSize = ST->getPointerSize(); 111 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 112 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 113 : SPIRV::AddressingModel::Logical; 114 } 115 // Get the OpenCL version number from metadata. 116 // TODO: support other source languages. 117 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 118 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 119 // Construct version literal in accordance with SPIRV-LLVM-Translator. 120 // TODO: support multiple OCL version metadata. 121 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 122 auto VersionMD = VerNode->getOperand(0); 123 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 124 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 125 unsigned RevNum = getMetadataUInt(VersionMD, 2); 126 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum; 127 } else { 128 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 129 MAI.SrcLangVersion = 0; 130 } 131 132 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 133 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 134 MDNode *MD = ExtNode->getOperand(I); 135 if (!MD || MD->getNumOperands() == 0) 136 continue; 137 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 138 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 139 } 140 } 141 142 // Update required capabilities for this memory model, addressing model and 143 // source language. 144 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 145 MAI.Mem, *ST); 146 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 147 MAI.SrcLang, *ST); 148 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 149 MAI.Addr, *ST); 150 151 // TODO: check if it's required by default. 152 MAI.ExtInstSetMap[static_cast<unsigned>(SPIRV::InstructionSet::OpenCL_std)] = 153 Register::index2VirtReg(MAI.getNextID()); 154 } 155 156 // Collect MI which defines the register in the given machine function. 157 static void collectDefInstr(Register Reg, const MachineFunction *MF, 158 SPIRV::ModuleAnalysisInfo *MAI, 159 SPIRV::ModuleSectionType MSType, 160 bool DoInsert = true) { 161 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 162 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 163 assert(MI && "There should be an instruction that defines the register"); 164 MAI->setSkipEmission(MI); 165 if (DoInsert) 166 MAI->MS[MSType].push_back(MI); 167 } 168 169 void SPIRVModuleAnalysis::collectGlobalEntities( 170 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 171 SPIRV::ModuleSectionType MSType, 172 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 173 bool UsePreOrder = false) { 174 DenseSet<const SPIRV::DTSortableEntry *> Visited; 175 for (const auto *E : DepsGraph) { 176 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 177 // NOTE: here we prefer recursive approach over iterative because 178 // we don't expect depchains long enough to cause SO. 179 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 180 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 181 if (Visited.count(E) || !Pred(E)) 182 return; 183 Visited.insert(E); 184 185 // Traversing deps graph in post-order allows us to get rid of 186 // register aliases preprocessing. 187 // But pre-order is required for correct processing of function 188 // declaration and arguments processing. 189 if (!UsePreOrder) 190 for (auto *S : E->getDeps()) 191 RecHoistUtil(S); 192 193 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 194 bool IsFirst = true; 195 for (auto &U : *E) { 196 const MachineFunction *MF = U.first; 197 Register Reg = U.second; 198 MAI.setRegisterAlias(MF, Reg, GlobalReg); 199 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 200 continue; 201 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 202 IsFirst = false; 203 if (E->getIsGV()) 204 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 205 } 206 207 if (UsePreOrder) 208 for (auto *S : E->getDeps()) 209 RecHoistUtil(S); 210 }; 211 RecHoistUtil(E); 212 } 213 } 214 215 // The function initializes global register alias table for types, consts, 216 // global vars and func decls and collects these instruction for output 217 // at module level. Also it collects explicit OpExtension/OpCapability 218 // instructions. 219 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 220 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 221 222 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 223 224 collectGlobalEntities( 225 DepsGraph, SPIRV::MB_TypeConstVars, 226 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 227 228 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 229 MachineFunction *MF = MMI->getMachineFunction(*F); 230 if (!MF) 231 continue; 232 // Iterate through and collect OpExtension/OpCapability instructions. 233 for (MachineBasicBlock &MBB : *MF) { 234 for (MachineInstr &MI : MBB) { 235 if (MI.getOpcode() == SPIRV::OpExtension) { 236 // Here, OpExtension just has a single enum operand, not a string. 237 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 238 MAI.Reqs.addExtension(Ext); 239 MAI.setSkipEmission(&MI); 240 } else if (MI.getOpcode() == SPIRV::OpCapability) { 241 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 242 MAI.Reqs.addCapability(Cap); 243 MAI.setSkipEmission(&MI); 244 } 245 } 246 } 247 } 248 249 collectGlobalEntities( 250 DepsGraph, SPIRV::MB_ExtFuncDecls, 251 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 252 } 253 254 // True if there is an instruction in the MS list with all the same operands as 255 // the given instruction has (after the given starting index). 256 // TODO: maybe it needs to check Opcodes too. 257 static bool findSameInstrInMS(const MachineInstr &A, 258 SPIRV::ModuleSectionType MSType, 259 SPIRV::ModuleAnalysisInfo &MAI, 260 unsigned StartOpIndex = 0) { 261 for (const auto *B : MAI.MS[MSType]) { 262 const unsigned NumAOps = A.getNumOperands(); 263 if (NumAOps != B->getNumOperands() || A.getNumDefs() != B->getNumDefs()) 264 continue; 265 bool AllOpsMatch = true; 266 for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) { 267 if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) { 268 Register RegA = A.getOperand(i).getReg(); 269 Register RegB = B->getOperand(i).getReg(); 270 AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) == 271 MAI.getRegisterAlias(B->getMF(), RegB); 272 } else { 273 AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i)); 274 } 275 } 276 if (AllOpsMatch) 277 return true; 278 } 279 return false; 280 } 281 282 // Look for IDs declared with Import linkage, and map the corresponding function 283 // to the register defining that variable (which will usually be the result of 284 // an OpFunction). This lets us call externally imported functions using 285 // the correct ID registers. 286 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 287 const Function *F) { 288 if (MI.getOpcode() == SPIRV::OpDecorate) { 289 // If it's got Import linkage. 290 auto Dec = MI.getOperand(1).getImm(); 291 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 292 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 293 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 294 // Map imported function name to function ID register. 295 const Function *ImportedFunc = 296 F->getParent()->getFunction(getStringImm(MI, 2)); 297 Register Target = MI.getOperand(0).getReg(); 298 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 299 } 300 } 301 } else if (MI.getOpcode() == SPIRV::OpFunction) { 302 // Record all internal OpFunction declarations. 303 Register Reg = MI.defs().begin()->getReg(); 304 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 305 assert(GlobalReg.isValid()); 306 MAI.FuncMap[F] = GlobalReg; 307 } 308 } 309 310 // Collect the given instruction in the specified MS. We assume global register 311 // numbering has already occurred by this point. We can directly compare reg 312 // arguments when detecting duplicates. 313 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 314 SPIRV::ModuleSectionType MSType, 315 bool Append = true) { 316 MAI.setSkipEmission(&MI); 317 if (findSameInstrInMS(MI, MSType, MAI)) 318 return; // Found a duplicate, so don't add it. 319 // No duplicates, so add it. 320 if (Append) 321 MAI.MS[MSType].push_back(&MI); 322 else 323 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 324 } 325 326 // Some global instructions make reference to function-local ID regs, so cannot 327 // be correctly collected until these registers are globally numbered. 328 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 329 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 330 if ((*F).isDeclaration()) 331 continue; 332 MachineFunction *MF = MMI->getMachineFunction(*F); 333 assert(MF); 334 for (MachineBasicBlock &MBB : *MF) 335 for (MachineInstr &MI : MBB) { 336 if (MAI.getSkipEmission(&MI)) 337 continue; 338 const unsigned OpCode = MI.getOpcode(); 339 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 340 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames); 341 } else if (OpCode == SPIRV::OpEntryPoint) { 342 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints); 343 } else if (TII->isDecorationInstr(MI)) { 344 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations); 345 collectFuncNames(MI, &*F); 346 } else if (TII->isConstantInstr(MI)) { 347 // Now OpSpecConstant*s are not in DT, 348 // but they need to be collected anyway. 349 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars); 350 } else if (OpCode == SPIRV::OpFunction) { 351 collectFuncNames(MI, &*F); 352 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 353 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, false); 354 } 355 } 356 } 357 } 358 359 // Number registers in all functions globally from 0 onwards and store 360 // the result in global register alias table. Some registers are already 361 // numbered in collectGlobalEntities. 362 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 363 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 364 if ((*F).isDeclaration()) 365 continue; 366 MachineFunction *MF = MMI->getMachineFunction(*F); 367 assert(MF); 368 for (MachineBasicBlock &MBB : *MF) { 369 for (MachineInstr &MI : MBB) { 370 for (MachineOperand &Op : MI.operands()) { 371 if (!Op.isReg()) 372 continue; 373 Register Reg = Op.getReg(); 374 if (MAI.hasRegisterAlias(MF, Reg)) 375 continue; 376 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 377 MAI.setRegisterAlias(MF, Reg, NewReg); 378 } 379 if (MI.getOpcode() != SPIRV::OpExtInst) 380 continue; 381 auto Set = MI.getOperand(2).getImm(); 382 if (MAI.ExtInstSetMap.find(Set) == MAI.ExtInstSetMap.end()) 383 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 384 } 385 } 386 } 387 } 388 389 // RequirementHandler implementations. 390 void SPIRV::RequirementHandler::getAndAddRequirements( 391 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 392 const SPIRVSubtarget &ST) { 393 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 394 } 395 396 void SPIRV::RequirementHandler::pruneCapabilities( 397 const CapabilityList &ToPrune) { 398 for (const auto &Cap : ToPrune) { 399 AllCaps.insert(Cap); 400 auto FoundIndex = std::find(MinimalCaps.begin(), MinimalCaps.end(), Cap); 401 if (FoundIndex != MinimalCaps.end()) 402 MinimalCaps.erase(FoundIndex); 403 CapabilityList ImplicitDecls = 404 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 405 pruneCapabilities(ImplicitDecls); 406 } 407 } 408 409 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 410 for (const auto &Cap : ToAdd) { 411 bool IsNewlyInserted = AllCaps.insert(Cap).second; 412 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 413 continue; 414 CapabilityList ImplicitDecls = 415 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 416 pruneCapabilities(ImplicitDecls); 417 MinimalCaps.push_back(Cap); 418 } 419 } 420 421 void SPIRV::RequirementHandler::addRequirements( 422 const SPIRV::Requirements &Req) { 423 if (!Req.IsSatisfiable) 424 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 425 426 if (Req.Cap.has_value()) 427 addCapabilities({Req.Cap.value()}); 428 429 addExtensions(Req.Exts); 430 431 if (Req.MinVer) { 432 if (MaxVersion && Req.MinVer > MaxVersion) { 433 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 434 << " and <= " << MaxVersion << "\n"); 435 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 436 } 437 438 if (MinVersion == 0 || Req.MinVer > MinVersion) 439 MinVersion = Req.MinVer; 440 } 441 442 if (Req.MaxVer) { 443 if (MinVersion && Req.MaxVer < MinVersion) { 444 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 445 << " and >= " << MinVersion << "\n"); 446 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 447 } 448 449 if (MaxVersion == 0 || Req.MaxVer < MaxVersion) 450 MaxVersion = Req.MaxVer; 451 } 452 } 453 454 void SPIRV::RequirementHandler::checkSatisfiable( 455 const SPIRVSubtarget &ST) const { 456 // Report as many errors as possible before aborting the compilation. 457 bool IsSatisfiable = true; 458 auto TargetVer = ST.getSPIRVVersion(); 459 460 if (MaxVersion && TargetVer && MaxVersion < TargetVer) { 461 LLVM_DEBUG( 462 dbgs() << "Target SPIR-V version too high for required features\n" 463 << "Required max version: " << MaxVersion << " target version " 464 << TargetVer << "\n"); 465 IsSatisfiable = false; 466 } 467 468 if (MinVersion && TargetVer && MinVersion > TargetVer) { 469 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 470 << "Required min version: " << MinVersion 471 << " target version " << TargetVer << "\n"); 472 IsSatisfiable = false; 473 } 474 475 if (MinVersion && MaxVersion && MinVersion > MaxVersion) { 476 LLVM_DEBUG( 477 dbgs() 478 << "Version is too low for some features and too high for others.\n" 479 << "Required SPIR-V min version: " << MinVersion 480 << " required SPIR-V max version " << MaxVersion << "\n"); 481 IsSatisfiable = false; 482 } 483 484 for (auto Cap : MinimalCaps) { 485 if (AvailableCaps.contains(Cap)) 486 continue; 487 LLVM_DEBUG(dbgs() << "Capability not supported: " 488 << getSymbolicOperandMnemonic( 489 OperandCategory::CapabilityOperand, Cap) 490 << "\n"); 491 IsSatisfiable = false; 492 } 493 494 for (auto Ext : AllExtensions) { 495 if (ST.canUseExtension(Ext)) 496 continue; 497 LLVM_DEBUG(dbgs() << "Extension not suported: " 498 << getSymbolicOperandMnemonic( 499 OperandCategory::ExtensionOperand, Ext) 500 << "\n"); 501 IsSatisfiable = false; 502 } 503 504 if (!IsSatisfiable) 505 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 506 } 507 508 // Add the given capabilities and all their implicitly defined capabilities too. 509 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 510 for (const auto Cap : ToAdd) 511 if (AvailableCaps.insert(Cap).second) 512 addAvailableCaps(getSymbolicOperandCapabilities( 513 SPIRV::OperandCategory::CapabilityOperand, Cap)); 514 } 515 516 namespace llvm { 517 namespace SPIRV { 518 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 519 // TODO: Implemented for other targets other then OpenCL. 520 if (!ST.isOpenCLEnv()) 521 return; 522 // Add the min requirements for different OpenCL and SPIR-V versions. 523 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 524 Capability::Int16, Capability::Int8, Capability::Kernel, 525 Capability::Linkage, Capability::Vector16, 526 Capability::Groups, Capability::GenericPointer, 527 Capability::Shader}); 528 if (ST.hasOpenCLFullProfile()) 529 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 530 if (ST.hasOpenCLImageSupport()) { 531 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 532 Capability::Image1D, Capability::SampledBuffer, 533 Capability::ImageBuffer}); 534 if (ST.isAtLeastOpenCLVer(20)) 535 addAvailableCaps({Capability::ImageReadWrite}); 536 } 537 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22)) 538 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 539 if (ST.isAtLeastSPIRVVer(13)) 540 addAvailableCaps({Capability::GroupNonUniform, 541 Capability::GroupNonUniformVote, 542 Capability::GroupNonUniformArithmetic, 543 Capability::GroupNonUniformBallot, 544 Capability::GroupNonUniformClustered, 545 Capability::GroupNonUniformShuffle, 546 Capability::GroupNonUniformShuffleRelative}); 547 if (ST.isAtLeastSPIRVVer(14)) 548 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 549 Capability::SignedZeroInfNanPreserve, 550 Capability::RoundingModeRTE, 551 Capability::RoundingModeRTZ}); 552 // TODO: verify if this needs some checks. 553 addAvailableCaps({Capability::Float16, Capability::Float64}); 554 555 // TODO: add OpenCL extensions. 556 } 557 } // namespace SPIRV 558 } // namespace llvm 559 560 // Add the required capabilities from a decoration instruction (including 561 // BuiltIns). 562 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 563 SPIRV::RequirementHandler &Reqs, 564 const SPIRVSubtarget &ST) { 565 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 566 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 567 Reqs.addRequirements(getSymbolicOperandRequirements( 568 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 569 570 if (Dec == SPIRV::Decoration::BuiltIn) { 571 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 572 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 573 Reqs.addRequirements(getSymbolicOperandRequirements( 574 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 575 } 576 } 577 578 // Add requirements for image handling. 579 static void addOpTypeImageReqs(const MachineInstr &MI, 580 SPIRV::RequirementHandler &Reqs, 581 const SPIRVSubtarget &ST) { 582 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 583 // The operand indices used here are based on the OpTypeImage layout, which 584 // the MachineInstr follows as well. 585 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 586 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 587 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 588 ImgFormat, ST); 589 590 bool IsArrayed = MI.getOperand(4).getImm() == 1; 591 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 592 bool NoSampler = MI.getOperand(6).getImm() == 2; 593 // Add dimension requirements. 594 assert(MI.getOperand(2).isImm()); 595 switch (MI.getOperand(2).getImm()) { 596 case SPIRV::Dim::DIM_1D: 597 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 598 : SPIRV::Capability::Sampled1D); 599 break; 600 case SPIRV::Dim::DIM_2D: 601 if (IsMultisampled && NoSampler) 602 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 603 break; 604 case SPIRV::Dim::DIM_Cube: 605 Reqs.addRequirements(SPIRV::Capability::Shader); 606 if (IsArrayed) 607 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 608 : SPIRV::Capability::SampledCubeArray); 609 break; 610 case SPIRV::Dim::DIM_Rect: 611 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 612 : SPIRV::Capability::SampledRect); 613 break; 614 case SPIRV::Dim::DIM_Buffer: 615 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 616 : SPIRV::Capability::SampledBuffer); 617 break; 618 case SPIRV::Dim::DIM_SubpassData: 619 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 620 break; 621 } 622 623 // Has optional access qualifier. 624 // TODO: check if it's OpenCL's kernel. 625 if (MI.getNumOperands() > 8 && 626 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 627 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 628 else 629 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 630 } 631 632 void addInstrRequirements(const MachineInstr &MI, 633 SPIRV::RequirementHandler &Reqs, 634 const SPIRVSubtarget &ST) { 635 switch (MI.getOpcode()) { 636 case SPIRV::OpMemoryModel: { 637 int64_t Addr = MI.getOperand(0).getImm(); 638 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 639 Addr, ST); 640 int64_t Mem = MI.getOperand(1).getImm(); 641 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 642 ST); 643 break; 644 } 645 case SPIRV::OpEntryPoint: { 646 int64_t Exe = MI.getOperand(0).getImm(); 647 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 648 Exe, ST); 649 break; 650 } 651 case SPIRV::OpExecutionMode: 652 case SPIRV::OpExecutionModeId: { 653 int64_t Exe = MI.getOperand(1).getImm(); 654 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 655 Exe, ST); 656 break; 657 } 658 case SPIRV::OpTypeMatrix: 659 Reqs.addCapability(SPIRV::Capability::Matrix); 660 break; 661 case SPIRV::OpTypeInt: { 662 unsigned BitWidth = MI.getOperand(1).getImm(); 663 if (BitWidth == 64) 664 Reqs.addCapability(SPIRV::Capability::Int64); 665 else if (BitWidth == 16) 666 Reqs.addCapability(SPIRV::Capability::Int16); 667 else if (BitWidth == 8) 668 Reqs.addCapability(SPIRV::Capability::Int8); 669 break; 670 } 671 case SPIRV::OpTypeFloat: { 672 unsigned BitWidth = MI.getOperand(1).getImm(); 673 if (BitWidth == 64) 674 Reqs.addCapability(SPIRV::Capability::Float64); 675 else if (BitWidth == 16) 676 Reqs.addCapability(SPIRV::Capability::Float16); 677 break; 678 } 679 case SPIRV::OpTypeVector: { 680 unsigned NumComponents = MI.getOperand(2).getImm(); 681 if (NumComponents == 8 || NumComponents == 16) 682 Reqs.addCapability(SPIRV::Capability::Vector16); 683 break; 684 } 685 case SPIRV::OpTypePointer: { 686 auto SC = MI.getOperand(1).getImm(); 687 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 688 ST); 689 // If it's a type of pointer to float16, add Float16Buffer capability. 690 assert(MI.getOperand(2).isReg()); 691 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 692 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 693 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 694 TypeDef->getOperand(1).getImm() == 16) 695 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 696 break; 697 } 698 case SPIRV::OpBitReverse: 699 case SPIRV::OpTypeRuntimeArray: 700 Reqs.addCapability(SPIRV::Capability::Shader); 701 break; 702 case SPIRV::OpTypeOpaque: 703 case SPIRV::OpTypeEvent: 704 Reqs.addCapability(SPIRV::Capability::Kernel); 705 break; 706 case SPIRV::OpTypePipe: 707 case SPIRV::OpTypeReserveId: 708 Reqs.addCapability(SPIRV::Capability::Pipes); 709 break; 710 case SPIRV::OpTypeDeviceEvent: 711 case SPIRV::OpTypeQueue: 712 case SPIRV::OpBuildNDRange: 713 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 714 break; 715 case SPIRV::OpDecorate: 716 case SPIRV::OpDecorateId: 717 case SPIRV::OpDecorateString: 718 addOpDecorateReqs(MI, 1, Reqs, ST); 719 break; 720 case SPIRV::OpMemberDecorate: 721 case SPIRV::OpMemberDecorateString: 722 addOpDecorateReqs(MI, 2, Reqs, ST); 723 break; 724 case SPIRV::OpInBoundsPtrAccessChain: 725 Reqs.addCapability(SPIRV::Capability::Addresses); 726 break; 727 case SPIRV::OpConstantSampler: 728 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 729 break; 730 case SPIRV::OpTypeImage: 731 addOpTypeImageReqs(MI, Reqs, ST); 732 break; 733 case SPIRV::OpTypeSampler: 734 Reqs.addCapability(SPIRV::Capability::ImageBasic); 735 break; 736 case SPIRV::OpTypeForwardPointer: 737 // TODO: check if it's OpenCL's kernel. 738 Reqs.addCapability(SPIRV::Capability::Addresses); 739 break; 740 case SPIRV::OpAtomicFlagTestAndSet: 741 case SPIRV::OpAtomicLoad: 742 case SPIRV::OpAtomicStore: 743 case SPIRV::OpAtomicExchange: 744 case SPIRV::OpAtomicCompareExchange: 745 case SPIRV::OpAtomicIIncrement: 746 case SPIRV::OpAtomicIDecrement: 747 case SPIRV::OpAtomicIAdd: 748 case SPIRV::OpAtomicISub: 749 case SPIRV::OpAtomicUMin: 750 case SPIRV::OpAtomicUMax: 751 case SPIRV::OpAtomicSMin: 752 case SPIRV::OpAtomicSMax: 753 case SPIRV::OpAtomicAnd: 754 case SPIRV::OpAtomicOr: 755 case SPIRV::OpAtomicXor: { 756 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 757 const MachineInstr *InstrPtr = &MI; 758 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 759 assert(MI.getOperand(3).isReg()); 760 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 761 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 762 } 763 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 764 Register TypeReg = InstrPtr->getOperand(1).getReg(); 765 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 766 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 767 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 768 if (BitWidth == 64) 769 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 770 } 771 break; 772 } 773 case SPIRV::OpGroupNonUniformIAdd: 774 case SPIRV::OpGroupNonUniformFAdd: 775 case SPIRV::OpGroupNonUniformIMul: 776 case SPIRV::OpGroupNonUniformFMul: 777 case SPIRV::OpGroupNonUniformSMin: 778 case SPIRV::OpGroupNonUniformUMin: 779 case SPIRV::OpGroupNonUniformFMin: 780 case SPIRV::OpGroupNonUniformSMax: 781 case SPIRV::OpGroupNonUniformUMax: 782 case SPIRV::OpGroupNonUniformFMax: 783 case SPIRV::OpGroupNonUniformBitwiseAnd: 784 case SPIRV::OpGroupNonUniformBitwiseOr: 785 case SPIRV::OpGroupNonUniformBitwiseXor: 786 case SPIRV::OpGroupNonUniformLogicalAnd: 787 case SPIRV::OpGroupNonUniformLogicalOr: 788 case SPIRV::OpGroupNonUniformLogicalXor: { 789 assert(MI.getOperand(3).isImm()); 790 int64_t GroupOp = MI.getOperand(3).getImm(); 791 switch (GroupOp) { 792 case SPIRV::GroupOperation::Reduce: 793 case SPIRV::GroupOperation::InclusiveScan: 794 case SPIRV::GroupOperation::ExclusiveScan: 795 Reqs.addCapability(SPIRV::Capability::Kernel); 796 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 797 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 798 break; 799 case SPIRV::GroupOperation::ClusteredReduce: 800 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 801 break; 802 case SPIRV::GroupOperation::PartitionedReduceNV: 803 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 804 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 805 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 806 break; 807 } 808 break; 809 } 810 case SPIRV::OpGroupNonUniformShuffle: 811 case SPIRV::OpGroupNonUniformShuffleXor: 812 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 813 break; 814 case SPIRV::OpGroupNonUniformShuffleUp: 815 case SPIRV::OpGroupNonUniformShuffleDown: 816 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 817 break; 818 case SPIRV::OpGroupAll: 819 case SPIRV::OpGroupAny: 820 case SPIRV::OpGroupBroadcast: 821 case SPIRV::OpGroupIAdd: 822 case SPIRV::OpGroupFAdd: 823 case SPIRV::OpGroupFMin: 824 case SPIRV::OpGroupUMin: 825 case SPIRV::OpGroupSMin: 826 case SPIRV::OpGroupFMax: 827 case SPIRV::OpGroupUMax: 828 case SPIRV::OpGroupSMax: 829 Reqs.addCapability(SPIRV::Capability::Groups); 830 break; 831 case SPIRV::OpGroupNonUniformElect: 832 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 833 break; 834 case SPIRV::OpGroupNonUniformAll: 835 case SPIRV::OpGroupNonUniformAny: 836 case SPIRV::OpGroupNonUniformAllEqual: 837 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 838 break; 839 case SPIRV::OpGroupNonUniformBroadcast: 840 case SPIRV::OpGroupNonUniformBroadcastFirst: 841 case SPIRV::OpGroupNonUniformBallot: 842 case SPIRV::OpGroupNonUniformInverseBallot: 843 case SPIRV::OpGroupNonUniformBallotBitExtract: 844 case SPIRV::OpGroupNonUniformBallotBitCount: 845 case SPIRV::OpGroupNonUniformBallotFindLSB: 846 case SPIRV::OpGroupNonUniformBallotFindMSB: 847 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 848 break; 849 default: 850 break; 851 } 852 } 853 854 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 855 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 856 // Collect requirements for existing instructions. 857 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 858 MachineFunction *MF = MMI->getMachineFunction(*F); 859 if (!MF) 860 continue; 861 for (const MachineBasicBlock &MBB : *MF) 862 for (const MachineInstr &MI : MBB) 863 addInstrRequirements(MI, MAI.Reqs, ST); 864 } 865 // Collect requirements for OpExecutionMode instructions. 866 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 867 if (Node) { 868 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 869 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 870 const MDOperand &MDOp = MDN->getOperand(1); 871 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 872 Constant *C = CMeta->getValue(); 873 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 874 auto EM = Const->getZExtValue(); 875 MAI.Reqs.getAndAddRequirements( 876 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 877 } 878 } 879 } 880 } 881 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 882 const Function &F = *FI; 883 if (F.isDeclaration()) 884 continue; 885 if (F.getMetadata("reqd_work_group_size")) 886 MAI.Reqs.getAndAddRequirements( 887 SPIRV::OperandCategory::ExecutionModeOperand, 888 SPIRV::ExecutionMode::LocalSize, ST); 889 if (F.getMetadata("work_group_size_hint")) 890 MAI.Reqs.getAndAddRequirements( 891 SPIRV::OperandCategory::ExecutionModeOperand, 892 SPIRV::ExecutionMode::LocalSizeHint, ST); 893 if (F.getMetadata("intel_reqd_sub_group_size")) 894 MAI.Reqs.getAndAddRequirements( 895 SPIRV::OperandCategory::ExecutionModeOperand, 896 SPIRV::ExecutionMode::SubgroupSize, ST); 897 if (F.getMetadata("vec_type_hint")) 898 MAI.Reqs.getAndAddRequirements( 899 SPIRV::OperandCategory::ExecutionModeOperand, 900 SPIRV::ExecutionMode::VecTypeHint, ST); 901 } 902 } 903 904 static unsigned getFastMathFlags(const MachineInstr &I) { 905 unsigned Flags = SPIRV::FPFastMathMode::None; 906 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 907 Flags |= SPIRV::FPFastMathMode::NotNaN; 908 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 909 Flags |= SPIRV::FPFastMathMode::NotInf; 910 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 911 Flags |= SPIRV::FPFastMathMode::NSZ; 912 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 913 Flags |= SPIRV::FPFastMathMode::AllowRecip; 914 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 915 Flags |= SPIRV::FPFastMathMode::Fast; 916 return Flags; 917 } 918 919 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 920 const SPIRVInstrInfo &TII, 921 SPIRV::RequirementHandler &Reqs) { 922 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 923 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 924 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 925 .IsSatisfiable) { 926 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 927 SPIRV::Decoration::NoSignedWrap, {}); 928 } 929 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 930 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 931 SPIRV::Decoration::NoUnsignedWrap, ST, 932 Reqs) 933 .IsSatisfiable) { 934 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 935 SPIRV::Decoration::NoUnsignedWrap, {}); 936 } 937 if (!TII.canUseFastMathFlags(I)) 938 return; 939 unsigned FMFlags = getFastMathFlags(I); 940 if (FMFlags == SPIRV::FPFastMathMode::None) 941 return; 942 Register DstReg = I.getOperand(0).getReg(); 943 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 944 } 945 946 // Walk all functions and add decorations related to MI flags. 947 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 948 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 949 SPIRV::ModuleAnalysisInfo &MAI) { 950 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 951 MachineFunction *MF = MMI->getMachineFunction(*F); 952 if (!MF) 953 continue; 954 for (auto &MBB : *MF) 955 for (auto &MI : MBB) 956 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 957 } 958 } 959 960 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 961 962 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 963 AU.addRequired<TargetPassConfig>(); 964 AU.addRequired<MachineModuleInfoWrapperPass>(); 965 } 966 967 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 968 SPIRVTargetMachine &TM = 969 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 970 ST = TM.getSubtargetImpl(); 971 GR = ST->getSPIRVGlobalRegistry(); 972 TII = ST->getInstrInfo(); 973 974 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 975 976 setBaseInfo(M); 977 978 addDecorations(M, *TII, MMI, *ST, MAI); 979 980 collectReqs(M, MAI, MMI, *ST); 981 982 // Process type/const/global var/func decl instructions, number their 983 // destination registers from 0 to N, collect Extensions and Capabilities. 984 processDefInstrs(M); 985 986 // Number rest of registers from N+1 onwards. 987 numberRegistersGlobally(M); 988 989 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 990 processOtherInstrs(M); 991 992 // If there are no entry points, we need the Linkage capability. 993 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 994 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 995 996 return false; 997 } 998