1 //===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helpers for working with HLSL Root Signatures.
10 ///
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
14
15 #include <cmath>
16
17 namespace llvm {
18 namespace hlsl {
19 namespace rootsig {
20
verifyRootFlag(uint32_t Flags)21 bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
22
verifyVersion(uint32_t Version)23 bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); }
24
verifyRegisterValue(uint32_t RegisterValue)25 bool verifyRegisterValue(uint32_t RegisterValue) {
26 return RegisterValue != ~0U;
27 }
28
29 // This Range is reserverved, therefore invalid, according to the spec
30 // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
verifyRegisterSpace(uint32_t RegisterSpace)31 bool verifyRegisterSpace(uint32_t RegisterSpace) {
32 return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
33 }
34
verifyRootDescriptorFlag(uint32_t Version,uint32_t FlagsVal)35 bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
36 using FlagT = dxbc::RootDescriptorFlags;
37 FlagT Flags = FlagT(FlagsVal);
38 if (Version == 1)
39 return Flags == FlagT::DataVolatile;
40
41 assert(Version == 2 && "Provided invalid root signature version");
42
43 // The data-specific flags are mutually exclusive.
44 FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
45 FlagT::DataStaticWhileSetAtExecute;
46
47 if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
48 return false;
49
50 // Only a data flag or no flags is valid
51 return (Flags | DataFlags) == DataFlags;
52 }
53
verifyRangeType(uint32_t Type)54 bool verifyRangeType(uint32_t Type) {
55 switch (Type) {
56 case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
57 case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
58 case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
59 case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
60 return true;
61 };
62
63 return false;
64 }
65
verifyDescriptorRangeFlag(uint32_t Version,uint32_t Type,uint32_t FlagsVal)66 bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
67 uint32_t FlagsVal) {
68 using FlagT = dxbc::DescriptorRangeFlags;
69 FlagT Flags = FlagT(FlagsVal);
70
71 const bool IsSampler =
72 (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
73
74 if (Version == 1) {
75 // Since the metadata is unversioned, we expect to explicitly see the values
76 // that map to the version 1 behaviour here.
77 if (IsSampler)
78 return Flags == FlagT::DescriptorsVolatile;
79 return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile);
80 }
81
82 // The data-specific flags are mutually exclusive.
83 FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
84 FlagT::DataStaticWhileSetAtExecute;
85
86 if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
87 return false;
88
89 // The descriptor-specific flags are mutually exclusive.
90 FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks |
91 FlagT::DescriptorsVolatile;
92 if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1)
93 return false;
94
95 // For volatile descriptors, DATA_is never valid.
96 if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) {
97 FlagT Mask = FlagT::DescriptorsVolatile;
98 if (!IsSampler) {
99 Mask |= FlagT::DataVolatile;
100 Mask |= FlagT::DataStaticWhileSetAtExecute;
101 }
102 return (Flags & ~Mask) == FlagT::None;
103 }
104
105 // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
106 // the other data-specific flags may all be set.
107 if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) ==
108 FlagT::DescriptorsStaticKeepingBufferBoundsChecks) {
109 FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks;
110 if (!IsSampler) {
111 Mask |= FlagT::DataVolatile;
112 Mask |= FlagT::DataStatic;
113 Mask |= FlagT::DataStaticWhileSetAtExecute;
114 }
115 return (Flags & ~Mask) == FlagT::None;
116 }
117
118 // When no descriptor flag is set, any data flag is allowed.
119 FlagT Mask = FlagT::None;
120 if (!IsSampler) {
121 Mask |= FlagT::DataVolatile;
122 Mask |= FlagT::DataStaticWhileSetAtExecute;
123 Mask |= FlagT::DataStatic;
124 }
125 return (Flags & ~Mask) == FlagT::None;
126 }
127
verifyNumDescriptors(uint32_t NumDescriptors)128 bool verifyNumDescriptors(uint32_t NumDescriptors) {
129 return NumDescriptors > 0;
130 }
131
verifySamplerFilter(uint32_t Value)132 bool verifySamplerFilter(uint32_t Value) {
133 switch (Value) {
134 #define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val):
135 #include "llvm/BinaryFormat/DXContainerConstants.def"
136 return true;
137 }
138 return false;
139 }
140
141 // Values allowed here:
142 // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
verifyAddress(uint32_t Address)143 bool verifyAddress(uint32_t Address) {
144 switch (Address) {
145 #define TEXTURE_ADDRESS_MODE(Num, Val) \
146 case llvm::to_underlying(dxbc::TextureAddressMode::Val):
147 #include "llvm/BinaryFormat/DXContainerConstants.def"
148 return true;
149 }
150 return false;
151 }
152
verifyMipLODBias(float MipLODBias)153 bool verifyMipLODBias(float MipLODBias) {
154 return MipLODBias >= -16.f && MipLODBias <= 15.99f;
155 }
156
verifyMaxAnisotropy(uint32_t MaxAnisotropy)157 bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
158 return MaxAnisotropy <= 16u;
159 }
160
verifyComparisonFunc(uint32_t ComparisonFunc)161 bool verifyComparisonFunc(uint32_t ComparisonFunc) {
162 switch (ComparisonFunc) {
163 #define COMPARISON_FUNC(Num, Val) \
164 case llvm::to_underlying(dxbc::ComparisonFunc::Val):
165 #include "llvm/BinaryFormat/DXContainerConstants.def"
166 return true;
167 }
168 return false;
169 }
170
verifyBorderColor(uint32_t BorderColor)171 bool verifyBorderColor(uint32_t BorderColor) {
172 switch (BorderColor) {
173 #define STATIC_BORDER_COLOR(Num, Val) \
174 case llvm::to_underlying(dxbc::StaticBorderColor::Val):
175 #include "llvm/BinaryFormat/DXContainerConstants.def"
176 return true;
177 }
178 return false;
179 }
180
verifyLOD(float LOD)181 bool verifyLOD(float LOD) { return !std::isnan(LOD); }
182
183 std::optional<const RangeInfo *>
getOverlapping(const RangeInfo & Info) const184 ResourceRange::getOverlapping(const RangeInfo &Info) const {
185 MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
186 if (!Interval.valid() || Info.UpperBound < Interval.start())
187 return std::nullopt;
188 return Interval.value();
189 }
190
lookup(uint32_t X) const191 const RangeInfo *ResourceRange::lookup(uint32_t X) const {
192 return Intervals.lookup(X, nullptr);
193 }
194
clear()195 void ResourceRange::clear() { return Intervals.clear(); }
196
insert(const RangeInfo & Info)197 std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
198 uint32_t LowerBound = Info.LowerBound;
199 uint32_t UpperBound = Info.UpperBound;
200
201 std::optional<const RangeInfo *> Res = std::nullopt;
202 MapT::iterator Interval = Intervals.begin();
203
204 while (true) {
205 if (UpperBound < LowerBound)
206 break;
207
208 Interval.advanceTo(LowerBound);
209 if (!Interval.valid()) // No interval found
210 break;
211
212 // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
213 // a <= y implicitly from Intervals.find(LowerBound)
214 if (UpperBound < Interval.start())
215 break; // found interval does not overlap with inserted one
216
217 if (!Res.has_value()) // Update to be the first found intersection
218 Res = Interval.value();
219
220 if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
221 // x <= a <= b <= y implies that [a;b] is covered by [x;y]
222 // -> so we don't need to insert this, report an overlap
223 return Res;
224 } else if (LowerBound <= Interval.start() &&
225 Interval.stop() <= UpperBound) {
226 // a <= x <= y <= b implies that [x;y] is covered by [a;b]
227 // -> so remove the existing interval that we will cover with the
228 // overwrite
229 Interval.erase();
230 } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
231 // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
232 // -> so set b = x - 1 such that [a;x-1] is now the interval to insert
233 UpperBound = Interval.start() - 1;
234 } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
235 // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
236 // -> so set a = y + 1 such that [y+1;b] is now the interval to insert
237 LowerBound = Interval.stop() + 1;
238 }
239 }
240
241 assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
242 Intervals.insert(LowerBound, UpperBound, &Info);
243 return Res;
244 }
245
246 llvm::SmallVector<OverlappingRanges>
findOverlappingRanges(ArrayRef<RangeInfo> Infos)247 findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
248 // It is expected that Infos is filled with valid RangeInfos and that
249 // they are sorted with respect to the RangeInfo <operator
250 assert(llvm::is_sorted(Infos) && "Ranges must be sorted");
251
252 llvm::SmallVector<OverlappingRanges> Overlaps;
253 using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
254
255 // First we will init our state to track:
256 if (Infos.size() == 0)
257 return Overlaps; // No ranges to overlap
258 GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
259
260 // Create a ResourceRange for each Visibility
261 ResourceRange::MapT::Allocator Allocator;
262 std::array<ResourceRange, 8> Ranges = {
263 ResourceRange(Allocator), // All
264 ResourceRange(Allocator), // Vertex
265 ResourceRange(Allocator), // Hull
266 ResourceRange(Allocator), // Domain
267 ResourceRange(Allocator), // Geometry
268 ResourceRange(Allocator), // Pixel
269 ResourceRange(Allocator), // Amplification
270 ResourceRange(Allocator), // Mesh
271 };
272
273 // Reset the ResourceRanges for when we iterate through a new group
274 auto ClearRanges = [&Ranges]() {
275 for (ResourceRange &Range : Ranges)
276 Range.clear();
277 };
278
279 // Iterate through collected RangeInfos
280 for (const RangeInfo &Info : Infos) {
281 GroupT InfoGroup = {Info.Class, Info.Space};
282 // Reset our ResourceRanges when we enter a new group
283 if (CurGroup != InfoGroup) {
284 ClearRanges();
285 CurGroup = InfoGroup;
286 }
287
288 // Insert range info into corresponding Visibility ResourceRange
289 ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
290 if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
291 Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
292
293 // Check for overlap in all overlapping Visibility ResourceRanges
294 //
295 // If the range that we are inserting has ShaderVisiblity::All it needs to
296 // check for an overlap in all other visibility types as well.
297 // Otherwise, the range that is inserted needs to check that it does not
298 // overlap with ShaderVisibility::All.
299 //
300 // OverlapRanges will be an ArrayRef to all non-all visibility
301 // ResourceRanges in the former case and it will be an ArrayRef to just the
302 // all visiblity ResourceRange in the latter case.
303 ArrayRef<ResourceRange> OverlapRanges =
304 Info.Visibility == llvm::dxbc::ShaderVisibility::All
305 ? ArrayRef<ResourceRange>{Ranges}.drop_front()
306 : ArrayRef<ResourceRange>{Ranges}.take_front();
307
308 for (const ResourceRange &Range : OverlapRanges)
309 if (std::optional<const RangeInfo *> Overlapping =
310 Range.getOverlapping(Info))
311 Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
312 }
313
314 return Overlaps;
315 }
316
317 } // namespace rootsig
318 } // namespace hlsl
319 } // namespace llvm
320