xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helpers for working with HLSL Root Signatures.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Frontend/HLSL/HLSLRootSignature.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 
16 namespace llvm {
17 namespace hlsl {
18 namespace rootsig {
19 
20 template <typename T>
getEnumName(const T Value,ArrayRef<EnumEntry<T>> Enums)21 static std::optional<StringRef> getEnumName(const T Value,
22                                             ArrayRef<EnumEntry<T>> Enums) {
23   for (const auto &EnumItem : Enums)
24     if (EnumItem.Value == Value)
25       return EnumItem.Name;
26   return std::nullopt;
27 }
28 
29 template <typename T>
printEnum(raw_ostream & OS,const T Value,ArrayRef<EnumEntry<T>> Enums)30 static raw_ostream &printEnum(raw_ostream &OS, const T Value,
31                               ArrayRef<EnumEntry<T>> Enums) {
32   auto MaybeName = getEnumName(Value, Enums);
33   if (MaybeName)
34     OS << *MaybeName;
35   return OS;
36 }
37 
38 template <typename T>
printFlags(raw_ostream & OS,const T Value,ArrayRef<EnumEntry<T>> Flags)39 static raw_ostream &printFlags(raw_ostream &OS, const T Value,
40                                ArrayRef<EnumEntry<T>> Flags) {
41   bool FlagSet = false;
42   unsigned Remaining = llvm::to_underlying(Value);
43   while (Remaining) {
44     unsigned Bit = 1u << llvm::countr_zero(Remaining);
45     if (Remaining & Bit) {
46       if (FlagSet)
47         OS << " | ";
48 
49       auto MaybeFlag = getEnumName(T(Bit), Flags);
50       if (MaybeFlag)
51         OS << *MaybeFlag;
52       else
53         OS << "invalid: " << Bit;
54 
55       FlagSet = true;
56     }
57     Remaining &= ~Bit;
58   }
59 
60   if (!FlagSet)
61     OS << "None";
62   return OS;
63 }
64 
65 static const EnumEntry<RegisterType> RegisterNames[] = {
66     {"b", RegisterType::BReg},
67     {"t", RegisterType::TReg},
68     {"u", RegisterType::UReg},
69     {"s", RegisterType::SReg},
70 };
71 
operator <<(raw_ostream & OS,const Register & Reg)72 static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
73   printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
74   OS << Reg.Number;
75 
76   return OS;
77 }
78 
operator <<(raw_ostream & OS,const llvm::dxbc::ShaderVisibility & Visibility)79 static raw_ostream &operator<<(raw_ostream &OS,
80                                const llvm::dxbc::ShaderVisibility &Visibility) {
81   printEnum(OS, Visibility, dxbc::getShaderVisibility());
82 
83   return OS;
84 }
85 
operator <<(raw_ostream & OS,const llvm::dxbc::SamplerFilter & Filter)86 static raw_ostream &operator<<(raw_ostream &OS,
87                                const llvm::dxbc::SamplerFilter &Filter) {
88   printEnum(OS, Filter, dxbc::getSamplerFilters());
89 
90   return OS;
91 }
92 
operator <<(raw_ostream & OS,const dxbc::TextureAddressMode & Address)93 static raw_ostream &operator<<(raw_ostream &OS,
94                                const dxbc::TextureAddressMode &Address) {
95   printEnum(OS, Address, dxbc::getTextureAddressModes());
96 
97   return OS;
98 }
99 
operator <<(raw_ostream & OS,const dxbc::ComparisonFunc & CompFunc)100 static raw_ostream &operator<<(raw_ostream &OS,
101                                const dxbc::ComparisonFunc &CompFunc) {
102   printEnum(OS, CompFunc, dxbc::getComparisonFuncs());
103 
104   return OS;
105 }
106 
operator <<(raw_ostream & OS,const dxbc::StaticBorderColor & BorderColor)107 static raw_ostream &operator<<(raw_ostream &OS,
108                                const dxbc::StaticBorderColor &BorderColor) {
109   printEnum(OS, BorderColor, dxbc::getStaticBorderColors());
110 
111   return OS;
112 }
113 
114 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
115     {"CBV", dxil::ResourceClass::CBuffer},
116     {"SRV", dxil::ResourceClass::SRV},
117     {"UAV", dxil::ResourceClass::UAV},
118     {"Sampler", dxil::ResourceClass::Sampler},
119 };
120 
operator <<(raw_ostream & OS,const ClauseType & Type)121 static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
122   printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
123             ArrayRef(ResourceClassNames));
124 
125   return OS;
126 }
127 
operator <<(raw_ostream & OS,const dxbc::RootDescriptorFlags & Flags)128 static raw_ostream &operator<<(raw_ostream &OS,
129                                const dxbc::RootDescriptorFlags &Flags) {
130   printFlags(OS, Flags, dxbc::getRootDescriptorFlags());
131 
132   return OS;
133 }
134 
operator <<(raw_ostream & OS,const llvm::dxbc::DescriptorRangeFlags & Flags)135 static raw_ostream &operator<<(raw_ostream &OS,
136                                const llvm::dxbc::DescriptorRangeFlags &Flags) {
137   printFlags(OS, Flags, dxbc::getDescriptorRangeFlags());
138 
139   return OS;
140 }
141 
operator <<(raw_ostream & OS,const dxbc::RootFlags & Flags)142 raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) {
143   OS << "RootFlags(";
144   printFlags(OS, Flags, dxbc::getRootFlags());
145   OS << ")";
146 
147   return OS;
148 }
149 
operator <<(raw_ostream & OS,const RootConstants & Constants)150 raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
151   OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
152      << ", " << Constants.Reg << ", space = " << Constants.Space
153      << ", visibility = " << Constants.Visibility << ")";
154 
155   return OS;
156 }
157 
operator <<(raw_ostream & OS,const DescriptorTable & Table)158 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
159   OS << "DescriptorTable(numClauses = " << Table.NumClauses
160      << ", visibility = " << Table.Visibility << ")";
161 
162   return OS;
163 }
164 
operator <<(raw_ostream & OS,const DescriptorTableClause & Clause)165 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
166   OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = ";
167   if (Clause.NumDescriptors == NumDescriptorsUnbounded)
168     OS << "unbounded";
169   else
170     OS << Clause.NumDescriptors;
171   OS << ", space = " << Clause.Space << ", offset = ";
172   if (Clause.Offset == DescriptorTableOffsetAppend)
173     OS << "DescriptorTableOffsetAppend";
174   else
175     OS << Clause.Offset;
176   OS << ", flags = " << Clause.Flags << ")";
177 
178   return OS;
179 }
180 
operator <<(raw_ostream & OS,const RootDescriptor & Descriptor)181 raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
182   ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
183   OS << "Root" << Type << "(" << Descriptor.Reg
184      << ", space = " << Descriptor.Space
185      << ", visibility = " << Descriptor.Visibility
186      << ", flags = " << Descriptor.Flags << ")";
187 
188   return OS;
189 }
190 
operator <<(raw_ostream & OS,const StaticSampler & Sampler)191 raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
192   OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
193      << ", addressU = " << Sampler.AddressU
194      << ", addressV = " << Sampler.AddressV
195      << ", addressW = " << Sampler.AddressW
196      << ", mipLODBias = " << Sampler.MipLODBias
197      << ", maxAnisotropy = " << Sampler.MaxAnisotropy
198      << ", comparisonFunc = " << Sampler.CompFunc
199      << ", borderColor = " << Sampler.BorderColor
200      << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
201      << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
202      << ")";
203   return OS;
204 }
205 
206 namespace {
207 
208 // We use the OverloadVisit with std::visit to ensure the compiler catches if a
209 // new RootElement variant type is added but it's operator<< isn't handled.
210 template <class... Ts> struct OverloadedVisit : Ts... {
211   using Ts::operator()...;
212 };
213 template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
214 
215 } // namespace
216 
operator <<(raw_ostream & OS,const RootElement & Element)217 raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) {
218   const auto Visitor = OverloadedVisit{
219       [&OS](const dxbc::RootFlags &Flags) { OS << Flags; },
220       [&OS](const RootConstants &Constants) { OS << Constants; },
221       [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; },
222       [&OS](const DescriptorTableClause &Clause) { OS << Clause; },
223       [&OS](const DescriptorTable &Table) { OS << Table; },
224       [&OS](const StaticSampler &Sampler) { OS << Sampler; },
225   };
226   std::visit(Visitor, Element);
227   return OS;
228 }
229 
dumpRootElements(raw_ostream & OS,ArrayRef<RootElement> Elements)230 void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
231   OS << " RootElements{";
232   bool First = true;
233   for (const RootElement &Element : Elements) {
234     if (!First)
235       OS << ",";
236     OS << " " << Element;
237     First = false;
238   }
239   OS << "}";
240 }
241 
242 } // namespace rootsig
243 } // namespace hlsl
244 } // namespace llvm
245