xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1*700637cbSDimitry Andric //===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
8*700637cbSDimitry Andric ///
9*700637cbSDimitry Andric /// \file This file contains helpers for working with HLSL Root Signatures.
10*700637cbSDimitry Andric ///
11*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
12*700637cbSDimitry Andric 
13*700637cbSDimitry Andric #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
14*700637cbSDimitry Andric 
15*700637cbSDimitry Andric #include <cmath>
16*700637cbSDimitry Andric 
17*700637cbSDimitry Andric namespace llvm {
18*700637cbSDimitry Andric namespace hlsl {
19*700637cbSDimitry Andric namespace rootsig {
20*700637cbSDimitry Andric 
verifyRootFlag(uint32_t Flags)21*700637cbSDimitry Andric bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
22*700637cbSDimitry Andric 
verifyVersion(uint32_t Version)23*700637cbSDimitry Andric bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); }
24*700637cbSDimitry Andric 
verifyRegisterValue(uint32_t RegisterValue)25*700637cbSDimitry Andric bool verifyRegisterValue(uint32_t RegisterValue) {
26*700637cbSDimitry Andric   return RegisterValue != ~0U;
27*700637cbSDimitry Andric }
28*700637cbSDimitry Andric 
29*700637cbSDimitry Andric // This Range is reserverved, therefore invalid, according to the spec
30*700637cbSDimitry Andric // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
verifyRegisterSpace(uint32_t RegisterSpace)31*700637cbSDimitry Andric bool verifyRegisterSpace(uint32_t RegisterSpace) {
32*700637cbSDimitry Andric   return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
33*700637cbSDimitry Andric }
34*700637cbSDimitry Andric 
verifyRootDescriptorFlag(uint32_t Version,uint32_t FlagsVal)35*700637cbSDimitry Andric bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
36*700637cbSDimitry Andric   using FlagT = dxbc::RootDescriptorFlags;
37*700637cbSDimitry Andric   FlagT Flags = FlagT(FlagsVal);
38*700637cbSDimitry Andric   if (Version == 1)
39*700637cbSDimitry Andric     return Flags == FlagT::DataVolatile;
40*700637cbSDimitry Andric 
41*700637cbSDimitry Andric   assert(Version == 2 && "Provided invalid root signature version");
42*700637cbSDimitry Andric 
43*700637cbSDimitry Andric   // The data-specific flags are mutually exclusive.
44*700637cbSDimitry Andric   FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
45*700637cbSDimitry Andric                     FlagT::DataStaticWhileSetAtExecute;
46*700637cbSDimitry Andric 
47*700637cbSDimitry Andric   if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
48*700637cbSDimitry Andric     return false;
49*700637cbSDimitry Andric 
50*700637cbSDimitry Andric   // Only a data flag or no flags is valid
51*700637cbSDimitry Andric   return (Flags | DataFlags) == DataFlags;
52*700637cbSDimitry Andric }
53*700637cbSDimitry Andric 
verifyRangeType(uint32_t Type)54*700637cbSDimitry Andric bool verifyRangeType(uint32_t Type) {
55*700637cbSDimitry Andric   switch (Type) {
56*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
57*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
58*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
59*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
60*700637cbSDimitry Andric     return true;
61*700637cbSDimitry Andric   };
62*700637cbSDimitry Andric 
63*700637cbSDimitry Andric   return false;
64*700637cbSDimitry Andric }
65*700637cbSDimitry Andric 
verifyDescriptorRangeFlag(uint32_t Version,uint32_t Type,uint32_t FlagsVal)66*700637cbSDimitry Andric bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
67*700637cbSDimitry Andric                                uint32_t FlagsVal) {
68*700637cbSDimitry Andric   using FlagT = dxbc::DescriptorRangeFlags;
69*700637cbSDimitry Andric   FlagT Flags = FlagT(FlagsVal);
70*700637cbSDimitry Andric 
71*700637cbSDimitry Andric   const bool IsSampler =
72*700637cbSDimitry Andric       (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
73*700637cbSDimitry Andric 
74*700637cbSDimitry Andric   if (Version == 1) {
75*700637cbSDimitry Andric     // Since the metadata is unversioned, we expect to explicitly see the values
76*700637cbSDimitry Andric     // that map to the version 1 behaviour here.
77*700637cbSDimitry Andric     if (IsSampler)
78*700637cbSDimitry Andric       return Flags == FlagT::DescriptorsVolatile;
79*700637cbSDimitry Andric     return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile);
80*700637cbSDimitry Andric   }
81*700637cbSDimitry Andric 
82*700637cbSDimitry Andric   // The data-specific flags are mutually exclusive.
83*700637cbSDimitry Andric   FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
84*700637cbSDimitry Andric                     FlagT::DataStaticWhileSetAtExecute;
85*700637cbSDimitry Andric 
86*700637cbSDimitry Andric   if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
87*700637cbSDimitry Andric     return false;
88*700637cbSDimitry Andric 
89*700637cbSDimitry Andric   // The descriptor-specific flags are mutually exclusive.
90*700637cbSDimitry Andric   FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks |
91*700637cbSDimitry Andric                           FlagT::DescriptorsVolatile;
92*700637cbSDimitry Andric   if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1)
93*700637cbSDimitry Andric     return false;
94*700637cbSDimitry Andric 
95*700637cbSDimitry Andric   // For volatile descriptors, DATA_is never valid.
96*700637cbSDimitry Andric   if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) {
97*700637cbSDimitry Andric     FlagT Mask = FlagT::DescriptorsVolatile;
98*700637cbSDimitry Andric     if (!IsSampler) {
99*700637cbSDimitry Andric       Mask |= FlagT::DataVolatile;
100*700637cbSDimitry Andric       Mask |= FlagT::DataStaticWhileSetAtExecute;
101*700637cbSDimitry Andric     }
102*700637cbSDimitry Andric     return (Flags & ~Mask) == FlagT::None;
103*700637cbSDimitry Andric   }
104*700637cbSDimitry Andric 
105*700637cbSDimitry Andric   // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
106*700637cbSDimitry Andric   // the other data-specific flags may all be set.
107*700637cbSDimitry Andric   if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) ==
108*700637cbSDimitry Andric       FlagT::DescriptorsStaticKeepingBufferBoundsChecks) {
109*700637cbSDimitry Andric     FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks;
110*700637cbSDimitry Andric     if (!IsSampler) {
111*700637cbSDimitry Andric       Mask |= FlagT::DataVolatile;
112*700637cbSDimitry Andric       Mask |= FlagT::DataStatic;
113*700637cbSDimitry Andric       Mask |= FlagT::DataStaticWhileSetAtExecute;
114*700637cbSDimitry Andric     }
115*700637cbSDimitry Andric     return (Flags & ~Mask) == FlagT::None;
116*700637cbSDimitry Andric   }
117*700637cbSDimitry Andric 
118*700637cbSDimitry Andric   // When no descriptor flag is set, any data flag is allowed.
119*700637cbSDimitry Andric   FlagT Mask = FlagT::None;
120*700637cbSDimitry Andric   if (!IsSampler) {
121*700637cbSDimitry Andric     Mask |= FlagT::DataVolatile;
122*700637cbSDimitry Andric     Mask |= FlagT::DataStaticWhileSetAtExecute;
123*700637cbSDimitry Andric     Mask |= FlagT::DataStatic;
124*700637cbSDimitry Andric   }
125*700637cbSDimitry Andric   return (Flags & ~Mask) == FlagT::None;
126*700637cbSDimitry Andric }
127*700637cbSDimitry Andric 
verifyNumDescriptors(uint32_t NumDescriptors)128*700637cbSDimitry Andric bool verifyNumDescriptors(uint32_t NumDescriptors) {
129*700637cbSDimitry Andric   return NumDescriptors > 0;
130*700637cbSDimitry Andric }
131*700637cbSDimitry Andric 
verifySamplerFilter(uint32_t Value)132*700637cbSDimitry Andric bool verifySamplerFilter(uint32_t Value) {
133*700637cbSDimitry Andric   switch (Value) {
134*700637cbSDimitry Andric #define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val):
135*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
136*700637cbSDimitry Andric     return true;
137*700637cbSDimitry Andric   }
138*700637cbSDimitry Andric   return false;
139*700637cbSDimitry Andric }
140*700637cbSDimitry Andric 
141*700637cbSDimitry Andric // Values allowed here:
142*700637cbSDimitry Andric // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
verifyAddress(uint32_t Address)143*700637cbSDimitry Andric bool verifyAddress(uint32_t Address) {
144*700637cbSDimitry Andric   switch (Address) {
145*700637cbSDimitry Andric #define TEXTURE_ADDRESS_MODE(Num, Val)                                         \
146*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::TextureAddressMode::Val):
147*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
148*700637cbSDimitry Andric     return true;
149*700637cbSDimitry Andric   }
150*700637cbSDimitry Andric   return false;
151*700637cbSDimitry Andric }
152*700637cbSDimitry Andric 
verifyMipLODBias(float MipLODBias)153*700637cbSDimitry Andric bool verifyMipLODBias(float MipLODBias) {
154*700637cbSDimitry Andric   return MipLODBias >= -16.f && MipLODBias <= 15.99f;
155*700637cbSDimitry Andric }
156*700637cbSDimitry Andric 
verifyMaxAnisotropy(uint32_t MaxAnisotropy)157*700637cbSDimitry Andric bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
158*700637cbSDimitry Andric   return MaxAnisotropy <= 16u;
159*700637cbSDimitry Andric }
160*700637cbSDimitry Andric 
verifyComparisonFunc(uint32_t ComparisonFunc)161*700637cbSDimitry Andric bool verifyComparisonFunc(uint32_t ComparisonFunc) {
162*700637cbSDimitry Andric   switch (ComparisonFunc) {
163*700637cbSDimitry Andric #define COMPARISON_FUNC(Num, Val)                                              \
164*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::ComparisonFunc::Val):
165*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
166*700637cbSDimitry Andric     return true;
167*700637cbSDimitry Andric   }
168*700637cbSDimitry Andric   return false;
169*700637cbSDimitry Andric }
170*700637cbSDimitry Andric 
verifyBorderColor(uint32_t BorderColor)171*700637cbSDimitry Andric bool verifyBorderColor(uint32_t BorderColor) {
172*700637cbSDimitry Andric   switch (BorderColor) {
173*700637cbSDimitry Andric #define STATIC_BORDER_COLOR(Num, Val)                                          \
174*700637cbSDimitry Andric   case llvm::to_underlying(dxbc::StaticBorderColor::Val):
175*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
176*700637cbSDimitry Andric     return true;
177*700637cbSDimitry Andric   }
178*700637cbSDimitry Andric   return false;
179*700637cbSDimitry Andric }
180*700637cbSDimitry Andric 
verifyLOD(float LOD)181*700637cbSDimitry Andric bool verifyLOD(float LOD) { return !std::isnan(LOD); }
182*700637cbSDimitry Andric 
183*700637cbSDimitry Andric std::optional<const RangeInfo *>
getOverlapping(const RangeInfo & Info) const184*700637cbSDimitry Andric ResourceRange::getOverlapping(const RangeInfo &Info) const {
185*700637cbSDimitry Andric   MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
186*700637cbSDimitry Andric   if (!Interval.valid() || Info.UpperBound < Interval.start())
187*700637cbSDimitry Andric     return std::nullopt;
188*700637cbSDimitry Andric   return Interval.value();
189*700637cbSDimitry Andric }
190*700637cbSDimitry Andric 
lookup(uint32_t X) const191*700637cbSDimitry Andric const RangeInfo *ResourceRange::lookup(uint32_t X) const {
192*700637cbSDimitry Andric   return Intervals.lookup(X, nullptr);
193*700637cbSDimitry Andric }
194*700637cbSDimitry Andric 
clear()195*700637cbSDimitry Andric void ResourceRange::clear() { return Intervals.clear(); }
196*700637cbSDimitry Andric 
insert(const RangeInfo & Info)197*700637cbSDimitry Andric std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
198*700637cbSDimitry Andric   uint32_t LowerBound = Info.LowerBound;
199*700637cbSDimitry Andric   uint32_t UpperBound = Info.UpperBound;
200*700637cbSDimitry Andric 
201*700637cbSDimitry Andric   std::optional<const RangeInfo *> Res = std::nullopt;
202*700637cbSDimitry Andric   MapT::iterator Interval = Intervals.begin();
203*700637cbSDimitry Andric 
204*700637cbSDimitry Andric   while (true) {
205*700637cbSDimitry Andric     if (UpperBound < LowerBound)
206*700637cbSDimitry Andric       break;
207*700637cbSDimitry Andric 
208*700637cbSDimitry Andric     Interval.advanceTo(LowerBound);
209*700637cbSDimitry Andric     if (!Interval.valid()) // No interval found
210*700637cbSDimitry Andric       break;
211*700637cbSDimitry Andric 
212*700637cbSDimitry Andric     // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
213*700637cbSDimitry Andric     // a <= y implicitly from Intervals.find(LowerBound)
214*700637cbSDimitry Andric     if (UpperBound < Interval.start())
215*700637cbSDimitry Andric       break; // found interval does not overlap with inserted one
216*700637cbSDimitry Andric 
217*700637cbSDimitry Andric     if (!Res.has_value()) // Update to be the first found intersection
218*700637cbSDimitry Andric       Res = Interval.value();
219*700637cbSDimitry Andric 
220*700637cbSDimitry Andric     if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
221*700637cbSDimitry Andric       // x <= a <= b <= y implies that [a;b] is covered by [x;y]
222*700637cbSDimitry Andric       //  -> so we don't need to insert this, report an overlap
223*700637cbSDimitry Andric       return Res;
224*700637cbSDimitry Andric     } else if (LowerBound <= Interval.start() &&
225*700637cbSDimitry Andric                Interval.stop() <= UpperBound) {
226*700637cbSDimitry Andric       // a <= x <= y <= b implies that [x;y] is covered by [a;b]
227*700637cbSDimitry Andric       //  -> so remove the existing interval that we will cover with the
228*700637cbSDimitry Andric       //  overwrite
229*700637cbSDimitry Andric       Interval.erase();
230*700637cbSDimitry Andric     } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
231*700637cbSDimitry Andric       // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
232*700637cbSDimitry Andric       //  -> so set b = x - 1 such that [a;x-1] is now the interval to insert
233*700637cbSDimitry Andric       UpperBound = Interval.start() - 1;
234*700637cbSDimitry Andric     } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
235*700637cbSDimitry Andric       // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
236*700637cbSDimitry Andric       //  -> so set a = y + 1 such that [y+1;b] is now the interval to insert
237*700637cbSDimitry Andric       LowerBound = Interval.stop() + 1;
238*700637cbSDimitry Andric     }
239*700637cbSDimitry Andric   }
240*700637cbSDimitry Andric 
241*700637cbSDimitry Andric   assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
242*700637cbSDimitry Andric   Intervals.insert(LowerBound, UpperBound, &Info);
243*700637cbSDimitry Andric   return Res;
244*700637cbSDimitry Andric }
245*700637cbSDimitry Andric 
246*700637cbSDimitry Andric llvm::SmallVector<OverlappingRanges>
findOverlappingRanges(ArrayRef<RangeInfo> Infos)247*700637cbSDimitry Andric findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
248*700637cbSDimitry Andric   // It is expected that Infos is filled with valid RangeInfos and that
249*700637cbSDimitry Andric   // they are sorted with respect to the RangeInfo <operator
250*700637cbSDimitry Andric   assert(llvm::is_sorted(Infos) && "Ranges must be sorted");
251*700637cbSDimitry Andric 
252*700637cbSDimitry Andric   llvm::SmallVector<OverlappingRanges> Overlaps;
253*700637cbSDimitry Andric   using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
254*700637cbSDimitry Andric 
255*700637cbSDimitry Andric   // First we will init our state to track:
256*700637cbSDimitry Andric   if (Infos.size() == 0)
257*700637cbSDimitry Andric     return Overlaps; // No ranges to overlap
258*700637cbSDimitry Andric   GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
259*700637cbSDimitry Andric 
260*700637cbSDimitry Andric   // Create a ResourceRange for each Visibility
261*700637cbSDimitry Andric   ResourceRange::MapT::Allocator Allocator;
262*700637cbSDimitry Andric   std::array<ResourceRange, 8> Ranges = {
263*700637cbSDimitry Andric       ResourceRange(Allocator), // All
264*700637cbSDimitry Andric       ResourceRange(Allocator), // Vertex
265*700637cbSDimitry Andric       ResourceRange(Allocator), // Hull
266*700637cbSDimitry Andric       ResourceRange(Allocator), // Domain
267*700637cbSDimitry Andric       ResourceRange(Allocator), // Geometry
268*700637cbSDimitry Andric       ResourceRange(Allocator), // Pixel
269*700637cbSDimitry Andric       ResourceRange(Allocator), // Amplification
270*700637cbSDimitry Andric       ResourceRange(Allocator), // Mesh
271*700637cbSDimitry Andric   };
272*700637cbSDimitry Andric 
273*700637cbSDimitry Andric   // Reset the ResourceRanges for when we iterate through a new group
274*700637cbSDimitry Andric   auto ClearRanges = [&Ranges]() {
275*700637cbSDimitry Andric     for (ResourceRange &Range : Ranges)
276*700637cbSDimitry Andric       Range.clear();
277*700637cbSDimitry Andric   };
278*700637cbSDimitry Andric 
279*700637cbSDimitry Andric   // Iterate through collected RangeInfos
280*700637cbSDimitry Andric   for (const RangeInfo &Info : Infos) {
281*700637cbSDimitry Andric     GroupT InfoGroup = {Info.Class, Info.Space};
282*700637cbSDimitry Andric     // Reset our ResourceRanges when we enter a new group
283*700637cbSDimitry Andric     if (CurGroup != InfoGroup) {
284*700637cbSDimitry Andric       ClearRanges();
285*700637cbSDimitry Andric       CurGroup = InfoGroup;
286*700637cbSDimitry Andric     }
287*700637cbSDimitry Andric 
288*700637cbSDimitry Andric     // Insert range info into corresponding Visibility ResourceRange
289*700637cbSDimitry Andric     ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
290*700637cbSDimitry Andric     if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
291*700637cbSDimitry Andric       Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
292*700637cbSDimitry Andric 
293*700637cbSDimitry Andric     // Check for overlap in all overlapping Visibility ResourceRanges
294*700637cbSDimitry Andric     //
295*700637cbSDimitry Andric     // If the range that we are inserting has ShaderVisiblity::All it needs to
296*700637cbSDimitry Andric     // check for an overlap in all other visibility types as well.
297*700637cbSDimitry Andric     // Otherwise, the range that is inserted needs to check that it does not
298*700637cbSDimitry Andric     // overlap with ShaderVisibility::All.
299*700637cbSDimitry Andric     //
300*700637cbSDimitry Andric     // OverlapRanges will be an ArrayRef to all non-all visibility
301*700637cbSDimitry Andric     // ResourceRanges in the former case and it will be an ArrayRef to just the
302*700637cbSDimitry Andric     // all visiblity ResourceRange in the latter case.
303*700637cbSDimitry Andric     ArrayRef<ResourceRange> OverlapRanges =
304*700637cbSDimitry Andric         Info.Visibility == llvm::dxbc::ShaderVisibility::All
305*700637cbSDimitry Andric             ? ArrayRef<ResourceRange>{Ranges}.drop_front()
306*700637cbSDimitry Andric             : ArrayRef<ResourceRange>{Ranges}.take_front();
307*700637cbSDimitry Andric 
308*700637cbSDimitry Andric     for (const ResourceRange &Range : OverlapRanges)
309*700637cbSDimitry Andric       if (std::optional<const RangeInfo *> Overlapping =
310*700637cbSDimitry Andric               Range.getOverlapping(Info))
311*700637cbSDimitry Andric         Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
312*700637cbSDimitry Andric   }
313*700637cbSDimitry Andric 
314*700637cbSDimitry Andric   return Overlaps;
315*700637cbSDimitry Andric }
316*700637cbSDimitry Andric 
317*700637cbSDimitry Andric } // namespace rootsig
318*700637cbSDimitry Andric } // namespace hlsl
319*700637cbSDimitry Andric } // namespace llvm
320