xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILRootSignature.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 ///       Root Signatures.
11 ///
12 //===----------------------------------------------------------------------===//
13 #include "DXILRootSignature.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/DXILMetadataAnalysis.h"
18 #include "llvm/BinaryFormat/DXContainer.h"
19 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/DiagnosticInfo.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Metadata.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <cstdint>
32 #include <optional>
33 #include <utility>
34 
35 using namespace llvm;
36 using namespace llvm::dxil;
37 
reportError(LLVMContext * Ctx,Twine Message,DiagnosticSeverity Severity=DS_Error)38 static bool reportError(LLVMContext *Ctx, Twine Message,
39                         DiagnosticSeverity Severity = DS_Error) {
40   Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
41   return true;
42 }
43 
reportValueError(LLVMContext * Ctx,Twine ParamName,uint32_t Value)44 static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
45                              uint32_t Value) {
46   Ctx->diagnose(DiagnosticInfoGeneric(
47       "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
48   return true;
49 }
50 
extractMdIntValue(MDNode * Node,unsigned int OpId)51 static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
52                                                  unsigned int OpId) {
53   if (auto *CI =
54           mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
55     return CI->getZExtValue();
56   return std::nullopt;
57 }
58 
extractMdFloatValue(MDNode * Node,unsigned int OpId)59 static std::optional<float> extractMdFloatValue(MDNode *Node,
60                                                 unsigned int OpId) {
61   if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62     return CI->getValueAPF().convertToFloat();
63   return std::nullopt;
64 }
65 
extractMdStringValue(MDNode * Node,unsigned int OpId)66 static std::optional<StringRef> extractMdStringValue(MDNode *Node,
67                                                      unsigned int OpId) {
68   MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
69   if (NodeText == nullptr)
70     return std::nullopt;
71   return NodeText->getString();
72 }
73 
parseRootFlags(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootFlagNode)74 static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
75                            MDNode *RootFlagNode) {
76 
77   if (RootFlagNode->getNumOperands() != 2)
78     return reportError(Ctx, "Invalid format for RootFlag Element");
79 
80   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
81     RSD.Flags = *Val;
82   else
83     return reportError(Ctx, "Invalid value for RootFlag");
84 
85   return false;
86 }
87 
parseRootConstants(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootConstantNode)88 static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
89                                MDNode *RootConstantNode) {
90 
91   if (RootConstantNode->getNumOperands() != 5)
92     return reportError(Ctx, "Invalid format for RootConstants Element");
93 
94   dxbc::RTS0::v1::RootParameterHeader Header;
95   // The parameter offset doesn't matter here - we recalculate it during
96   // serialization  Header.ParameterOffset = 0;
97   Header.ParameterType =
98       llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
99 
100   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
101     Header.ShaderVisibility = *Val;
102   else
103     return reportError(Ctx, "Invalid value for ShaderVisibility");
104 
105   dxbc::RTS0::v1::RootConstants Constants;
106   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
107     Constants.ShaderRegister = *Val;
108   else
109     return reportError(Ctx, "Invalid value for ShaderRegister");
110 
111   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
112     Constants.RegisterSpace = *Val;
113   else
114     return reportError(Ctx, "Invalid value for RegisterSpace");
115 
116   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
117     Constants.Num32BitValues = *Val;
118   else
119     return reportError(Ctx, "Invalid value for Num32BitValues");
120 
121   RSD.ParametersContainer.addParameter(Header, Constants);
122 
123   return false;
124 }
125 
parseRootDescriptors(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * RootDescriptorNode,RootSignatureElementKind ElementKind)126 static bool parseRootDescriptors(LLVMContext *Ctx,
127                                  mcdxbc::RootSignatureDesc &RSD,
128                                  MDNode *RootDescriptorNode,
129                                  RootSignatureElementKind ElementKind) {
130   assert(ElementKind == RootSignatureElementKind::SRV ||
131          ElementKind == RootSignatureElementKind::UAV ||
132          ElementKind == RootSignatureElementKind::CBV &&
133              "parseRootDescriptors should only be called with RootDescriptor "
134              "element kind.");
135   if (RootDescriptorNode->getNumOperands() != 5)
136     return reportError(Ctx, "Invalid format for Root Descriptor Element");
137 
138   dxbc::RTS0::v1::RootParameterHeader Header;
139   switch (ElementKind) {
140   case RootSignatureElementKind::SRV:
141     Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
142     break;
143   case RootSignatureElementKind::UAV:
144     Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
145     break;
146   case RootSignatureElementKind::CBV:
147     Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
148     break;
149   default:
150     llvm_unreachable("invalid Root Descriptor kind");
151     break;
152   }
153 
154   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
155     Header.ShaderVisibility = *Val;
156   else
157     return reportError(Ctx, "Invalid value for ShaderVisibility");
158 
159   dxbc::RTS0::v2::RootDescriptor Descriptor;
160   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
161     Descriptor.ShaderRegister = *Val;
162   else
163     return reportError(Ctx, "Invalid value for ShaderRegister");
164 
165   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
166     Descriptor.RegisterSpace = *Val;
167   else
168     return reportError(Ctx, "Invalid value for RegisterSpace");
169 
170   if (RSD.Version == 1) {
171     RSD.ParametersContainer.addParameter(Header, Descriptor);
172     return false;
173   }
174   assert(RSD.Version > 1);
175 
176   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
177     Descriptor.Flags = *Val;
178   else
179     return reportError(Ctx, "Invalid value for Root Descriptor Flags");
180 
181   RSD.ParametersContainer.addParameter(Header, Descriptor);
182   return false;
183 }
184 
parseDescriptorRange(LLVMContext * Ctx,mcdxbc::DescriptorTable & Table,MDNode * RangeDescriptorNode)185 static bool parseDescriptorRange(LLVMContext *Ctx,
186                                  mcdxbc::DescriptorTable &Table,
187                                  MDNode *RangeDescriptorNode) {
188 
189   if (RangeDescriptorNode->getNumOperands() != 6)
190     return reportError(Ctx, "Invalid format for Descriptor Range");
191 
192   dxbc::RTS0::v2::DescriptorRange Range;
193 
194   std::optional<StringRef> ElementText =
195       extractMdStringValue(RangeDescriptorNode, 0);
196 
197   if (!ElementText.has_value())
198     return reportError(Ctx, "Descriptor Range, first element is not a string.");
199 
200   Range.RangeType =
201       StringSwitch<uint32_t>(*ElementText)
202           .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
203           .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
204           .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
205           .Case("Sampler",
206                 llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
207           .Default(~0U);
208 
209   if (Range.RangeType == ~0U)
210     return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
211 
212   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
213     Range.NumDescriptors = *Val;
214   else
215     return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
216 
217   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
218     Range.BaseShaderRegister = *Val;
219   else
220     return reportError(Ctx, "Invalid value for BaseShaderRegister");
221 
222   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
223     Range.RegisterSpace = *Val;
224   else
225     return reportError(Ctx, "Invalid value for RegisterSpace");
226 
227   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
228     Range.OffsetInDescriptorsFromTableStart = *Val;
229   else
230     return reportError(Ctx,
231                        "Invalid value for OffsetInDescriptorsFromTableStart");
232 
233   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
234     Range.Flags = *Val;
235   else
236     return reportError(Ctx, "Invalid value for Descriptor Range Flags");
237 
238   Table.Ranges.push_back(Range);
239   return false;
240 }
241 
parseDescriptorTable(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * DescriptorTableNode)242 static bool parseDescriptorTable(LLVMContext *Ctx,
243                                  mcdxbc::RootSignatureDesc &RSD,
244                                  MDNode *DescriptorTableNode) {
245   const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
246   if (NumOperands < 2)
247     return reportError(Ctx, "Invalid format for Descriptor Table");
248 
249   dxbc::RTS0::v1::RootParameterHeader Header;
250   if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
251     Header.ShaderVisibility = *Val;
252   else
253     return reportError(Ctx, "Invalid value for ShaderVisibility");
254 
255   mcdxbc::DescriptorTable Table;
256   Header.ParameterType =
257       llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
258 
259   for (unsigned int I = 2; I < NumOperands; I++) {
260     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
261     if (Element == nullptr)
262       return reportError(Ctx, "Missing Root Element Metadata Node.");
263 
264     if (parseDescriptorRange(Ctx, Table, Element))
265       return true;
266   }
267 
268   RSD.ParametersContainer.addParameter(Header, Table);
269   return false;
270 }
271 
parseStaticSampler(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * StaticSamplerNode)272 static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
273                                MDNode *StaticSamplerNode) {
274   if (StaticSamplerNode->getNumOperands() != 14)
275     return reportError(Ctx, "Invalid format for Static Sampler");
276 
277   dxbc::RTS0::v1::StaticSampler Sampler;
278   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
279     Sampler.Filter = *Val;
280   else
281     return reportError(Ctx, "Invalid value for Filter");
282 
283   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
284     Sampler.AddressU = *Val;
285   else
286     return reportError(Ctx, "Invalid value for AddressU");
287 
288   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
289     Sampler.AddressV = *Val;
290   else
291     return reportError(Ctx, "Invalid value for AddressV");
292 
293   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
294     Sampler.AddressW = *Val;
295   else
296     return reportError(Ctx, "Invalid value for AddressW");
297 
298   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
299     Sampler.MipLODBias = *Val;
300   else
301     return reportError(Ctx, "Invalid value for MipLODBias");
302 
303   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
304     Sampler.MaxAnisotropy = *Val;
305   else
306     return reportError(Ctx, "Invalid value for MaxAnisotropy");
307 
308   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
309     Sampler.ComparisonFunc = *Val;
310   else
311     return reportError(Ctx, "Invalid value for ComparisonFunc ");
312 
313   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
314     Sampler.BorderColor = *Val;
315   else
316     return reportError(Ctx, "Invalid value for ComparisonFunc ");
317 
318   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
319     Sampler.MinLOD = *Val;
320   else
321     return reportError(Ctx, "Invalid value for MinLOD");
322 
323   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
324     Sampler.MaxLOD = *Val;
325   else
326     return reportError(Ctx, "Invalid value for MaxLOD");
327 
328   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
329     Sampler.ShaderRegister = *Val;
330   else
331     return reportError(Ctx, "Invalid value for ShaderRegister");
332 
333   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
334     Sampler.RegisterSpace = *Val;
335   else
336     return reportError(Ctx, "Invalid value for RegisterSpace");
337 
338   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
339     Sampler.ShaderVisibility = *Val;
340   else
341     return reportError(Ctx, "Invalid value for ShaderVisibility");
342 
343   RSD.StaticSamplers.push_back(Sampler);
344   return false;
345 }
346 
parseRootSignatureElement(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * Element)347 static bool parseRootSignatureElement(LLVMContext *Ctx,
348                                       mcdxbc::RootSignatureDesc &RSD,
349                                       MDNode *Element) {
350   std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
351   if (!ElementText.has_value())
352     return reportError(Ctx, "Invalid format for Root Element");
353 
354   RootSignatureElementKind ElementKind =
355       StringSwitch<RootSignatureElementKind>(*ElementText)
356           .Case("RootFlags", RootSignatureElementKind::RootFlags)
357           .Case("RootConstants", RootSignatureElementKind::RootConstants)
358           .Case("RootCBV", RootSignatureElementKind::CBV)
359           .Case("RootSRV", RootSignatureElementKind::SRV)
360           .Case("RootUAV", RootSignatureElementKind::UAV)
361           .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
362           .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
363           .Default(RootSignatureElementKind::Error);
364 
365   switch (ElementKind) {
366 
367   case RootSignatureElementKind::RootFlags:
368     return parseRootFlags(Ctx, RSD, Element);
369   case RootSignatureElementKind::RootConstants:
370     return parseRootConstants(Ctx, RSD, Element);
371   case RootSignatureElementKind::CBV:
372   case RootSignatureElementKind::SRV:
373   case RootSignatureElementKind::UAV:
374     return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
375   case RootSignatureElementKind::DescriptorTable:
376     return parseDescriptorTable(Ctx, RSD, Element);
377   case RootSignatureElementKind::StaticSamplers:
378     return parseStaticSampler(Ctx, RSD, Element);
379   case RootSignatureElementKind::Error:
380     return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
381   }
382 
383   llvm_unreachable("Unhandled RootSignatureElementKind enum.");
384 }
385 
parse(LLVMContext * Ctx,mcdxbc::RootSignatureDesc & RSD,MDNode * Node)386 static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
387                   MDNode *Node) {
388   bool HasError = false;
389 
390   // Loop through the Root Elements of the root signature.
391   for (const auto &Operand : Node->operands()) {
392     MDNode *Element = dyn_cast<MDNode>(Operand);
393     if (Element == nullptr)
394       return reportError(Ctx, "Missing Root Element Metadata Node.");
395 
396     HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
397   }
398 
399   return HasError;
400 }
401 
validate(LLVMContext * Ctx,const mcdxbc::RootSignatureDesc & RSD)402 static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
403 
404   if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
405     return reportValueError(Ctx, "Version", RSD.Version);
406   }
407 
408   if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
409     return reportValueError(Ctx, "RootFlags", RSD.Flags);
410   }
411 
412   for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
413     if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
414       return reportValueError(Ctx, "ShaderVisibility",
415                               Info.Header.ShaderVisibility);
416 
417     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
418            "Invalid value for ParameterType");
419 
420     switch (Info.Header.ParameterType) {
421 
422     case llvm::to_underlying(dxbc::RootParameterType::CBV):
423     case llvm::to_underlying(dxbc::RootParameterType::UAV):
424     case llvm::to_underlying(dxbc::RootParameterType::SRV): {
425       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
426           RSD.ParametersContainer.getRootDescriptor(Info.Location);
427       if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
428         return reportValueError(Ctx, "ShaderRegister",
429                                 Descriptor.ShaderRegister);
430 
431       if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
432         return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
433 
434       if (RSD.Version > 1) {
435         if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
436                                                            Descriptor.Flags))
437           return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
438       }
439       break;
440     }
441     case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
442       const mcdxbc::DescriptorTable &Table =
443           RSD.ParametersContainer.getDescriptorTable(Info.Location);
444       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
445         if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
446           return reportValueError(Ctx, "RangeType", Range.RangeType);
447 
448         if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
449           return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
450 
451         if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
452           return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
453 
454         if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
455                 RSD.Version, Range.RangeType, Range.Flags))
456           return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
457       }
458       break;
459     }
460     }
461   }
462 
463   for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
464     if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
465       return reportValueError(Ctx, "Filter", Sampler.Filter);
466 
467     if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
468       return reportValueError(Ctx, "AddressU", Sampler.AddressU);
469 
470     if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
471       return reportValueError(Ctx, "AddressV", Sampler.AddressV);
472 
473     if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
474       return reportValueError(Ctx, "AddressW", Sampler.AddressW);
475 
476     if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
477       return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
478 
479     if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
480       return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
481 
482     if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
483       return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
484 
485     if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
486       return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
487 
488     if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
489       return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
490 
491     if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
492       return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
493 
494     if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
495       return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
496 
497     if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
498       return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
499 
500     if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
501       return reportValueError(Ctx, "ShaderVisibility",
502                               Sampler.ShaderVisibility);
503   }
504 
505   return false;
506 }
507 
508 static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
analyzeModule(Module & M)509 analyzeModule(Module &M) {
510 
511   /** Root Signature are specified as following in the metadata:
512 
513     !dx.rootsignatures = !{!2} ; list of function/root signature pairs
514     !2 = !{ ptr @main, !3 } ; function, root signature
515     !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
516 
517     So for each MDNode inside dx.rootsignatures NamedMDNode
518     (the Root parameter of this function), the parsing process needs
519     to loop through each of its operands and process the function,
520     signature pair.
521  */
522 
523   LLVMContext *Ctx = &M.getContext();
524 
525   SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
526 
527   NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
528   if (RootSignatureNode == nullptr)
529     return RSDMap;
530 
531   for (const auto &RSDefNode : RootSignatureNode->operands()) {
532     if (RSDefNode->getNumOperands() != 3) {
533       reportError(Ctx, "Invalid Root Signature metadata - expected function, "
534                        "signature, and version.");
535       continue;
536     }
537 
538     // Function was pruned during compilation.
539     const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
540     if (FunctionPointerMdNode == nullptr) {
541       reportError(
542           Ctx, "Function associated with Root Signature definition is null.");
543       continue;
544     }
545 
546     ValueAsMetadata *VAM =
547         llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
548     if (VAM == nullptr) {
549       reportError(Ctx, "First element of root signature is not a Value");
550       continue;
551     }
552 
553     Function *F = dyn_cast<Function>(VAM->getValue());
554     if (F == nullptr) {
555       reportError(Ctx, "First element of root signature is not a Function");
556       continue;
557     }
558 
559     Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
560 
561     if (RootElementListOperand == nullptr) {
562       reportError(Ctx, "Root Element mdnode is null.");
563       continue;
564     }
565 
566     MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand);
567     if (RootElementListNode == nullptr) {
568       reportError(Ctx, "Root Element is not a metadata node.");
569       continue;
570     }
571     mcdxbc::RootSignatureDesc RSD;
572     if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
573       RSD.Version = *Version;
574     else {
575       reportError(Ctx, "Invalid RSDefNode value, expected constant int");
576       continue;
577     }
578 
579     // Clang emits the root signature data in dxcontainer following a specific
580     // sequence. First the header, then the root parameters. So the header
581     // offset will always equal to the header size.
582     RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
583 
584     // static sampler offset is calculated when writting dxcontainer.
585     RSD.StaticSamplersOffset = 0u;
586 
587     if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
588       return RSDMap;
589     }
590 
591     RSDMap.insert(std::make_pair(F, RSD));
592   }
593 
594   return RSDMap;
595 }
596 
597 AnalysisKey RootSignatureAnalysis::Key;
598 
599 RootSignatureAnalysis::Result
run(Module & M,ModuleAnalysisManager & AM)600 RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
601   return RootSignatureBindingInfo(analyzeModule(M));
602 }
603 
604 //===----------------------------------------------------------------------===//
605 
run(Module & M,ModuleAnalysisManager & AM)606 PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
607                                                     ModuleAnalysisManager &AM) {
608 
609   RootSignatureBindingInfo &RSDMap = AM.getResult<RootSignatureAnalysis>(M);
610 
611   OS << "Root Signature Definitions"
612      << "\n";
613   for (const Function &F : M) {
614     auto It = RSDMap.find(&F);
615     if (It == RSDMap.end())
616       continue;
617     const auto &RS = It->second;
618     OS << "Definition for '" << F.getName() << "':\n";
619     // start root signature header
620     OS << "Flags: " << format_hex(RS.Flags, 8) << "\n"
621        << "Version: " << RS.Version << "\n"
622        << "RootParametersOffset: " << RS.RootParameterOffset << "\n"
623        << "NumParameters: " << RS.ParametersContainer.size() << "\n";
624     for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
625       const auto &[Type, Loc] =
626           RS.ParametersContainer.getTypeAndLocForParameter(I);
627       const dxbc::RTS0::v1::RootParameterHeader Header =
628           RS.ParametersContainer.getHeader(I);
629 
630       OS << "- Parameter Type: " << Type << "\n"
631          << "  Shader Visibility: " << Header.ShaderVisibility << "\n";
632 
633       switch (Type) {
634       case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
635         const dxbc::RTS0::v1::RootConstants &Constants =
636             RS.ParametersContainer.getConstant(Loc);
637         OS << "  Register Space: " << Constants.RegisterSpace << "\n"
638            << "  Shader Register: " << Constants.ShaderRegister << "\n"
639            << "  Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
640         break;
641       }
642       case llvm::to_underlying(dxbc::RootParameterType::CBV):
643       case llvm::to_underlying(dxbc::RootParameterType::UAV):
644       case llvm::to_underlying(dxbc::RootParameterType::SRV): {
645         const dxbc::RTS0::v2::RootDescriptor &Descriptor =
646             RS.ParametersContainer.getRootDescriptor(Loc);
647         OS << "  Register Space: " << Descriptor.RegisterSpace << "\n"
648            << "  Shader Register: " << Descriptor.ShaderRegister << "\n";
649         if (RS.Version > 1)
650           OS << "  Flags: " << Descriptor.Flags << "\n";
651         break;
652       }
653       case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
654         const mcdxbc::DescriptorTable &Table =
655             RS.ParametersContainer.getDescriptorTable(Loc);
656         OS << "  NumRanges: " << Table.Ranges.size() << "\n";
657 
658         for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
659           OS << "  - Range Type: " << Range.RangeType << "\n"
660              << "    Register Space: " << Range.RegisterSpace << "\n"
661              << "    Base Shader Register: " << Range.BaseShaderRegister << "\n"
662              << "    Num Descriptors: " << Range.NumDescriptors << "\n"
663              << "    Offset In Descriptors From Table Start: "
664              << Range.OffsetInDescriptorsFromTableStart << "\n";
665           if (RS.Version > 1)
666             OS << "    Flags: " << Range.Flags << "\n";
667         }
668         break;
669       }
670       }
671     }
672     OS << "NumStaticSamplers: " << 0 << "\n";
673     OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n";
674   }
675   return PreservedAnalyses::all();
676 }
677 
678 //===----------------------------------------------------------------------===//
runOnModule(Module & M)679 bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
680   FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
681       RootSignatureBindingInfo(analyzeModule(M)));
682   return false;
683 }
684 
getAnalysisUsage(AnalysisUsage & AU) const685 void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
686   AU.setPreservesAll();
687   AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
688 }
689 
690 char RootSignatureAnalysisWrapper::ID = 0;
691 
692 INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
693                       "dxil-root-signature-analysis",
694                       "DXIL Root Signature Analysis", true, true)
695 INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
696                     "dxil-root-signature-analysis",
697                     "DXIL Root Signature Analysis", true, true)
698