1 //===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 /// Root Signatures.
11 ///
12 //===----------------------------------------------------------------------===//
13 #include "DXILRootSignature.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/DXILMetadataAnalysis.h"
18 #include "llvm/BinaryFormat/DXContainer.h"
19 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/DiagnosticInfo.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Metadata.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <cstdint>
32 #include <optional>
33 #include <utility>
34
35 using namespace llvm;
36 using namespace llvm::dxil;
37
reportError(LLVMContext * Ctx,Twine Message,DiagnosticSeverity Severity=DS_Error)38 static bool reportError(LLVMContext *Ctx, Twine Message,
39 DiagnosticSeverity Severity = DS_Error) {
40 Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
41 return true;
42 }
43
reportValueError(LLVMContext * Ctx,Twine ParamName,uint32_t Value)44 static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
45 uint32_t Value) {
46 Ctx->diagnose(DiagnosticInfoGeneric(
47 "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
48 return true;
49 }
50
extractMdIntValue(MDNode * Node,unsigned int OpId)51 static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
52 unsigned int OpId) {
53 if (auto *CI =
54 mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
55 return CI->getZExtValue();
56 return std::nullopt;
57 }
58
extractMdFloatValue(MDNode * Node,unsigned int OpId)59 static std::optional<float> extractMdFloatValue(MDNode *Node,
60 unsigned int OpId) {
61 if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62 return CI->getValueAPF().convertToFloat();
63 return std::nullopt;
64 }
65
extractMdStringValue(MDNode * Node,unsigned int OpId)66 static std::optional<StringRef> extractMdStringValue(MDNode *Node,
67 unsigned int OpId) {
68 MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
69 if (NodeText == nullptr)
70 return std::nullopt;
71 return NodeText->getString();
72 }
73
parseRootFlags(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootFlagNode)74 static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
75 MDNode *RootFlagNode) {
76
77 if (RootFlagNode->getNumOperands() != 2)
78 return reportError(Ctx, "Invalid format for RootFlag Element");
79
80 if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
81 RSD.Flags = *Val;
82 else
83 return reportError(Ctx, "Invalid value for RootFlag");
84
85 return false;
86 }
87
parseRootConstants(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootConstantNode)88 static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
89 MDNode *RootConstantNode) {
90
91 if (RootConstantNode->getNumOperands() != 5)
92 return reportError(Ctx, "Invalid format for RootConstants Element");
93
94 dxbc::RTS0::v1::RootParameterHeader Header;
95 // The parameter offset doesn't matter here - we recalculate it during
96 // serialization Header.ParameterOffset = 0;
97 Header.ParameterType =
98 llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
99
100 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
101 Header.ShaderVisibility = *Val;
102 else
103 return reportError(Ctx, "Invalid value for ShaderVisibility");
104
105 dxbc::RTS0::v1::RootConstants Constants;
106 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
107 Constants.ShaderRegister = *Val;
108 else
109 return reportError(Ctx, "Invalid value for ShaderRegister");
110
111 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
112 Constants.RegisterSpace = *Val;
113 else
114 return reportError(Ctx, "Invalid value for RegisterSpace");
115
116 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
117 Constants.Num32BitValues = *Val;
118 else
119 return reportError(Ctx, "Invalid value for Num32BitValues");
120
121 RSD.ParametersContainer.addParameter(Header, Constants);
122
123 return false;
124 }
125
parseRootDescriptors(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootDescriptorNode,RootSignatureElementKind ElementKind)126 static bool parseRootDescriptors(LLVMContext *Ctx,
127 mcdxbc::RootSignatureDesc &RSD,
128 MDNode *RootDescriptorNode,
129 RootSignatureElementKind ElementKind) {
130 assert(ElementKind == RootSignatureElementKind::SRV ||
131 ElementKind == RootSignatureElementKind::UAV ||
132 ElementKind == RootSignatureElementKind::CBV &&
133 "parseRootDescriptors should only be called with RootDescriptor "
134 "element kind.");
135 if (RootDescriptorNode->getNumOperands() != 5)
136 return reportError(Ctx, "Invalid format for Root Descriptor Element");
137
138 dxbc::RTS0::v1::RootParameterHeader Header;
139 switch (ElementKind) {
140 case RootSignatureElementKind::SRV:
141 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
142 break;
143 case RootSignatureElementKind::UAV:
144 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
145 break;
146 case RootSignatureElementKind::CBV:
147 Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
148 break;
149 default:
150 llvm_unreachable("invalid Root Descriptor kind");
151 break;
152 }
153
154 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
155 Header.ShaderVisibility = *Val;
156 else
157 return reportError(Ctx, "Invalid value for ShaderVisibility");
158
159 dxbc::RTS0::v2::RootDescriptor Descriptor;
160 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
161 Descriptor.ShaderRegister = *Val;
162 else
163 return reportError(Ctx, "Invalid value for ShaderRegister");
164
165 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
166 Descriptor.RegisterSpace = *Val;
167 else
168 return reportError(Ctx, "Invalid value for RegisterSpace");
169
170 if (RSD.Version == 1) {
171 RSD.ParametersContainer.addParameter(Header, Descriptor);
172 return false;
173 }
174 assert(RSD.Version > 1);
175
176 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
177 Descriptor.Flags = *Val;
178 else
179 return reportError(Ctx, "Invalid value for Root Descriptor Flags");
180
181 RSD.ParametersContainer.addParameter(Header, Descriptor);
182 return false;
183 }
184
parseDescriptorRange(LLVMContext * Ctx,mcdxbc::DescriptorTable & Table,MDNode * RangeDescriptorNode)185 static bool parseDescriptorRange(LLVMContext *Ctx,
186 mcdxbc::DescriptorTable &Table,
187 MDNode *RangeDescriptorNode) {
188
189 if (RangeDescriptorNode->getNumOperands() != 6)
190 return reportError(Ctx, "Invalid format for Descriptor Range");
191
192 dxbc::RTS0::v2::DescriptorRange Range;
193
194 std::optional<StringRef> ElementText =
195 extractMdStringValue(RangeDescriptorNode, 0);
196
197 if (!ElementText.has_value())
198 return reportError(Ctx, "Descriptor Range, first element is not a string.");
199
200 Range.RangeType =
201 StringSwitch<uint32_t>(*ElementText)
202 .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
203 .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
204 .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
205 .Case("Sampler",
206 llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
207 .Default(~0U);
208
209 if (Range.RangeType == ~0U)
210 return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
211
212 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
213 Range.NumDescriptors = *Val;
214 else
215 return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
216
217 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
218 Range.BaseShaderRegister = *Val;
219 else
220 return reportError(Ctx, "Invalid value for BaseShaderRegister");
221
222 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
223 Range.RegisterSpace = *Val;
224 else
225 return reportError(Ctx, "Invalid value for RegisterSpace");
226
227 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
228 Range.OffsetInDescriptorsFromTableStart = *Val;
229 else
230 return reportError(Ctx,
231 "Invalid value for OffsetInDescriptorsFromTableStart");
232
233 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
234 Range.Flags = *Val;
235 else
236 return reportError(Ctx, "Invalid value for Descriptor Range Flags");
237
238 Table.Ranges.push_back(Range);
239 return false;
240 }
241
parseDescriptorTable(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * DescriptorTableNode)242 static bool parseDescriptorTable(LLVMContext *Ctx,
243 mcdxbc::RootSignatureDesc &RSD,
244 MDNode *DescriptorTableNode) {
245 const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
246 if (NumOperands < 2)
247 return reportError(Ctx, "Invalid format for Descriptor Table");
248
249 dxbc::RTS0::v1::RootParameterHeader Header;
250 if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
251 Header.ShaderVisibility = *Val;
252 else
253 return reportError(Ctx, "Invalid value for ShaderVisibility");
254
255 mcdxbc::DescriptorTable Table;
256 Header.ParameterType =
257 llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
258
259 for (unsigned int I = 2; I < NumOperands; I++) {
260 MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
261 if (Element == nullptr)
262 return reportError(Ctx, "Missing Root Element Metadata Node.");
263
264 if (parseDescriptorRange(Ctx, Table, Element))
265 return true;
266 }
267
268 RSD.ParametersContainer.addParameter(Header, Table);
269 return false;
270 }
271
parseStaticSampler(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * StaticSamplerNode)272 static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
273 MDNode *StaticSamplerNode) {
274 if (StaticSamplerNode->getNumOperands() != 14)
275 return reportError(Ctx, "Invalid format for Static Sampler");
276
277 dxbc::RTS0::v1::StaticSampler Sampler;
278 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
279 Sampler.Filter = *Val;
280 else
281 return reportError(Ctx, "Invalid value for Filter");
282
283 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
284 Sampler.AddressU = *Val;
285 else
286 return reportError(Ctx, "Invalid value for AddressU");
287
288 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
289 Sampler.AddressV = *Val;
290 else
291 return reportError(Ctx, "Invalid value for AddressV");
292
293 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
294 Sampler.AddressW = *Val;
295 else
296 return reportError(Ctx, "Invalid value for AddressW");
297
298 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
299 Sampler.MipLODBias = *Val;
300 else
301 return reportError(Ctx, "Invalid value for MipLODBias");
302
303 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
304 Sampler.MaxAnisotropy = *Val;
305 else
306 return reportError(Ctx, "Invalid value for MaxAnisotropy");
307
308 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
309 Sampler.ComparisonFunc = *Val;
310 else
311 return reportError(Ctx, "Invalid value for ComparisonFunc ");
312
313 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
314 Sampler.BorderColor = *Val;
315 else
316 return reportError(Ctx, "Invalid value for ComparisonFunc ");
317
318 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
319 Sampler.MinLOD = *Val;
320 else
321 return reportError(Ctx, "Invalid value for MinLOD");
322
323 if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
324 Sampler.MaxLOD = *Val;
325 else
326 return reportError(Ctx, "Invalid value for MaxLOD");
327
328 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
329 Sampler.ShaderRegister = *Val;
330 else
331 return reportError(Ctx, "Invalid value for ShaderRegister");
332
333 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
334 Sampler.RegisterSpace = *Val;
335 else
336 return reportError(Ctx, "Invalid value for RegisterSpace");
337
338 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
339 Sampler.ShaderVisibility = *Val;
340 else
341 return reportError(Ctx, "Invalid value for ShaderVisibility");
342
343 RSD.StaticSamplers.push_back(Sampler);
344 return false;
345 }
346
parseRootSignatureElement(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * Element)347 static bool parseRootSignatureElement(LLVMContext *Ctx,
348 mcdxbc::RootSignatureDesc &RSD,
349 MDNode *Element) {
350 std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
351 if (!ElementText.has_value())
352 return reportError(Ctx, "Invalid format for Root Element");
353
354 RootSignatureElementKind ElementKind =
355 StringSwitch<RootSignatureElementKind>(*ElementText)
356 .Case("RootFlags", RootSignatureElementKind::RootFlags)
357 .Case("RootConstants", RootSignatureElementKind::RootConstants)
358 .Case("RootCBV", RootSignatureElementKind::CBV)
359 .Case("RootSRV", RootSignatureElementKind::SRV)
360 .Case("RootUAV", RootSignatureElementKind::UAV)
361 .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
362 .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
363 .Default(RootSignatureElementKind::Error);
364
365 switch (ElementKind) {
366
367 case RootSignatureElementKind::RootFlags:
368 return parseRootFlags(Ctx, RSD, Element);
369 case RootSignatureElementKind::RootConstants:
370 return parseRootConstants(Ctx, RSD, Element);
371 case RootSignatureElementKind::CBV:
372 case RootSignatureElementKind::SRV:
373 case RootSignatureElementKind::UAV:
374 return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
375 case RootSignatureElementKind::DescriptorTable:
376 return parseDescriptorTable(Ctx, RSD, Element);
377 case RootSignatureElementKind::StaticSamplers:
378 return parseStaticSampler(Ctx, RSD, Element);
379 case RootSignatureElementKind::Error:
380 return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
381 }
382
383 llvm_unreachable("Unhandled RootSignatureElementKind enum.");
384 }
385
parse(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * Node)386 static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
387 MDNode *Node) {
388 bool HasError = false;
389
390 // Loop through the Root Elements of the root signature.
391 for (const auto &Operand : Node->operands()) {
392 MDNode *Element = dyn_cast<MDNode>(Operand);
393 if (Element == nullptr)
394 return reportError(Ctx, "Missing Root Element Metadata Node.");
395
396 HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
397 }
398
399 return HasError;
400 }
401
validate(LLVMContext * Ctx,const mcdxbc::RootSignatureDesc & RSD)402 static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
403
404 if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
405 return reportValueError(Ctx, "Version", RSD.Version);
406 }
407
408 if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
409 return reportValueError(Ctx, "RootFlags", RSD.Flags);
410 }
411
412 for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
413 if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
414 return reportValueError(Ctx, "ShaderVisibility",
415 Info.Header.ShaderVisibility);
416
417 assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
418 "Invalid value for ParameterType");
419
420 switch (Info.Header.ParameterType) {
421
422 case llvm::to_underlying(dxbc::RootParameterType::CBV):
423 case llvm::to_underlying(dxbc::RootParameterType::UAV):
424 case llvm::to_underlying(dxbc::RootParameterType::SRV): {
425 const dxbc::RTS0::v2::RootDescriptor &Descriptor =
426 RSD.ParametersContainer.getRootDescriptor(Info.Location);
427 if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
428 return reportValueError(Ctx, "ShaderRegister",
429 Descriptor.ShaderRegister);
430
431 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
432 return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
433
434 if (RSD.Version > 1) {
435 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
436 Descriptor.Flags))
437 return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
438 }
439 break;
440 }
441 case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
442 const mcdxbc::DescriptorTable &Table =
443 RSD.ParametersContainer.getDescriptorTable(Info.Location);
444 for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
445 if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
446 return reportValueError(Ctx, "RangeType", Range.RangeType);
447
448 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
449 return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
450
451 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
452 return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
453
454 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
455 RSD.Version, Range.RangeType, Range.Flags))
456 return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
457 }
458 break;
459 }
460 }
461 }
462
463 for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
464 if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
465 return reportValueError(Ctx, "Filter", Sampler.Filter);
466
467 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
468 return reportValueError(Ctx, "AddressU", Sampler.AddressU);
469
470 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
471 return reportValueError(Ctx, "AddressV", Sampler.AddressV);
472
473 if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
474 return reportValueError(Ctx, "AddressW", Sampler.AddressW);
475
476 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
477 return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
478
479 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
480 return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
481
482 if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
483 return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
484
485 if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
486 return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
487
488 if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
489 return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
490
491 if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
492 return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
493
494 if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
495 return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
496
497 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
498 return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
499
500 if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
501 return reportValueError(Ctx, "ShaderVisibility",
502 Sampler.ShaderVisibility);
503 }
504
505 return false;
506 }
507
508 static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
analyzeModule(Module & M)509 analyzeModule(Module &M) {
510
511 /** Root Signature are specified as following in the metadata:
512
513 !dx.rootsignatures = !{!2} ; list of function/root signature pairs
514 !2 = !{ ptr @main, !3 } ; function, root signature
515 !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
516
517 So for each MDNode inside dx.rootsignatures NamedMDNode
518 (the Root parameter of this function), the parsing process needs
519 to loop through each of its operands and process the function,
520 signature pair.
521 */
522
523 LLVMContext *Ctx = &M.getContext();
524
525 SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
526
527 NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
528 if (RootSignatureNode == nullptr)
529 return RSDMap;
530
531 for (const auto &RSDefNode : RootSignatureNode->operands()) {
532 if (RSDefNode->getNumOperands() != 3) {
533 reportError(Ctx, "Invalid Root Signature metadata - expected function, "
534 "signature, and version.");
535 continue;
536 }
537
538 // Function was pruned during compilation.
539 const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
540 if (FunctionPointerMdNode == nullptr) {
541 reportError(
542 Ctx, "Function associated with Root Signature definition is null.");
543 continue;
544 }
545
546 ValueAsMetadata *VAM =
547 llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
548 if (VAM == nullptr) {
549 reportError(Ctx, "First element of root signature is not a Value");
550 continue;
551 }
552
553 Function *F = dyn_cast<Function>(VAM->getValue());
554 if (F == nullptr) {
555 reportError(Ctx, "First element of root signature is not a Function");
556 continue;
557 }
558
559 Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
560
561 if (RootElementListOperand == nullptr) {
562 reportError(Ctx, "Root Element mdnode is null.");
563 continue;
564 }
565
566 MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand);
567 if (RootElementListNode == nullptr) {
568 reportError(Ctx, "Root Element is not a metadata node.");
569 continue;
570 }
571 mcdxbc::RootSignatureDesc RSD;
572 if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
573 RSD.Version = *Version;
574 else {
575 reportError(Ctx, "Invalid RSDefNode value, expected constant int");
576 continue;
577 }
578
579 // Clang emits the root signature data in dxcontainer following a specific
580 // sequence. First the header, then the root parameters. So the header
581 // offset will always equal to the header size.
582 RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
583
584 // static sampler offset is calculated when writting dxcontainer.
585 RSD.StaticSamplersOffset = 0u;
586
587 if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
588 return RSDMap;
589 }
590
591 RSDMap.insert(std::make_pair(F, RSD));
592 }
593
594 return RSDMap;
595 }
596
597 AnalysisKey RootSignatureAnalysis::Key;
598
599 RootSignatureAnalysis::Result
run(Module & M,ModuleAnalysisManager & AM)600 RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
601 return RootSignatureBindingInfo(analyzeModule(M));
602 }
603
604 //===----------------------------------------------------------------------===//
605
run(Module & M,ModuleAnalysisManager & AM)606 PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
607 ModuleAnalysisManager &AM) {
608
609 RootSignatureBindingInfo &RSDMap = AM.getResult<RootSignatureAnalysis>(M);
610
611 OS << "Root Signature Definitions"
612 << "\n";
613 for (const Function &F : M) {
614 auto It = RSDMap.find(&F);
615 if (It == RSDMap.end())
616 continue;
617 const auto &RS = It->second;
618 OS << "Definition for '" << F.getName() << "':\n";
619 // start root signature header
620 OS << "Flags: " << format_hex(RS.Flags, 8) << "\n"
621 << "Version: " << RS.Version << "\n"
622 << "RootParametersOffset: " << RS.RootParameterOffset << "\n"
623 << "NumParameters: " << RS.ParametersContainer.size() << "\n";
624 for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
625 const auto &[Type, Loc] =
626 RS.ParametersContainer.getTypeAndLocForParameter(I);
627 const dxbc::RTS0::v1::RootParameterHeader Header =
628 RS.ParametersContainer.getHeader(I);
629
630 OS << "- Parameter Type: " << Type << "\n"
631 << " Shader Visibility: " << Header.ShaderVisibility << "\n";
632
633 switch (Type) {
634 case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
635 const dxbc::RTS0::v1::RootConstants &Constants =
636 RS.ParametersContainer.getConstant(Loc);
637 OS << " Register Space: " << Constants.RegisterSpace << "\n"
638 << " Shader Register: " << Constants.ShaderRegister << "\n"
639 << " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
640 break;
641 }
642 case llvm::to_underlying(dxbc::RootParameterType::CBV):
643 case llvm::to_underlying(dxbc::RootParameterType::UAV):
644 case llvm::to_underlying(dxbc::RootParameterType::SRV): {
645 const dxbc::RTS0::v2::RootDescriptor &Descriptor =
646 RS.ParametersContainer.getRootDescriptor(Loc);
647 OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
648 << " Shader Register: " << Descriptor.ShaderRegister << "\n";
649 if (RS.Version > 1)
650 OS << " Flags: " << Descriptor.Flags << "\n";
651 break;
652 }
653 case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
654 const mcdxbc::DescriptorTable &Table =
655 RS.ParametersContainer.getDescriptorTable(Loc);
656 OS << " NumRanges: " << Table.Ranges.size() << "\n";
657
658 for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
659 OS << " - Range Type: " << Range.RangeType << "\n"
660 << " Register Space: " << Range.RegisterSpace << "\n"
661 << " Base Shader Register: " << Range.BaseShaderRegister << "\n"
662 << " Num Descriptors: " << Range.NumDescriptors << "\n"
663 << " Offset In Descriptors From Table Start: "
664 << Range.OffsetInDescriptorsFromTableStart << "\n";
665 if (RS.Version > 1)
666 OS << " Flags: " << Range.Flags << "\n";
667 }
668 break;
669 }
670 }
671 }
672 OS << "NumStaticSamplers: " << 0 << "\n";
673 OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n";
674 }
675 return PreservedAnalyses::all();
676 }
677
678 //===----------------------------------------------------------------------===//
runOnModule(Module & M)679 bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
680 FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
681 RootSignatureBindingInfo(analyzeModule(M)));
682 return false;
683 }
684
getAnalysisUsage(AnalysisUsage & AU) const685 void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
686 AU.setPreservesAll();
687 AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
688 }
689
690 char RootSignatureAnalysisWrapper::ID = 0;
691
692 INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
693 "dxil-root-signature-analysis",
694 "DXIL Root Signature Analysis", true, true)
695 INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
696 "dxil-root-signature-analysis",
697 "DXIL Root Signature Analysis", true, true)
698