xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helpers for working with HLSL Root Signatures.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
14 
15 #include <cmath>
16 
17 namespace llvm {
18 namespace hlsl {
19 namespace rootsig {
20 
21 bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
22 
23 bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); }
24 
25 bool verifyRegisterValue(uint32_t RegisterValue) {
26   return RegisterValue != ~0U;
27 }
28 
29 // This Range is reserverved, therefore invalid, according to the spec
30 // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
31 bool verifyRegisterSpace(uint32_t RegisterSpace) {
32   return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
33 }
34 
35 bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
36   using FlagT = dxbc::RootDescriptorFlags;
37   FlagT Flags = FlagT(FlagsVal);
38   if (Version == 1)
39     return Flags == FlagT::DataVolatile;
40 
41   assert(Version == 2 && "Provided invalid root signature version");
42 
43   // The data-specific flags are mutually exclusive.
44   FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
45                     FlagT::DataStaticWhileSetAtExecute;
46 
47   if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
48     return false;
49 
50   // Only a data flag or no flags is valid
51   return (Flags | DataFlags) == DataFlags;
52 }
53 
54 bool verifyRangeType(uint32_t Type) {
55   switch (Type) {
56   case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
57   case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
58   case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
59   case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
60     return true;
61   };
62 
63   return false;
64 }
65 
66 bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
67                                uint32_t FlagsVal) {
68   using FlagT = dxbc::DescriptorRangeFlags;
69   FlagT Flags = FlagT(FlagsVal);
70 
71   const bool IsSampler =
72       (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
73 
74   if (Version == 1) {
75     // Since the metadata is unversioned, we expect to explicitly see the values
76     // that map to the version 1 behaviour here.
77     if (IsSampler)
78       return Flags == FlagT::DescriptorsVolatile;
79     return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile);
80   }
81 
82   // The data-specific flags are mutually exclusive.
83   FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
84                     FlagT::DataStaticWhileSetAtExecute;
85 
86   if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
87     return false;
88 
89   // The descriptor-specific flags are mutually exclusive.
90   FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks |
91                           FlagT::DescriptorsVolatile;
92   if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1)
93     return false;
94 
95   // For volatile descriptors, DATA_is never valid.
96   if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) {
97     FlagT Mask = FlagT::DescriptorsVolatile;
98     if (!IsSampler) {
99       Mask |= FlagT::DataVolatile;
100       Mask |= FlagT::DataStaticWhileSetAtExecute;
101     }
102     return (Flags & ~Mask) == FlagT::None;
103   }
104 
105   // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
106   // the other data-specific flags may all be set.
107   if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) ==
108       FlagT::DescriptorsStaticKeepingBufferBoundsChecks) {
109     FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks;
110     if (!IsSampler) {
111       Mask |= FlagT::DataVolatile;
112       Mask |= FlagT::DataStatic;
113       Mask |= FlagT::DataStaticWhileSetAtExecute;
114     }
115     return (Flags & ~Mask) == FlagT::None;
116   }
117 
118   // When no descriptor flag is set, any data flag is allowed.
119   FlagT Mask = FlagT::None;
120   if (!IsSampler) {
121     Mask |= FlagT::DataVolatile;
122     Mask |= FlagT::DataStaticWhileSetAtExecute;
123     Mask |= FlagT::DataStatic;
124   }
125   return (Flags & ~Mask) == FlagT::None;
126 }
127 
128 bool verifyNumDescriptors(uint32_t NumDescriptors) {
129   return NumDescriptors > 0;
130 }
131 
132 bool verifySamplerFilter(uint32_t Value) {
133   switch (Value) {
134 #define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val):
135 #include "llvm/BinaryFormat/DXContainerConstants.def"
136     return true;
137   }
138   return false;
139 }
140 
141 // Values allowed here:
142 // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
143 bool verifyAddress(uint32_t Address) {
144   switch (Address) {
145 #define TEXTURE_ADDRESS_MODE(Num, Val)                                         \
146   case llvm::to_underlying(dxbc::TextureAddressMode::Val):
147 #include "llvm/BinaryFormat/DXContainerConstants.def"
148     return true;
149   }
150   return false;
151 }
152 
153 bool verifyMipLODBias(float MipLODBias) {
154   return MipLODBias >= -16.f && MipLODBias <= 15.99f;
155 }
156 
157 bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
158   return MaxAnisotropy <= 16u;
159 }
160 
161 bool verifyComparisonFunc(uint32_t ComparisonFunc) {
162   switch (ComparisonFunc) {
163 #define COMPARISON_FUNC(Num, Val)                                              \
164   case llvm::to_underlying(dxbc::ComparisonFunc::Val):
165 #include "llvm/BinaryFormat/DXContainerConstants.def"
166     return true;
167   }
168   return false;
169 }
170 
171 bool verifyBorderColor(uint32_t BorderColor) {
172   switch (BorderColor) {
173 #define STATIC_BORDER_COLOR(Num, Val)                                          \
174   case llvm::to_underlying(dxbc::StaticBorderColor::Val):
175 #include "llvm/BinaryFormat/DXContainerConstants.def"
176     return true;
177   }
178   return false;
179 }
180 
181 bool verifyLOD(float LOD) { return !std::isnan(LOD); }
182 
183 std::optional<const RangeInfo *>
184 ResourceRange::getOverlapping(const RangeInfo &Info) const {
185   MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
186   if (!Interval.valid() || Info.UpperBound < Interval.start())
187     return std::nullopt;
188   return Interval.value();
189 }
190 
191 const RangeInfo *ResourceRange::lookup(uint32_t X) const {
192   return Intervals.lookup(X, nullptr);
193 }
194 
195 void ResourceRange::clear() { return Intervals.clear(); }
196 
197 std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
198   uint32_t LowerBound = Info.LowerBound;
199   uint32_t UpperBound = Info.UpperBound;
200 
201   std::optional<const RangeInfo *> Res = std::nullopt;
202   MapT::iterator Interval = Intervals.begin();
203 
204   while (true) {
205     if (UpperBound < LowerBound)
206       break;
207 
208     Interval.advanceTo(LowerBound);
209     if (!Interval.valid()) // No interval found
210       break;
211 
212     // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
213     // a <= y implicitly from Intervals.find(LowerBound)
214     if (UpperBound < Interval.start())
215       break; // found interval does not overlap with inserted one
216 
217     if (!Res.has_value()) // Update to be the first found intersection
218       Res = Interval.value();
219 
220     if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
221       // x <= a <= b <= y implies that [a;b] is covered by [x;y]
222       //  -> so we don't need to insert this, report an overlap
223       return Res;
224     } else if (LowerBound <= Interval.start() &&
225                Interval.stop() <= UpperBound) {
226       // a <= x <= y <= b implies that [x;y] is covered by [a;b]
227       //  -> so remove the existing interval that we will cover with the
228       //  overwrite
229       Interval.erase();
230     } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
231       // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
232       //  -> so set b = x - 1 such that [a;x-1] is now the interval to insert
233       UpperBound = Interval.start() - 1;
234     } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
235       // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
236       //  -> so set a = y + 1 such that [y+1;b] is now the interval to insert
237       LowerBound = Interval.stop() + 1;
238     }
239   }
240 
241   assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
242   Intervals.insert(LowerBound, UpperBound, &Info);
243   return Res;
244 }
245 
246 llvm::SmallVector<OverlappingRanges>
247 findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
248   // It is expected that Infos is filled with valid RangeInfos and that
249   // they are sorted with respect to the RangeInfo <operator
250   assert(llvm::is_sorted(Infos) && "Ranges must be sorted");
251 
252   llvm::SmallVector<OverlappingRanges> Overlaps;
253   using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
254 
255   // First we will init our state to track:
256   if (Infos.size() == 0)
257     return Overlaps; // No ranges to overlap
258   GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
259 
260   // Create a ResourceRange for each Visibility
261   ResourceRange::MapT::Allocator Allocator;
262   std::array<ResourceRange, 8> Ranges = {
263       ResourceRange(Allocator), // All
264       ResourceRange(Allocator), // Vertex
265       ResourceRange(Allocator), // Hull
266       ResourceRange(Allocator), // Domain
267       ResourceRange(Allocator), // Geometry
268       ResourceRange(Allocator), // Pixel
269       ResourceRange(Allocator), // Amplification
270       ResourceRange(Allocator), // Mesh
271   };
272 
273   // Reset the ResourceRanges for when we iterate through a new group
274   auto ClearRanges = [&Ranges]() {
275     for (ResourceRange &Range : Ranges)
276       Range.clear();
277   };
278 
279   // Iterate through collected RangeInfos
280   for (const RangeInfo &Info : Infos) {
281     GroupT InfoGroup = {Info.Class, Info.Space};
282     // Reset our ResourceRanges when we enter a new group
283     if (CurGroup != InfoGroup) {
284       ClearRanges();
285       CurGroup = InfoGroup;
286     }
287 
288     // Insert range info into corresponding Visibility ResourceRange
289     ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
290     if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
291       Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
292 
293     // Check for overlap in all overlapping Visibility ResourceRanges
294     //
295     // If the range that we are inserting has ShaderVisiblity::All it needs to
296     // check for an overlap in all other visibility types as well.
297     // Otherwise, the range that is inserted needs to check that it does not
298     // overlap with ShaderVisibility::All.
299     //
300     // OverlapRanges will be an ArrayRef to all non-all visibility
301     // ResourceRanges in the former case and it will be an ArrayRef to just the
302     // all visiblity ResourceRange in the latter case.
303     ArrayRef<ResourceRange> OverlapRanges =
304         Info.Visibility == llvm::dxbc::ShaderVisibility::All
305             ? ArrayRef<ResourceRange>{Ranges}.drop_front()
306             : ArrayRef<ResourceRange>{Ranges}.take_front();
307 
308     for (const ResourceRange &Range : OverlapRanges)
309       if (std::optional<const RangeInfo *> Overlapping =
310               Range.getOverlapping(Info))
311         Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
312   }
313 
314   return Overlaps;
315 }
316 
317 } // namespace rootsig
318 } // namespace hlsl
319 } // namespace llvm
320