1 //===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helper objects and APIs for working with DXIL 10 /// Root Signatures. 11 /// 12 //===----------------------------------------------------------------------===// 13 #include "DXILRootSignature.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/StringSwitch.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Analysis/DXILMetadataAnalysis.h" 18 #include "llvm/BinaryFormat/DXContainer.h" 19 #include "llvm/Frontend/HLSL/RootSignatureValidations.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/DiagnosticInfo.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/LLVMContext.h" 24 #include "llvm/IR/Metadata.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/Error.h" 29 #include "llvm/Support/ErrorHandling.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <cstdint> 32 #include <optional> 33 #include <utility> 34 35 using namespace llvm; 36 using namespace llvm::dxil; 37 38 static bool reportError(LLVMContext *Ctx, Twine Message, 39 DiagnosticSeverity Severity = DS_Error) { 40 Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); 41 return true; 42 } 43 44 static bool reportValueError(LLVMContext *Ctx, Twine ParamName, 45 uint32_t Value) { 46 Ctx->diagnose(DiagnosticInfoGeneric( 47 "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); 48 return true; 49 } 50 51 static std::optional<uint32_t> extractMdIntValue(MDNode *Node, 52 unsigned int OpId) { 53 if (auto *CI = 54 mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get())) 55 return CI->getZExtValue(); 56 return std::nullopt; 57 } 58 59 static std::optional<float> extractMdFloatValue(MDNode *Node, 60 unsigned int OpId) { 61 if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get())) 62 return CI->getValueAPF().convertToFloat(); 63 return std::nullopt; 64 } 65 66 static std::optional<StringRef> extractMdStringValue(MDNode *Node, 67 unsigned int OpId) { 68 MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId)); 69 if (NodeText == nullptr) 70 return std::nullopt; 71 return NodeText->getString(); 72 } 73 74 static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, 75 MDNode *RootFlagNode) { 76 77 if (RootFlagNode->getNumOperands() != 2) 78 return reportError(Ctx, "Invalid format for RootFlag Element"); 79 80 if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) 81 RSD.Flags = *Val; 82 else 83 return reportError(Ctx, "Invalid value for RootFlag"); 84 85 return false; 86 } 87 88 static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, 89 MDNode *RootConstantNode) { 90 91 if (RootConstantNode->getNumOperands() != 5) 92 return reportError(Ctx, "Invalid format for RootConstants Element"); 93 94 dxbc::RTS0::v1::RootParameterHeader Header; 95 // The parameter offset doesn't matter here - we recalculate it during 96 // serialization Header.ParameterOffset = 0; 97 Header.ParameterType = 98 llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); 99 100 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) 101 Header.ShaderVisibility = *Val; 102 else 103 return reportError(Ctx, "Invalid value for ShaderVisibility"); 104 105 dxbc::RTS0::v1::RootConstants Constants; 106 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) 107 Constants.ShaderRegister = *Val; 108 else 109 return reportError(Ctx, "Invalid value for ShaderRegister"); 110 111 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) 112 Constants.RegisterSpace = *Val; 113 else 114 return reportError(Ctx, "Invalid value for RegisterSpace"); 115 116 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) 117 Constants.Num32BitValues = *Val; 118 else 119 return reportError(Ctx, "Invalid value for Num32BitValues"); 120 121 RSD.ParametersContainer.addParameter(Header, Constants); 122 123 return false; 124 } 125 126 static bool parseRootDescriptors(LLVMContext *Ctx, 127 mcdxbc::RootSignatureDesc &RSD, 128 MDNode *RootDescriptorNode, 129 RootSignatureElementKind ElementKind) { 130 assert(ElementKind == RootSignatureElementKind::SRV || 131 ElementKind == RootSignatureElementKind::UAV || 132 ElementKind == RootSignatureElementKind::CBV && 133 "parseRootDescriptors should only be called with RootDescriptor " 134 "element kind."); 135 if (RootDescriptorNode->getNumOperands() != 5) 136 return reportError(Ctx, "Invalid format for Root Descriptor Element"); 137 138 dxbc::RTS0::v1::RootParameterHeader Header; 139 switch (ElementKind) { 140 case RootSignatureElementKind::SRV: 141 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); 142 break; 143 case RootSignatureElementKind::UAV: 144 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); 145 break; 146 case RootSignatureElementKind::CBV: 147 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); 148 break; 149 default: 150 llvm_unreachable("invalid Root Descriptor kind"); 151 break; 152 } 153 154 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) 155 Header.ShaderVisibility = *Val; 156 else 157 return reportError(Ctx, "Invalid value for ShaderVisibility"); 158 159 dxbc::RTS0::v2::RootDescriptor Descriptor; 160 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) 161 Descriptor.ShaderRegister = *Val; 162 else 163 return reportError(Ctx, "Invalid value for ShaderRegister"); 164 165 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) 166 Descriptor.RegisterSpace = *Val; 167 else 168 return reportError(Ctx, "Invalid value for RegisterSpace"); 169 170 if (RSD.Version == 1) { 171 RSD.ParametersContainer.addParameter(Header, Descriptor); 172 return false; 173 } 174 assert(RSD.Version > 1); 175 176 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) 177 Descriptor.Flags = *Val; 178 else 179 return reportError(Ctx, "Invalid value for Root Descriptor Flags"); 180 181 RSD.ParametersContainer.addParameter(Header, Descriptor); 182 return false; 183 } 184 185 static bool parseDescriptorRange(LLVMContext *Ctx, 186 mcdxbc::DescriptorTable &Table, 187 MDNode *RangeDescriptorNode) { 188 189 if (RangeDescriptorNode->getNumOperands() != 6) 190 return reportError(Ctx, "Invalid format for Descriptor Range"); 191 192 dxbc::RTS0::v2::DescriptorRange Range; 193 194 std::optional<StringRef> ElementText = 195 extractMdStringValue(RangeDescriptorNode, 0); 196 197 if (!ElementText.has_value()) 198 return reportError(Ctx, "Descriptor Range, first element is not a string."); 199 200 Range.RangeType = 201 StringSwitch<uint32_t>(*ElementText) 202 .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) 203 .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) 204 .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) 205 .Case("Sampler", 206 llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) 207 .Default(~0U); 208 209 if (Range.RangeType == ~0U) 210 return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); 211 212 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) 213 Range.NumDescriptors = *Val; 214 else 215 return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); 216 217 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) 218 Range.BaseShaderRegister = *Val; 219 else 220 return reportError(Ctx, "Invalid value for BaseShaderRegister"); 221 222 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) 223 Range.RegisterSpace = *Val; 224 else 225 return reportError(Ctx, "Invalid value for RegisterSpace"); 226 227 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) 228 Range.OffsetInDescriptorsFromTableStart = *Val; 229 else 230 return reportError(Ctx, 231 "Invalid value for OffsetInDescriptorsFromTableStart"); 232 233 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) 234 Range.Flags = *Val; 235 else 236 return reportError(Ctx, "Invalid value for Descriptor Range Flags"); 237 238 Table.Ranges.push_back(Range); 239 return false; 240 } 241 242 static bool parseDescriptorTable(LLVMContext *Ctx, 243 mcdxbc::RootSignatureDesc &RSD, 244 MDNode *DescriptorTableNode) { 245 const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); 246 if (NumOperands < 2) 247 return reportError(Ctx, "Invalid format for Descriptor Table"); 248 249 dxbc::RTS0::v1::RootParameterHeader Header; 250 if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) 251 Header.ShaderVisibility = *Val; 252 else 253 return reportError(Ctx, "Invalid value for ShaderVisibility"); 254 255 mcdxbc::DescriptorTable Table; 256 Header.ParameterType = 257 llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); 258 259 for (unsigned int I = 2; I < NumOperands; I++) { 260 MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); 261 if (Element == nullptr) 262 return reportError(Ctx, "Missing Root Element Metadata Node."); 263 264 if (parseDescriptorRange(Ctx, Table, Element)) 265 return true; 266 } 267 268 RSD.ParametersContainer.addParameter(Header, Table); 269 return false; 270 } 271 272 static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, 273 MDNode *StaticSamplerNode) { 274 if (StaticSamplerNode->getNumOperands() != 14) 275 return reportError(Ctx, "Invalid format for Static Sampler"); 276 277 dxbc::RTS0::v1::StaticSampler Sampler; 278 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) 279 Sampler.Filter = *Val; 280 else 281 return reportError(Ctx, "Invalid value for Filter"); 282 283 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) 284 Sampler.AddressU = *Val; 285 else 286 return reportError(Ctx, "Invalid value for AddressU"); 287 288 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) 289 Sampler.AddressV = *Val; 290 else 291 return reportError(Ctx, "Invalid value for AddressV"); 292 293 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) 294 Sampler.AddressW = *Val; 295 else 296 return reportError(Ctx, "Invalid value for AddressW"); 297 298 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) 299 Sampler.MipLODBias = *Val; 300 else 301 return reportError(Ctx, "Invalid value for MipLODBias"); 302 303 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) 304 Sampler.MaxAnisotropy = *Val; 305 else 306 return reportError(Ctx, "Invalid value for MaxAnisotropy"); 307 308 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) 309 Sampler.ComparisonFunc = *Val; 310 else 311 return reportError(Ctx, "Invalid value for ComparisonFunc "); 312 313 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) 314 Sampler.BorderColor = *Val; 315 else 316 return reportError(Ctx, "Invalid value for ComparisonFunc "); 317 318 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) 319 Sampler.MinLOD = *Val; 320 else 321 return reportError(Ctx, "Invalid value for MinLOD"); 322 323 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10)) 324 Sampler.MaxLOD = *Val; 325 else 326 return reportError(Ctx, "Invalid value for MaxLOD"); 327 328 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) 329 Sampler.ShaderRegister = *Val; 330 else 331 return reportError(Ctx, "Invalid value for ShaderRegister"); 332 333 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) 334 Sampler.RegisterSpace = *Val; 335 else 336 return reportError(Ctx, "Invalid value for RegisterSpace"); 337 338 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) 339 Sampler.ShaderVisibility = *Val; 340 else 341 return reportError(Ctx, "Invalid value for ShaderVisibility"); 342 343 RSD.StaticSamplers.push_back(Sampler); 344 return false; 345 } 346 347 static bool parseRootSignatureElement(LLVMContext *Ctx, 348 mcdxbc::RootSignatureDesc &RSD, 349 MDNode *Element) { 350 std::optional<StringRef> ElementText = extractMdStringValue(Element, 0); 351 if (!ElementText.has_value()) 352 return reportError(Ctx, "Invalid format for Root Element"); 353 354 RootSignatureElementKind ElementKind = 355 StringSwitch<RootSignatureElementKind>(*ElementText) 356 .Case("RootFlags", RootSignatureElementKind::RootFlags) 357 .Case("RootConstants", RootSignatureElementKind::RootConstants) 358 .Case("RootCBV", RootSignatureElementKind::CBV) 359 .Case("RootSRV", RootSignatureElementKind::SRV) 360 .Case("RootUAV", RootSignatureElementKind::UAV) 361 .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable) 362 .Case("StaticSampler", RootSignatureElementKind::StaticSamplers) 363 .Default(RootSignatureElementKind::Error); 364 365 switch (ElementKind) { 366 367 case RootSignatureElementKind::RootFlags: 368 return parseRootFlags(Ctx, RSD, Element); 369 case RootSignatureElementKind::RootConstants: 370 return parseRootConstants(Ctx, RSD, Element); 371 case RootSignatureElementKind::CBV: 372 case RootSignatureElementKind::SRV: 373 case RootSignatureElementKind::UAV: 374 return parseRootDescriptors(Ctx, RSD, Element, ElementKind); 375 case RootSignatureElementKind::DescriptorTable: 376 return parseDescriptorTable(Ctx, RSD, Element); 377 case RootSignatureElementKind::StaticSamplers: 378 return parseStaticSampler(Ctx, RSD, Element); 379 case RootSignatureElementKind::Error: 380 return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); 381 } 382 383 llvm_unreachable("Unhandled RootSignatureElementKind enum."); 384 } 385 386 static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, 387 MDNode *Node) { 388 bool HasError = false; 389 390 // Loop through the Root Elements of the root signature. 391 for (const auto &Operand : Node->operands()) { 392 MDNode *Element = dyn_cast<MDNode>(Operand); 393 if (Element == nullptr) 394 return reportError(Ctx, "Missing Root Element Metadata Node."); 395 396 HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element); 397 } 398 399 return HasError; 400 } 401 402 static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { 403 404 if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { 405 return reportValueError(Ctx, "Version", RSD.Version); 406 } 407 408 if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { 409 return reportValueError(Ctx, "RootFlags", RSD.Flags); 410 } 411 412 for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { 413 if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) 414 return reportValueError(Ctx, "ShaderVisibility", 415 Info.Header.ShaderVisibility); 416 417 assert(dxbc::isValidParameterType(Info.Header.ParameterType) && 418 "Invalid value for ParameterType"); 419 420 switch (Info.Header.ParameterType) { 421 422 case llvm::to_underlying(dxbc::RootParameterType::CBV): 423 case llvm::to_underlying(dxbc::RootParameterType::UAV): 424 case llvm::to_underlying(dxbc::RootParameterType::SRV): { 425 const dxbc::RTS0::v2::RootDescriptor &Descriptor = 426 RSD.ParametersContainer.getRootDescriptor(Info.Location); 427 if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) 428 return reportValueError(Ctx, "ShaderRegister", 429 Descriptor.ShaderRegister); 430 431 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) 432 return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); 433 434 if (RSD.Version > 1) { 435 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, 436 Descriptor.Flags)) 437 return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); 438 } 439 break; 440 } 441 case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { 442 const mcdxbc::DescriptorTable &Table = 443 RSD.ParametersContainer.getDescriptorTable(Info.Location); 444 for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { 445 if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) 446 return reportValueError(Ctx, "RangeType", Range.RangeType); 447 448 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) 449 return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); 450 451 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) 452 return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); 453 454 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( 455 RSD.Version, Range.RangeType, Range.Flags)) 456 return reportValueError(Ctx, "DescriptorFlag", Range.Flags); 457 } 458 break; 459 } 460 } 461 } 462 463 for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { 464 if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) 465 return reportValueError(Ctx, "Filter", Sampler.Filter); 466 467 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) 468 return reportValueError(Ctx, "AddressU", Sampler.AddressU); 469 470 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) 471 return reportValueError(Ctx, "AddressV", Sampler.AddressV); 472 473 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) 474 return reportValueError(Ctx, "AddressW", Sampler.AddressW); 475 476 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) 477 return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); 478 479 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) 480 return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); 481 482 if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) 483 return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); 484 485 if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) 486 return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); 487 488 if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) 489 return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); 490 491 if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) 492 return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); 493 494 if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) 495 return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); 496 497 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) 498 return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); 499 500 if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) 501 return reportValueError(Ctx, "ShaderVisibility", 502 Sampler.ShaderVisibility); 503 } 504 505 return false; 506 } 507 508 static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> 509 analyzeModule(Module &M) { 510 511 /** Root Signature are specified as following in the metadata: 512 513 !dx.rootsignatures = !{!2} ; list of function/root signature pairs 514 !2 = !{ ptr @main, !3 } ; function, root signature 515 !3 = !{ !4, !5, !6, !7 } ; list of root signature elements 516 517 So for each MDNode inside dx.rootsignatures NamedMDNode 518 (the Root parameter of this function), the parsing process needs 519 to loop through each of its operands and process the function, 520 signature pair. 521 */ 522 523 LLVMContext *Ctx = &M.getContext(); 524 525 SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap; 526 527 NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures"); 528 if (RootSignatureNode == nullptr) 529 return RSDMap; 530 531 for (const auto &RSDefNode : RootSignatureNode->operands()) { 532 if (RSDefNode->getNumOperands() != 3) { 533 reportError(Ctx, "Invalid Root Signature metadata - expected function, " 534 "signature, and version."); 535 continue; 536 } 537 538 // Function was pruned during compilation. 539 const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0); 540 if (FunctionPointerMdNode == nullptr) { 541 reportError( 542 Ctx, "Function associated with Root Signature definition is null."); 543 continue; 544 } 545 546 ValueAsMetadata *VAM = 547 llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get()); 548 if (VAM == nullptr) { 549 reportError(Ctx, "First element of root signature is not a Value"); 550 continue; 551 } 552 553 Function *F = dyn_cast<Function>(VAM->getValue()); 554 if (F == nullptr) { 555 reportError(Ctx, "First element of root signature is not a Function"); 556 continue; 557 } 558 559 Metadata *RootElementListOperand = RSDefNode->getOperand(1).get(); 560 561 if (RootElementListOperand == nullptr) { 562 reportError(Ctx, "Root Element mdnode is null."); 563 continue; 564 } 565 566 MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand); 567 if (RootElementListNode == nullptr) { 568 reportError(Ctx, "Root Element is not a metadata node."); 569 continue; 570 } 571 mcdxbc::RootSignatureDesc RSD; 572 if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2)) 573 RSD.Version = *Version; 574 else { 575 reportError(Ctx, "Invalid RSDefNode value, expected constant int"); 576 continue; 577 } 578 579 // Clang emits the root signature data in dxcontainer following a specific 580 // sequence. First the header, then the root parameters. So the header 581 // offset will always equal to the header size. 582 RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader); 583 584 // static sampler offset is calculated when writting dxcontainer. 585 RSD.StaticSamplersOffset = 0u; 586 587 if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) { 588 return RSDMap; 589 } 590 591 RSDMap.insert(std::make_pair(F, RSD)); 592 } 593 594 return RSDMap; 595 } 596 597 AnalysisKey RootSignatureAnalysis::Key; 598 599 RootSignatureAnalysis::Result 600 RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) { 601 return RootSignatureBindingInfo(analyzeModule(M)); 602 } 603 604 //===----------------------------------------------------------------------===// 605 606 PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, 607 ModuleAnalysisManager &AM) { 608 609 RootSignatureBindingInfo &RSDMap = AM.getResult<RootSignatureAnalysis>(M); 610 611 OS << "Root Signature Definitions" 612 << "\n"; 613 for (const Function &F : M) { 614 auto It = RSDMap.find(&F); 615 if (It == RSDMap.end()) 616 continue; 617 const auto &RS = It->second; 618 OS << "Definition for '" << F.getName() << "':\n"; 619 // start root signature header 620 OS << "Flags: " << format_hex(RS.Flags, 8) << "\n" 621 << "Version: " << RS.Version << "\n" 622 << "RootParametersOffset: " << RS.RootParameterOffset << "\n" 623 << "NumParameters: " << RS.ParametersContainer.size() << "\n"; 624 for (size_t I = 0; I < RS.ParametersContainer.size(); I++) { 625 const auto &[Type, Loc] = 626 RS.ParametersContainer.getTypeAndLocForParameter(I); 627 const dxbc::RTS0::v1::RootParameterHeader Header = 628 RS.ParametersContainer.getHeader(I); 629 630 OS << "- Parameter Type: " << Type << "\n" 631 << " Shader Visibility: " << Header.ShaderVisibility << "\n"; 632 633 switch (Type) { 634 case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { 635 const dxbc::RTS0::v1::RootConstants &Constants = 636 RS.ParametersContainer.getConstant(Loc); 637 OS << " Register Space: " << Constants.RegisterSpace << "\n" 638 << " Shader Register: " << Constants.ShaderRegister << "\n" 639 << " Num 32 Bit Values: " << Constants.Num32BitValues << "\n"; 640 break; 641 } 642 case llvm::to_underlying(dxbc::RootParameterType::CBV): 643 case llvm::to_underlying(dxbc::RootParameterType::UAV): 644 case llvm::to_underlying(dxbc::RootParameterType::SRV): { 645 const dxbc::RTS0::v2::RootDescriptor &Descriptor = 646 RS.ParametersContainer.getRootDescriptor(Loc); 647 OS << " Register Space: " << Descriptor.RegisterSpace << "\n" 648 << " Shader Register: " << Descriptor.ShaderRegister << "\n"; 649 if (RS.Version > 1) 650 OS << " Flags: " << Descriptor.Flags << "\n"; 651 break; 652 } 653 case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { 654 const mcdxbc::DescriptorTable &Table = 655 RS.ParametersContainer.getDescriptorTable(Loc); 656 OS << " NumRanges: " << Table.Ranges.size() << "\n"; 657 658 for (const dxbc::RTS0::v2::DescriptorRange Range : Table) { 659 OS << " - Range Type: " << Range.RangeType << "\n" 660 << " Register Space: " << Range.RegisterSpace << "\n" 661 << " Base Shader Register: " << Range.BaseShaderRegister << "\n" 662 << " Num Descriptors: " << Range.NumDescriptors << "\n" 663 << " Offset In Descriptors From Table Start: " 664 << Range.OffsetInDescriptorsFromTableStart << "\n"; 665 if (RS.Version > 1) 666 OS << " Flags: " << Range.Flags << "\n"; 667 } 668 break; 669 } 670 } 671 } 672 OS << "NumStaticSamplers: " << 0 << "\n"; 673 OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n"; 674 } 675 return PreservedAnalyses::all(); 676 } 677 678 //===----------------------------------------------------------------------===// 679 bool RootSignatureAnalysisWrapper::runOnModule(Module &M) { 680 FuncToRsMap = std::make_unique<RootSignatureBindingInfo>( 681 RootSignatureBindingInfo(analyzeModule(M))); 682 return false; 683 } 684 685 void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { 686 AU.setPreservesAll(); 687 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 688 } 689 690 char RootSignatureAnalysisWrapper::ID = 0; 691 692 INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper, 693 "dxil-root-signature-analysis", 694 "DXIL Root Signature Analysis", true, true) 695 INITIALIZE_PASS_END(RootSignatureAnalysisWrapper, 696 "dxil-root-signature-analysis", 697 "DXIL Root Signature Analysis", true, true) 698