1 //===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helpers for working with HLSL Root Signatures.
10 ///
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/Frontend/HLSL/HLSLRootSignature.h"
14 #include "llvm/Support/ScopedPrinter.h"
15
16 namespace llvm {
17 namespace hlsl {
18 namespace rootsig {
19
20 template <typename T>
getEnumName(const T Value,ArrayRef<EnumEntry<T>> Enums)21 static std::optional<StringRef> getEnumName(const T Value,
22 ArrayRef<EnumEntry<T>> Enums) {
23 for (const auto &EnumItem : Enums)
24 if (EnumItem.Value == Value)
25 return EnumItem.Name;
26 return std::nullopt;
27 }
28
29 template <typename T>
printEnum(raw_ostream & OS,const T Value,ArrayRef<EnumEntry<T>> Enums)30 static raw_ostream &printEnum(raw_ostream &OS, const T Value,
31 ArrayRef<EnumEntry<T>> Enums) {
32 auto MaybeName = getEnumName(Value, Enums);
33 if (MaybeName)
34 OS << *MaybeName;
35 return OS;
36 }
37
38 template <typename T>
printFlags(raw_ostream & OS,const T Value,ArrayRef<EnumEntry<T>> Flags)39 static raw_ostream &printFlags(raw_ostream &OS, const T Value,
40 ArrayRef<EnumEntry<T>> Flags) {
41 bool FlagSet = false;
42 unsigned Remaining = llvm::to_underlying(Value);
43 while (Remaining) {
44 unsigned Bit = 1u << llvm::countr_zero(Remaining);
45 if (Remaining & Bit) {
46 if (FlagSet)
47 OS << " | ";
48
49 auto MaybeFlag = getEnumName(T(Bit), Flags);
50 if (MaybeFlag)
51 OS << *MaybeFlag;
52 else
53 OS << "invalid: " << Bit;
54
55 FlagSet = true;
56 }
57 Remaining &= ~Bit;
58 }
59
60 if (!FlagSet)
61 OS << "None";
62 return OS;
63 }
64
65 static const EnumEntry<RegisterType> RegisterNames[] = {
66 {"b", RegisterType::BReg},
67 {"t", RegisterType::TReg},
68 {"u", RegisterType::UReg},
69 {"s", RegisterType::SReg},
70 };
71
operator <<(raw_ostream & OS,const Register & Reg)72 static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
73 printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
74 OS << Reg.Number;
75
76 return OS;
77 }
78
operator <<(raw_ostream & OS,const llvm::dxbc::ShaderVisibility & Visibility)79 static raw_ostream &operator<<(raw_ostream &OS,
80 const llvm::dxbc::ShaderVisibility &Visibility) {
81 printEnum(OS, Visibility, dxbc::getShaderVisibility());
82
83 return OS;
84 }
85
operator <<(raw_ostream & OS,const llvm::dxbc::SamplerFilter & Filter)86 static raw_ostream &operator<<(raw_ostream &OS,
87 const llvm::dxbc::SamplerFilter &Filter) {
88 printEnum(OS, Filter, dxbc::getSamplerFilters());
89
90 return OS;
91 }
92
operator <<(raw_ostream & OS,const dxbc::TextureAddressMode & Address)93 static raw_ostream &operator<<(raw_ostream &OS,
94 const dxbc::TextureAddressMode &Address) {
95 printEnum(OS, Address, dxbc::getTextureAddressModes());
96
97 return OS;
98 }
99
operator <<(raw_ostream & OS,const dxbc::ComparisonFunc & CompFunc)100 static raw_ostream &operator<<(raw_ostream &OS,
101 const dxbc::ComparisonFunc &CompFunc) {
102 printEnum(OS, CompFunc, dxbc::getComparisonFuncs());
103
104 return OS;
105 }
106
operator <<(raw_ostream & OS,const dxbc::StaticBorderColor & BorderColor)107 static raw_ostream &operator<<(raw_ostream &OS,
108 const dxbc::StaticBorderColor &BorderColor) {
109 printEnum(OS, BorderColor, dxbc::getStaticBorderColors());
110
111 return OS;
112 }
113
114 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
115 {"CBV", dxil::ResourceClass::CBuffer},
116 {"SRV", dxil::ResourceClass::SRV},
117 {"UAV", dxil::ResourceClass::UAV},
118 {"Sampler", dxil::ResourceClass::Sampler},
119 };
120
operator <<(raw_ostream & OS,const ClauseType & Type)121 static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
122 printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
123 ArrayRef(ResourceClassNames));
124
125 return OS;
126 }
127
operator <<(raw_ostream & OS,const dxbc::RootDescriptorFlags & Flags)128 static raw_ostream &operator<<(raw_ostream &OS,
129 const dxbc::RootDescriptorFlags &Flags) {
130 printFlags(OS, Flags, dxbc::getRootDescriptorFlags());
131
132 return OS;
133 }
134
operator <<(raw_ostream & OS,const llvm::dxbc::DescriptorRangeFlags & Flags)135 static raw_ostream &operator<<(raw_ostream &OS,
136 const llvm::dxbc::DescriptorRangeFlags &Flags) {
137 printFlags(OS, Flags, dxbc::getDescriptorRangeFlags());
138
139 return OS;
140 }
141
operator <<(raw_ostream & OS,const dxbc::RootFlags & Flags)142 raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) {
143 OS << "RootFlags(";
144 printFlags(OS, Flags, dxbc::getRootFlags());
145 OS << ")";
146
147 return OS;
148 }
149
operator <<(raw_ostream & OS,const RootConstants & Constants)150 raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
151 OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
152 << ", " << Constants.Reg << ", space = " << Constants.Space
153 << ", visibility = " << Constants.Visibility << ")";
154
155 return OS;
156 }
157
operator <<(raw_ostream & OS,const DescriptorTable & Table)158 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
159 OS << "DescriptorTable(numClauses = " << Table.NumClauses
160 << ", visibility = " << Table.Visibility << ")";
161
162 return OS;
163 }
164
operator <<(raw_ostream & OS,const DescriptorTableClause & Clause)165 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
166 OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = ";
167 if (Clause.NumDescriptors == NumDescriptorsUnbounded)
168 OS << "unbounded";
169 else
170 OS << Clause.NumDescriptors;
171 OS << ", space = " << Clause.Space << ", offset = ";
172 if (Clause.Offset == DescriptorTableOffsetAppend)
173 OS << "DescriptorTableOffsetAppend";
174 else
175 OS << Clause.Offset;
176 OS << ", flags = " << Clause.Flags << ")";
177
178 return OS;
179 }
180
operator <<(raw_ostream & OS,const RootDescriptor & Descriptor)181 raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
182 ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
183 OS << "Root" << Type << "(" << Descriptor.Reg
184 << ", space = " << Descriptor.Space
185 << ", visibility = " << Descriptor.Visibility
186 << ", flags = " << Descriptor.Flags << ")";
187
188 return OS;
189 }
190
operator <<(raw_ostream & OS,const StaticSampler & Sampler)191 raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
192 OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
193 << ", addressU = " << Sampler.AddressU
194 << ", addressV = " << Sampler.AddressV
195 << ", addressW = " << Sampler.AddressW
196 << ", mipLODBias = " << Sampler.MipLODBias
197 << ", maxAnisotropy = " << Sampler.MaxAnisotropy
198 << ", comparisonFunc = " << Sampler.CompFunc
199 << ", borderColor = " << Sampler.BorderColor
200 << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
201 << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
202 << ")";
203 return OS;
204 }
205
206 namespace {
207
208 // We use the OverloadVisit with std::visit to ensure the compiler catches if a
209 // new RootElement variant type is added but it's operator<< isn't handled.
210 template <class... Ts> struct OverloadedVisit : Ts... {
211 using Ts::operator()...;
212 };
213 template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
214
215 } // namespace
216
operator <<(raw_ostream & OS,const RootElement & Element)217 raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) {
218 const auto Visitor = OverloadedVisit{
219 [&OS](const dxbc::RootFlags &Flags) { OS << Flags; },
220 [&OS](const RootConstants &Constants) { OS << Constants; },
221 [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; },
222 [&OS](const DescriptorTableClause &Clause) { OS << Clause; },
223 [&OS](const DescriptorTable &Table) { OS << Table; },
224 [&OS](const StaticSampler &Sampler) { OS << Sampler; },
225 };
226 std::visit(Visitor, Element);
227 return OS;
228 }
229
dumpRootElements(raw_ostream & OS,ArrayRef<RootElement> Elements)230 void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
231 OS << " RootElements{";
232 bool First = true;
233 for (const RootElement &Element : Elements) {
234 if (!First)
235 OS << ",";
236 OS << " " << Element;
237 First = false;
238 }
239 OS << "}";
240 }
241
242 } // namespace rootsig
243 } // namespace hlsl
244 } // namespace llvm
245