1*700637cbSDimitry Andric //===- HLSLRootSignatureValidations.cpp - HLSL Root Signature helpers -----===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
8*700637cbSDimitry Andric ///
9*700637cbSDimitry Andric /// \file This file contains helpers for working with HLSL Root Signatures.
10*700637cbSDimitry Andric ///
11*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
12*700637cbSDimitry Andric
13*700637cbSDimitry Andric #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
14*700637cbSDimitry Andric
15*700637cbSDimitry Andric #include <cmath>
16*700637cbSDimitry Andric
17*700637cbSDimitry Andric namespace llvm {
18*700637cbSDimitry Andric namespace hlsl {
19*700637cbSDimitry Andric namespace rootsig {
20*700637cbSDimitry Andric
verifyRootFlag(uint32_t Flags)21*700637cbSDimitry Andric bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
22*700637cbSDimitry Andric
verifyVersion(uint32_t Version)23*700637cbSDimitry Andric bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); }
24*700637cbSDimitry Andric
verifyRegisterValue(uint32_t RegisterValue)25*700637cbSDimitry Andric bool verifyRegisterValue(uint32_t RegisterValue) {
26*700637cbSDimitry Andric return RegisterValue != ~0U;
27*700637cbSDimitry Andric }
28*700637cbSDimitry Andric
29*700637cbSDimitry Andric // This Range is reserverved, therefore invalid, according to the spec
30*700637cbSDimitry Andric // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
verifyRegisterSpace(uint32_t RegisterSpace)31*700637cbSDimitry Andric bool verifyRegisterSpace(uint32_t RegisterSpace) {
32*700637cbSDimitry Andric return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
33*700637cbSDimitry Andric }
34*700637cbSDimitry Andric
verifyRootDescriptorFlag(uint32_t Version,uint32_t FlagsVal)35*700637cbSDimitry Andric bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
36*700637cbSDimitry Andric using FlagT = dxbc::RootDescriptorFlags;
37*700637cbSDimitry Andric FlagT Flags = FlagT(FlagsVal);
38*700637cbSDimitry Andric if (Version == 1)
39*700637cbSDimitry Andric return Flags == FlagT::DataVolatile;
40*700637cbSDimitry Andric
41*700637cbSDimitry Andric assert(Version == 2 && "Provided invalid root signature version");
42*700637cbSDimitry Andric
43*700637cbSDimitry Andric // The data-specific flags are mutually exclusive.
44*700637cbSDimitry Andric FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
45*700637cbSDimitry Andric FlagT::DataStaticWhileSetAtExecute;
46*700637cbSDimitry Andric
47*700637cbSDimitry Andric if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
48*700637cbSDimitry Andric return false;
49*700637cbSDimitry Andric
50*700637cbSDimitry Andric // Only a data flag or no flags is valid
51*700637cbSDimitry Andric return (Flags | DataFlags) == DataFlags;
52*700637cbSDimitry Andric }
53*700637cbSDimitry Andric
verifyRangeType(uint32_t Type)54*700637cbSDimitry Andric bool verifyRangeType(uint32_t Type) {
55*700637cbSDimitry Andric switch (Type) {
56*700637cbSDimitry Andric case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
57*700637cbSDimitry Andric case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
58*700637cbSDimitry Andric case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
59*700637cbSDimitry Andric case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
60*700637cbSDimitry Andric return true;
61*700637cbSDimitry Andric };
62*700637cbSDimitry Andric
63*700637cbSDimitry Andric return false;
64*700637cbSDimitry Andric }
65*700637cbSDimitry Andric
verifyDescriptorRangeFlag(uint32_t Version,uint32_t Type,uint32_t FlagsVal)66*700637cbSDimitry Andric bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
67*700637cbSDimitry Andric uint32_t FlagsVal) {
68*700637cbSDimitry Andric using FlagT = dxbc::DescriptorRangeFlags;
69*700637cbSDimitry Andric FlagT Flags = FlagT(FlagsVal);
70*700637cbSDimitry Andric
71*700637cbSDimitry Andric const bool IsSampler =
72*700637cbSDimitry Andric (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
73*700637cbSDimitry Andric
74*700637cbSDimitry Andric if (Version == 1) {
75*700637cbSDimitry Andric // Since the metadata is unversioned, we expect to explicitly see the values
76*700637cbSDimitry Andric // that map to the version 1 behaviour here.
77*700637cbSDimitry Andric if (IsSampler)
78*700637cbSDimitry Andric return Flags == FlagT::DescriptorsVolatile;
79*700637cbSDimitry Andric return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile);
80*700637cbSDimitry Andric }
81*700637cbSDimitry Andric
82*700637cbSDimitry Andric // The data-specific flags are mutually exclusive.
83*700637cbSDimitry Andric FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic |
84*700637cbSDimitry Andric FlagT::DataStaticWhileSetAtExecute;
85*700637cbSDimitry Andric
86*700637cbSDimitry Andric if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
87*700637cbSDimitry Andric return false;
88*700637cbSDimitry Andric
89*700637cbSDimitry Andric // The descriptor-specific flags are mutually exclusive.
90*700637cbSDimitry Andric FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks |
91*700637cbSDimitry Andric FlagT::DescriptorsVolatile;
92*700637cbSDimitry Andric if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1)
93*700637cbSDimitry Andric return false;
94*700637cbSDimitry Andric
95*700637cbSDimitry Andric // For volatile descriptors, DATA_is never valid.
96*700637cbSDimitry Andric if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) {
97*700637cbSDimitry Andric FlagT Mask = FlagT::DescriptorsVolatile;
98*700637cbSDimitry Andric if (!IsSampler) {
99*700637cbSDimitry Andric Mask |= FlagT::DataVolatile;
100*700637cbSDimitry Andric Mask |= FlagT::DataStaticWhileSetAtExecute;
101*700637cbSDimitry Andric }
102*700637cbSDimitry Andric return (Flags & ~Mask) == FlagT::None;
103*700637cbSDimitry Andric }
104*700637cbSDimitry Andric
105*700637cbSDimitry Andric // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
106*700637cbSDimitry Andric // the other data-specific flags may all be set.
107*700637cbSDimitry Andric if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) ==
108*700637cbSDimitry Andric FlagT::DescriptorsStaticKeepingBufferBoundsChecks) {
109*700637cbSDimitry Andric FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks;
110*700637cbSDimitry Andric if (!IsSampler) {
111*700637cbSDimitry Andric Mask |= FlagT::DataVolatile;
112*700637cbSDimitry Andric Mask |= FlagT::DataStatic;
113*700637cbSDimitry Andric Mask |= FlagT::DataStaticWhileSetAtExecute;
114*700637cbSDimitry Andric }
115*700637cbSDimitry Andric return (Flags & ~Mask) == FlagT::None;
116*700637cbSDimitry Andric }
117*700637cbSDimitry Andric
118*700637cbSDimitry Andric // When no descriptor flag is set, any data flag is allowed.
119*700637cbSDimitry Andric FlagT Mask = FlagT::None;
120*700637cbSDimitry Andric if (!IsSampler) {
121*700637cbSDimitry Andric Mask |= FlagT::DataVolatile;
122*700637cbSDimitry Andric Mask |= FlagT::DataStaticWhileSetAtExecute;
123*700637cbSDimitry Andric Mask |= FlagT::DataStatic;
124*700637cbSDimitry Andric }
125*700637cbSDimitry Andric return (Flags & ~Mask) == FlagT::None;
126*700637cbSDimitry Andric }
127*700637cbSDimitry Andric
verifyNumDescriptors(uint32_t NumDescriptors)128*700637cbSDimitry Andric bool verifyNumDescriptors(uint32_t NumDescriptors) {
129*700637cbSDimitry Andric return NumDescriptors > 0;
130*700637cbSDimitry Andric }
131*700637cbSDimitry Andric
verifySamplerFilter(uint32_t Value)132*700637cbSDimitry Andric bool verifySamplerFilter(uint32_t Value) {
133*700637cbSDimitry Andric switch (Value) {
134*700637cbSDimitry Andric #define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val):
135*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
136*700637cbSDimitry Andric return true;
137*700637cbSDimitry Andric }
138*700637cbSDimitry Andric return false;
139*700637cbSDimitry Andric }
140*700637cbSDimitry Andric
141*700637cbSDimitry Andric // Values allowed here:
142*700637cbSDimitry Andric // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
verifyAddress(uint32_t Address)143*700637cbSDimitry Andric bool verifyAddress(uint32_t Address) {
144*700637cbSDimitry Andric switch (Address) {
145*700637cbSDimitry Andric #define TEXTURE_ADDRESS_MODE(Num, Val) \
146*700637cbSDimitry Andric case llvm::to_underlying(dxbc::TextureAddressMode::Val):
147*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
148*700637cbSDimitry Andric return true;
149*700637cbSDimitry Andric }
150*700637cbSDimitry Andric return false;
151*700637cbSDimitry Andric }
152*700637cbSDimitry Andric
verifyMipLODBias(float MipLODBias)153*700637cbSDimitry Andric bool verifyMipLODBias(float MipLODBias) {
154*700637cbSDimitry Andric return MipLODBias >= -16.f && MipLODBias <= 15.99f;
155*700637cbSDimitry Andric }
156*700637cbSDimitry Andric
verifyMaxAnisotropy(uint32_t MaxAnisotropy)157*700637cbSDimitry Andric bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
158*700637cbSDimitry Andric return MaxAnisotropy <= 16u;
159*700637cbSDimitry Andric }
160*700637cbSDimitry Andric
verifyComparisonFunc(uint32_t ComparisonFunc)161*700637cbSDimitry Andric bool verifyComparisonFunc(uint32_t ComparisonFunc) {
162*700637cbSDimitry Andric switch (ComparisonFunc) {
163*700637cbSDimitry Andric #define COMPARISON_FUNC(Num, Val) \
164*700637cbSDimitry Andric case llvm::to_underlying(dxbc::ComparisonFunc::Val):
165*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
166*700637cbSDimitry Andric return true;
167*700637cbSDimitry Andric }
168*700637cbSDimitry Andric return false;
169*700637cbSDimitry Andric }
170*700637cbSDimitry Andric
verifyBorderColor(uint32_t BorderColor)171*700637cbSDimitry Andric bool verifyBorderColor(uint32_t BorderColor) {
172*700637cbSDimitry Andric switch (BorderColor) {
173*700637cbSDimitry Andric #define STATIC_BORDER_COLOR(Num, Val) \
174*700637cbSDimitry Andric case llvm::to_underlying(dxbc::StaticBorderColor::Val):
175*700637cbSDimitry Andric #include "llvm/BinaryFormat/DXContainerConstants.def"
176*700637cbSDimitry Andric return true;
177*700637cbSDimitry Andric }
178*700637cbSDimitry Andric return false;
179*700637cbSDimitry Andric }
180*700637cbSDimitry Andric
verifyLOD(float LOD)181*700637cbSDimitry Andric bool verifyLOD(float LOD) { return !std::isnan(LOD); }
182*700637cbSDimitry Andric
183*700637cbSDimitry Andric std::optional<const RangeInfo *>
getOverlapping(const RangeInfo & Info) const184*700637cbSDimitry Andric ResourceRange::getOverlapping(const RangeInfo &Info) const {
185*700637cbSDimitry Andric MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
186*700637cbSDimitry Andric if (!Interval.valid() || Info.UpperBound < Interval.start())
187*700637cbSDimitry Andric return std::nullopt;
188*700637cbSDimitry Andric return Interval.value();
189*700637cbSDimitry Andric }
190*700637cbSDimitry Andric
lookup(uint32_t X) const191*700637cbSDimitry Andric const RangeInfo *ResourceRange::lookup(uint32_t X) const {
192*700637cbSDimitry Andric return Intervals.lookup(X, nullptr);
193*700637cbSDimitry Andric }
194*700637cbSDimitry Andric
clear()195*700637cbSDimitry Andric void ResourceRange::clear() { return Intervals.clear(); }
196*700637cbSDimitry Andric
insert(const RangeInfo & Info)197*700637cbSDimitry Andric std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
198*700637cbSDimitry Andric uint32_t LowerBound = Info.LowerBound;
199*700637cbSDimitry Andric uint32_t UpperBound = Info.UpperBound;
200*700637cbSDimitry Andric
201*700637cbSDimitry Andric std::optional<const RangeInfo *> Res = std::nullopt;
202*700637cbSDimitry Andric MapT::iterator Interval = Intervals.begin();
203*700637cbSDimitry Andric
204*700637cbSDimitry Andric while (true) {
205*700637cbSDimitry Andric if (UpperBound < LowerBound)
206*700637cbSDimitry Andric break;
207*700637cbSDimitry Andric
208*700637cbSDimitry Andric Interval.advanceTo(LowerBound);
209*700637cbSDimitry Andric if (!Interval.valid()) // No interval found
210*700637cbSDimitry Andric break;
211*700637cbSDimitry Andric
212*700637cbSDimitry Andric // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
213*700637cbSDimitry Andric // a <= y implicitly from Intervals.find(LowerBound)
214*700637cbSDimitry Andric if (UpperBound < Interval.start())
215*700637cbSDimitry Andric break; // found interval does not overlap with inserted one
216*700637cbSDimitry Andric
217*700637cbSDimitry Andric if (!Res.has_value()) // Update to be the first found intersection
218*700637cbSDimitry Andric Res = Interval.value();
219*700637cbSDimitry Andric
220*700637cbSDimitry Andric if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
221*700637cbSDimitry Andric // x <= a <= b <= y implies that [a;b] is covered by [x;y]
222*700637cbSDimitry Andric // -> so we don't need to insert this, report an overlap
223*700637cbSDimitry Andric return Res;
224*700637cbSDimitry Andric } else if (LowerBound <= Interval.start() &&
225*700637cbSDimitry Andric Interval.stop() <= UpperBound) {
226*700637cbSDimitry Andric // a <= x <= y <= b implies that [x;y] is covered by [a;b]
227*700637cbSDimitry Andric // -> so remove the existing interval that we will cover with the
228*700637cbSDimitry Andric // overwrite
229*700637cbSDimitry Andric Interval.erase();
230*700637cbSDimitry Andric } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
231*700637cbSDimitry Andric // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
232*700637cbSDimitry Andric // -> so set b = x - 1 such that [a;x-1] is now the interval to insert
233*700637cbSDimitry Andric UpperBound = Interval.start() - 1;
234*700637cbSDimitry Andric } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
235*700637cbSDimitry Andric // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
236*700637cbSDimitry Andric // -> so set a = y + 1 such that [y+1;b] is now the interval to insert
237*700637cbSDimitry Andric LowerBound = Interval.stop() + 1;
238*700637cbSDimitry Andric }
239*700637cbSDimitry Andric }
240*700637cbSDimitry Andric
241*700637cbSDimitry Andric assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
242*700637cbSDimitry Andric Intervals.insert(LowerBound, UpperBound, &Info);
243*700637cbSDimitry Andric return Res;
244*700637cbSDimitry Andric }
245*700637cbSDimitry Andric
246*700637cbSDimitry Andric llvm::SmallVector<OverlappingRanges>
findOverlappingRanges(ArrayRef<RangeInfo> Infos)247*700637cbSDimitry Andric findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
248*700637cbSDimitry Andric // It is expected that Infos is filled with valid RangeInfos and that
249*700637cbSDimitry Andric // they are sorted with respect to the RangeInfo <operator
250*700637cbSDimitry Andric assert(llvm::is_sorted(Infos) && "Ranges must be sorted");
251*700637cbSDimitry Andric
252*700637cbSDimitry Andric llvm::SmallVector<OverlappingRanges> Overlaps;
253*700637cbSDimitry Andric using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
254*700637cbSDimitry Andric
255*700637cbSDimitry Andric // First we will init our state to track:
256*700637cbSDimitry Andric if (Infos.size() == 0)
257*700637cbSDimitry Andric return Overlaps; // No ranges to overlap
258*700637cbSDimitry Andric GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
259*700637cbSDimitry Andric
260*700637cbSDimitry Andric // Create a ResourceRange for each Visibility
261*700637cbSDimitry Andric ResourceRange::MapT::Allocator Allocator;
262*700637cbSDimitry Andric std::array<ResourceRange, 8> Ranges = {
263*700637cbSDimitry Andric ResourceRange(Allocator), // All
264*700637cbSDimitry Andric ResourceRange(Allocator), // Vertex
265*700637cbSDimitry Andric ResourceRange(Allocator), // Hull
266*700637cbSDimitry Andric ResourceRange(Allocator), // Domain
267*700637cbSDimitry Andric ResourceRange(Allocator), // Geometry
268*700637cbSDimitry Andric ResourceRange(Allocator), // Pixel
269*700637cbSDimitry Andric ResourceRange(Allocator), // Amplification
270*700637cbSDimitry Andric ResourceRange(Allocator), // Mesh
271*700637cbSDimitry Andric };
272*700637cbSDimitry Andric
273*700637cbSDimitry Andric // Reset the ResourceRanges for when we iterate through a new group
274*700637cbSDimitry Andric auto ClearRanges = [&Ranges]() {
275*700637cbSDimitry Andric for (ResourceRange &Range : Ranges)
276*700637cbSDimitry Andric Range.clear();
277*700637cbSDimitry Andric };
278*700637cbSDimitry Andric
279*700637cbSDimitry Andric // Iterate through collected RangeInfos
280*700637cbSDimitry Andric for (const RangeInfo &Info : Infos) {
281*700637cbSDimitry Andric GroupT InfoGroup = {Info.Class, Info.Space};
282*700637cbSDimitry Andric // Reset our ResourceRanges when we enter a new group
283*700637cbSDimitry Andric if (CurGroup != InfoGroup) {
284*700637cbSDimitry Andric ClearRanges();
285*700637cbSDimitry Andric CurGroup = InfoGroup;
286*700637cbSDimitry Andric }
287*700637cbSDimitry Andric
288*700637cbSDimitry Andric // Insert range info into corresponding Visibility ResourceRange
289*700637cbSDimitry Andric ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
290*700637cbSDimitry Andric if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
291*700637cbSDimitry Andric Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
292*700637cbSDimitry Andric
293*700637cbSDimitry Andric // Check for overlap in all overlapping Visibility ResourceRanges
294*700637cbSDimitry Andric //
295*700637cbSDimitry Andric // If the range that we are inserting has ShaderVisiblity::All it needs to
296*700637cbSDimitry Andric // check for an overlap in all other visibility types as well.
297*700637cbSDimitry Andric // Otherwise, the range that is inserted needs to check that it does not
298*700637cbSDimitry Andric // overlap with ShaderVisibility::All.
299*700637cbSDimitry Andric //
300*700637cbSDimitry Andric // OverlapRanges will be an ArrayRef to all non-all visibility
301*700637cbSDimitry Andric // ResourceRanges in the former case and it will be an ArrayRef to just the
302*700637cbSDimitry Andric // all visiblity ResourceRange in the latter case.
303*700637cbSDimitry Andric ArrayRef<ResourceRange> OverlapRanges =
304*700637cbSDimitry Andric Info.Visibility == llvm::dxbc::ShaderVisibility::All
305*700637cbSDimitry Andric ? ArrayRef<ResourceRange>{Ranges}.drop_front()
306*700637cbSDimitry Andric : ArrayRef<ResourceRange>{Ranges}.take_front();
307*700637cbSDimitry Andric
308*700637cbSDimitry Andric for (const ResourceRange &Range : OverlapRanges)
309*700637cbSDimitry Andric if (std::optional<const RangeInfo *> Overlapping =
310*700637cbSDimitry Andric Range.getOverlapping(Info))
311*700637cbSDimitry Andric Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
312*700637cbSDimitry Andric }
313*700637cbSDimitry Andric
314*700637cbSDimitry Andric return Overlaps;
315*700637cbSDimitry Andric }
316*700637cbSDimitry Andric
317*700637cbSDimitry Andric } // namespace rootsig
318*700637cbSDimitry Andric } // namespace hlsl
319*700637cbSDimitry Andric } // namespace llvm
320