1 //===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helpers for working with HLSL Root Signatures. 10 /// 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Frontend/HLSL/HLSLRootSignature.h" 14 #include "llvm/Support/ScopedPrinter.h" 15 16 namespace llvm { 17 namespace hlsl { 18 namespace rootsig { 19 20 template <typename T> 21 static std::optional<StringRef> getEnumName(const T Value, 22 ArrayRef<EnumEntry<T>> Enums) { 23 for (const auto &EnumItem : Enums) 24 if (EnumItem.Value == Value) 25 return EnumItem.Name; 26 return std::nullopt; 27 } 28 29 template <typename T> 30 static raw_ostream &printEnum(raw_ostream &OS, const T Value, 31 ArrayRef<EnumEntry<T>> Enums) { 32 auto MaybeName = getEnumName(Value, Enums); 33 if (MaybeName) 34 OS << *MaybeName; 35 return OS; 36 } 37 38 template <typename T> 39 static raw_ostream &printFlags(raw_ostream &OS, const T Value, 40 ArrayRef<EnumEntry<T>> Flags) { 41 bool FlagSet = false; 42 unsigned Remaining = llvm::to_underlying(Value); 43 while (Remaining) { 44 unsigned Bit = 1u << llvm::countr_zero(Remaining); 45 if (Remaining & Bit) { 46 if (FlagSet) 47 OS << " | "; 48 49 auto MaybeFlag = getEnumName(T(Bit), Flags); 50 if (MaybeFlag) 51 OS << *MaybeFlag; 52 else 53 OS << "invalid: " << Bit; 54 55 FlagSet = true; 56 } 57 Remaining &= ~Bit; 58 } 59 60 if (!FlagSet) 61 OS << "None"; 62 return OS; 63 } 64 65 static const EnumEntry<RegisterType> RegisterNames[] = { 66 {"b", RegisterType::BReg}, 67 {"t", RegisterType::TReg}, 68 {"u", RegisterType::UReg}, 69 {"s", RegisterType::SReg}, 70 }; 71 72 static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { 73 printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); 74 OS << Reg.Number; 75 76 return OS; 77 } 78 79 static raw_ostream &operator<<(raw_ostream &OS, 80 const llvm::dxbc::ShaderVisibility &Visibility) { 81 printEnum(OS, Visibility, dxbc::getShaderVisibility()); 82 83 return OS; 84 } 85 86 static raw_ostream &operator<<(raw_ostream &OS, 87 const llvm::dxbc::SamplerFilter &Filter) { 88 printEnum(OS, Filter, dxbc::getSamplerFilters()); 89 90 return OS; 91 } 92 93 static raw_ostream &operator<<(raw_ostream &OS, 94 const dxbc::TextureAddressMode &Address) { 95 printEnum(OS, Address, dxbc::getTextureAddressModes()); 96 97 return OS; 98 } 99 100 static raw_ostream &operator<<(raw_ostream &OS, 101 const dxbc::ComparisonFunc &CompFunc) { 102 printEnum(OS, CompFunc, dxbc::getComparisonFuncs()); 103 104 return OS; 105 } 106 107 static raw_ostream &operator<<(raw_ostream &OS, 108 const dxbc::StaticBorderColor &BorderColor) { 109 printEnum(OS, BorderColor, dxbc::getStaticBorderColors()); 110 111 return OS; 112 } 113 114 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = { 115 {"CBV", dxil::ResourceClass::CBuffer}, 116 {"SRV", dxil::ResourceClass::SRV}, 117 {"UAV", dxil::ResourceClass::UAV}, 118 {"Sampler", dxil::ResourceClass::Sampler}, 119 }; 120 121 static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { 122 printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), 123 ArrayRef(ResourceClassNames)); 124 125 return OS; 126 } 127 128 static raw_ostream &operator<<(raw_ostream &OS, 129 const dxbc::RootDescriptorFlags &Flags) { 130 printFlags(OS, Flags, dxbc::getRootDescriptorFlags()); 131 132 return OS; 133 } 134 135 static raw_ostream &operator<<(raw_ostream &OS, 136 const llvm::dxbc::DescriptorRangeFlags &Flags) { 137 printFlags(OS, Flags, dxbc::getDescriptorRangeFlags()); 138 139 return OS; 140 } 141 142 raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) { 143 OS << "RootFlags("; 144 printFlags(OS, Flags, dxbc::getRootFlags()); 145 OS << ")"; 146 147 return OS; 148 } 149 150 raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) { 151 OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants 152 << ", " << Constants.Reg << ", space = " << Constants.Space 153 << ", visibility = " << Constants.Visibility << ")"; 154 155 return OS; 156 } 157 158 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) { 159 OS << "DescriptorTable(numClauses = " << Table.NumClauses 160 << ", visibility = " << Table.Visibility << ")"; 161 162 return OS; 163 } 164 165 raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) { 166 OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = "; 167 if (Clause.NumDescriptors == NumDescriptorsUnbounded) 168 OS << "unbounded"; 169 else 170 OS << Clause.NumDescriptors; 171 OS << ", space = " << Clause.Space << ", offset = "; 172 if (Clause.Offset == DescriptorTableOffsetAppend) 173 OS << "DescriptorTableOffsetAppend"; 174 else 175 OS << Clause.Offset; 176 OS << ", flags = " << Clause.Flags << ")"; 177 178 return OS; 179 } 180 181 raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) { 182 ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type)); 183 OS << "Root" << Type << "(" << Descriptor.Reg 184 << ", space = " << Descriptor.Space 185 << ", visibility = " << Descriptor.Visibility 186 << ", flags = " << Descriptor.Flags << ")"; 187 188 return OS; 189 } 190 191 raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) { 192 OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter 193 << ", addressU = " << Sampler.AddressU 194 << ", addressV = " << Sampler.AddressV 195 << ", addressW = " << Sampler.AddressW 196 << ", mipLODBias = " << Sampler.MipLODBias 197 << ", maxAnisotropy = " << Sampler.MaxAnisotropy 198 << ", comparisonFunc = " << Sampler.CompFunc 199 << ", borderColor = " << Sampler.BorderColor 200 << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD 201 << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility 202 << ")"; 203 return OS; 204 } 205 206 namespace { 207 208 // We use the OverloadVisit with std::visit to ensure the compiler catches if a 209 // new RootElement variant type is added but it's operator<< isn't handled. 210 template <class... Ts> struct OverloadedVisit : Ts... { 211 using Ts::operator()...; 212 }; 213 template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>; 214 215 } // namespace 216 217 raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) { 218 const auto Visitor = OverloadedVisit{ 219 [&OS](const dxbc::RootFlags &Flags) { OS << Flags; }, 220 [&OS](const RootConstants &Constants) { OS << Constants; }, 221 [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; }, 222 [&OS](const DescriptorTableClause &Clause) { OS << Clause; }, 223 [&OS](const DescriptorTable &Table) { OS << Table; }, 224 [&OS](const StaticSampler &Sampler) { OS << Sampler; }, 225 }; 226 std::visit(Visitor, Element); 227 return OS; 228 } 229 230 void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) { 231 OS << " RootElements{"; 232 bool First = true; 233 for (const RootElement &Element : Elements) { 234 if (!First) 235 OS << ","; 236 OS << " " << Element; 237 First = false; 238 } 239 OS << "}"; 240 } 241 242 } // namespace rootsig 243 } // namespace hlsl 244 } // namespace llvm 245