1 //===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helpers for working with HLSL Root Signatures. 10 /// 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Frontend/HLSL/RootSignatureValidations.h" 14 15 #include <cmath> 16 17 namespace llvm { 18 namespace hlsl { 19 namespace rootsig { 20 21 bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; } 22 23 bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); } 24 25 bool verifyRegisterValue(uint32_t RegisterValue) { 26 return RegisterValue != ~0U; 27 } 28 29 // This Range is reserverved, therefore invalid, according to the spec 30 // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal 31 bool verifyRegisterSpace(uint32_t RegisterSpace) { 32 return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF); 33 } 34 35 bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { 36 using FlagT = dxbc::RootDescriptorFlags; 37 FlagT Flags = FlagT(FlagsVal); 38 if (Version == 1) 39 return Flags == FlagT::DataVolatile; 40 41 assert(Version == 2 && "Provided invalid root signature version"); 42 43 // The data-specific flags are mutually exclusive. 44 FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic | 45 FlagT::DataStaticWhileSetAtExecute; 46 47 if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1) 48 return false; 49 50 // Only a data flag or no flags is valid 51 return (Flags | DataFlags) == DataFlags; 52 } 53 54 bool verifyRangeType(uint32_t Type) { 55 switch (Type) { 56 case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): 57 case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): 58 case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): 59 case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): 60 return true; 61 }; 62 63 return false; 64 } 65 66 bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type, 67 uint32_t FlagsVal) { 68 using FlagT = dxbc::DescriptorRangeFlags; 69 FlagT Flags = FlagT(FlagsVal); 70 71 const bool IsSampler = 72 (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)); 73 74 if (Version == 1) { 75 // Since the metadata is unversioned, we expect to explicitly see the values 76 // that map to the version 1 behaviour here. 77 if (IsSampler) 78 return Flags == FlagT::DescriptorsVolatile; 79 return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile); 80 } 81 82 // The data-specific flags are mutually exclusive. 83 FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic | 84 FlagT::DataStaticWhileSetAtExecute; 85 86 if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1) 87 return false; 88 89 // The descriptor-specific flags are mutually exclusive. 90 FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks | 91 FlagT::DescriptorsVolatile; 92 if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1) 93 return false; 94 95 // For volatile descriptors, DATA_is never valid. 96 if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) { 97 FlagT Mask = FlagT::DescriptorsVolatile; 98 if (!IsSampler) { 99 Mask |= FlagT::DataVolatile; 100 Mask |= FlagT::DataStaticWhileSetAtExecute; 101 } 102 return (Flags & ~Mask) == FlagT::None; 103 } 104 105 // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors, 106 // the other data-specific flags may all be set. 107 if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) == 108 FlagT::DescriptorsStaticKeepingBufferBoundsChecks) { 109 FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks; 110 if (!IsSampler) { 111 Mask |= FlagT::DataVolatile; 112 Mask |= FlagT::DataStatic; 113 Mask |= FlagT::DataStaticWhileSetAtExecute; 114 } 115 return (Flags & ~Mask) == FlagT::None; 116 } 117 118 // When no descriptor flag is set, any data flag is allowed. 119 FlagT Mask = FlagT::None; 120 if (!IsSampler) { 121 Mask |= FlagT::DataVolatile; 122 Mask |= FlagT::DataStaticWhileSetAtExecute; 123 Mask |= FlagT::DataStatic; 124 } 125 return (Flags & ~Mask) == FlagT::None; 126 } 127 128 bool verifyNumDescriptors(uint32_t NumDescriptors) { 129 return NumDescriptors > 0; 130 } 131 132 bool verifySamplerFilter(uint32_t Value) { 133 switch (Value) { 134 #define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val): 135 #include "llvm/BinaryFormat/DXContainerConstants.def" 136 return true; 137 } 138 return false; 139 } 140 141 // Values allowed here: 142 // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax 143 bool verifyAddress(uint32_t Address) { 144 switch (Address) { 145 #define TEXTURE_ADDRESS_MODE(Num, Val) \ 146 case llvm::to_underlying(dxbc::TextureAddressMode::Val): 147 #include "llvm/BinaryFormat/DXContainerConstants.def" 148 return true; 149 } 150 return false; 151 } 152 153 bool verifyMipLODBias(float MipLODBias) { 154 return MipLODBias >= -16.f && MipLODBias <= 15.99f; 155 } 156 157 bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) { 158 return MaxAnisotropy <= 16u; 159 } 160 161 bool verifyComparisonFunc(uint32_t ComparisonFunc) { 162 switch (ComparisonFunc) { 163 #define COMPARISON_FUNC(Num, Val) \ 164 case llvm::to_underlying(dxbc::ComparisonFunc::Val): 165 #include "llvm/BinaryFormat/DXContainerConstants.def" 166 return true; 167 } 168 return false; 169 } 170 171 bool verifyBorderColor(uint32_t BorderColor) { 172 switch (BorderColor) { 173 #define STATIC_BORDER_COLOR(Num, Val) \ 174 case llvm::to_underlying(dxbc::StaticBorderColor::Val): 175 #include "llvm/BinaryFormat/DXContainerConstants.def" 176 return true; 177 } 178 return false; 179 } 180 181 bool verifyLOD(float LOD) { return !std::isnan(LOD); } 182 183 std::optional<const RangeInfo *> 184 ResourceRange::getOverlapping(const RangeInfo &Info) const { 185 MapT::const_iterator Interval = Intervals.find(Info.LowerBound); 186 if (!Interval.valid() || Info.UpperBound < Interval.start()) 187 return std::nullopt; 188 return Interval.value(); 189 } 190 191 const RangeInfo *ResourceRange::lookup(uint32_t X) const { 192 return Intervals.lookup(X, nullptr); 193 } 194 195 void ResourceRange::clear() { return Intervals.clear(); } 196 197 std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) { 198 uint32_t LowerBound = Info.LowerBound; 199 uint32_t UpperBound = Info.UpperBound; 200 201 std::optional<const RangeInfo *> Res = std::nullopt; 202 MapT::iterator Interval = Intervals.begin(); 203 204 while (true) { 205 if (UpperBound < LowerBound) 206 break; 207 208 Interval.advanceTo(LowerBound); 209 if (!Interval.valid()) // No interval found 210 break; 211 212 // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that 213 // a <= y implicitly from Intervals.find(LowerBound) 214 if (UpperBound < Interval.start()) 215 break; // found interval does not overlap with inserted one 216 217 if (!Res.has_value()) // Update to be the first found intersection 218 Res = Interval.value(); 219 220 if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) { 221 // x <= a <= b <= y implies that [a;b] is covered by [x;y] 222 // -> so we don't need to insert this, report an overlap 223 return Res; 224 } else if (LowerBound <= Interval.start() && 225 Interval.stop() <= UpperBound) { 226 // a <= x <= y <= b implies that [x;y] is covered by [a;b] 227 // -> so remove the existing interval that we will cover with the 228 // overwrite 229 Interval.erase(); 230 } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) { 231 // a < x <= b <= y implies that [a; x] is not covered but [x;b] is 232 // -> so set b = x - 1 such that [a;x-1] is now the interval to insert 233 UpperBound = Interval.start() - 1; 234 } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) { 235 // a < x <= b <= y implies that [y; b] is not covered but [a;y] is 236 // -> so set a = y + 1 such that [y+1;b] is now the interval to insert 237 LowerBound = Interval.stop() + 1; 238 } 239 } 240 241 assert(LowerBound <= UpperBound && "Attempting to insert an empty interval"); 242 Intervals.insert(LowerBound, UpperBound, &Info); 243 return Res; 244 } 245 246 llvm::SmallVector<OverlappingRanges> 247 findOverlappingRanges(ArrayRef<RangeInfo> Infos) { 248 // It is expected that Infos is filled with valid RangeInfos and that 249 // they are sorted with respect to the RangeInfo <operator 250 assert(llvm::is_sorted(Infos) && "Ranges must be sorted"); 251 252 llvm::SmallVector<OverlappingRanges> Overlaps; 253 using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>; 254 255 // First we will init our state to track: 256 if (Infos.size() == 0) 257 return Overlaps; // No ranges to overlap 258 GroupT CurGroup = {Infos[0].Class, Infos[0].Space}; 259 260 // Create a ResourceRange for each Visibility 261 ResourceRange::MapT::Allocator Allocator; 262 std::array<ResourceRange, 8> Ranges = { 263 ResourceRange(Allocator), // All 264 ResourceRange(Allocator), // Vertex 265 ResourceRange(Allocator), // Hull 266 ResourceRange(Allocator), // Domain 267 ResourceRange(Allocator), // Geometry 268 ResourceRange(Allocator), // Pixel 269 ResourceRange(Allocator), // Amplification 270 ResourceRange(Allocator), // Mesh 271 }; 272 273 // Reset the ResourceRanges for when we iterate through a new group 274 auto ClearRanges = [&Ranges]() { 275 for (ResourceRange &Range : Ranges) 276 Range.clear(); 277 }; 278 279 // Iterate through collected RangeInfos 280 for (const RangeInfo &Info : Infos) { 281 GroupT InfoGroup = {Info.Class, Info.Space}; 282 // Reset our ResourceRanges when we enter a new group 283 if (CurGroup != InfoGroup) { 284 ClearRanges(); 285 CurGroup = InfoGroup; 286 } 287 288 // Insert range info into corresponding Visibility ResourceRange 289 ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)]; 290 if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info)) 291 Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); 292 293 // Check for overlap in all overlapping Visibility ResourceRanges 294 // 295 // If the range that we are inserting has ShaderVisiblity::All it needs to 296 // check for an overlap in all other visibility types as well. 297 // Otherwise, the range that is inserted needs to check that it does not 298 // overlap with ShaderVisibility::All. 299 // 300 // OverlapRanges will be an ArrayRef to all non-all visibility 301 // ResourceRanges in the former case and it will be an ArrayRef to just the 302 // all visiblity ResourceRange in the latter case. 303 ArrayRef<ResourceRange> OverlapRanges = 304 Info.Visibility == llvm::dxbc::ShaderVisibility::All 305 ? ArrayRef<ResourceRange>{Ranges}.drop_front() 306 : ArrayRef<ResourceRange>{Ranges}.take_front(); 307 308 for (const ResourceRange &Range : OverlapRanges) 309 if (std::optional<const RangeInfo *> Overlapping = 310 Range.getOverlapping(Info)) 311 Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); 312 } 313 314 return Overlaps; 315 } 316 317 } // namespace rootsig 318 } // namespace hlsl 319 } // namespace llvm 320