xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Object/DXContainer.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the DXContainerFile class, which implements the ObjectFile
10 // interface for DXContainer files.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_OBJECT_DXCONTAINER_H
16 #define LLVM_OBJECT_DXCONTAINER_H
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/BinaryFormat/DXContainer.h"
22 #include "llvm/Object/Error.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/MemoryBufferRef.h"
27 #include "llvm/TargetParser/Triple.h"
28 #include <array>
29 #include <cstddef>
30 #include <cstdint>
31 #include <variant>
32 
33 namespace llvm {
34 namespace object {
35 
36 namespace detail {
37 template <typename T>
swapBytes(T & value)38 std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
39   sys::swapByteOrder(value);
40 }
41 
42 template <typename T>
swapBytes(T & value)43 std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
44   value.swapBytes();
45 }
46 } // namespace detail
47 
48 // This class provides a view into the underlying resource array. The Resource
49 // data is little-endian encoded and may not be properly aligned to read
50 // directly from. The dereference operator creates a copy of the data and byte
51 // swaps it as appropriate.
52 template <typename T> struct ViewArray {
53   StringRef Data;
54   uint32_t Stride = sizeof(T); // size of each element in the list.
55 
56   ViewArray() = default;
ViewArrayViewArray57   ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
58 
59   using value_type = T;
MaxStrideViewArray60   static constexpr uint32_t MaxStride() {
61     return static_cast<uint32_t>(sizeof(value_type));
62   }
63 
64   struct iterator {
65     StringRef Data;
66     uint32_t Stride; // size of each element in the list.
67     const char *Current;
68 
iteratorViewArray::iterator69     iterator(const ViewArray &A, const char *C)
70         : Data(A.Data), Stride(A.Stride), Current(C) {}
71     iterator(const iterator &) = default;
72 
73     value_type operator*() {
74       // Explicitly zero the structure so that unused fields are zeroed. It is
75       // up to the user to know if the fields are used by verifying the PSV
76       // version.
77       value_type Val;
78       std::memset(&Val, 0, sizeof(value_type));
79       if (Current >= Data.end())
80         return Val;
81       memcpy(static_cast<void *>(&Val), Current, std::min(Stride, MaxStride()));
82       if (sys::IsBigEndianHost)
83         detail::swapBytes(Val);
84       return Val;
85     }
86 
87     iterator operator++() {
88       if (Current < Data.end())
89         Current += Stride;
90       return *this;
91     }
92 
93     iterator operator++(int) {
94       iterator Tmp = *this;
95       ++*this;
96       return Tmp;
97     }
98 
99     iterator operator--() {
100       if (Current > Data.begin())
101         Current -= Stride;
102       return *this;
103     }
104 
105     iterator operator--(int) {
106       iterator Tmp = *this;
107       --*this;
108       return Tmp;
109     }
110 
111     bool operator==(const iterator I) { return I.Current == Current; }
112     bool operator!=(const iterator I) { return !(*this == I); }
113   };
114 
beginViewArray115   iterator begin() const { return iterator(*this, Data.begin()); }
116 
endViewArray117   iterator end() const { return iterator(*this, Data.end()); }
118 
sizeViewArray119   size_t size() const { return Data.size() / Stride; }
120 
isEmptyViewArray121   bool isEmpty() const { return Data.empty(); }
122 };
123 
124 namespace DirectX {
125 struct RootParameterView {
126   const dxbc::RTS0::v1::RootParameterHeader &Header;
127   StringRef ParamData;
128 
RootParameterViewRootParameterView129   RootParameterView(const dxbc::RTS0::v1::RootParameterHeader &H, StringRef P)
130       : Header(H), ParamData(P) {}
131 
readParameterRootParameterView132   template <typename T> Expected<T> readParameter() {
133     T Struct;
134     if (sizeof(T) != ParamData.size())
135       return make_error<GenericBinaryError>(
136           "Reading structure out of file bounds", object_error::parse_failed);
137 
138     memcpy(&Struct, ParamData.data(), sizeof(T));
139     // DXContainer is always little endian
140     if (sys::IsBigEndianHost)
141       Struct.swapBytes();
142     return Struct;
143   }
144 };
145 
146 struct RootConstantView : RootParameterView {
classofRootConstantView147   static bool classof(const RootParameterView *V) {
148     return V->Header.ParameterType ==
149            (uint32_t)dxbc::RootParameterType::Constants32Bit;
150   }
151 
readRootConstantView152   llvm::Expected<dxbc::RTS0::v1::RootConstants> read() {
153     return readParameter<dxbc::RTS0::v1::RootConstants>();
154   }
155 };
156 
157 struct RootDescriptorView : RootParameterView {
classofRootDescriptorView158   static bool classof(const RootParameterView *V) {
159     return (V->Header.ParameterType ==
160                 llvm::to_underlying(dxbc::RootParameterType::CBV) ||
161             V->Header.ParameterType ==
162                 llvm::to_underlying(dxbc::RootParameterType::SRV) ||
163             V->Header.ParameterType ==
164                 llvm::to_underlying(dxbc::RootParameterType::UAV));
165   }
166 
readRootDescriptorView167   llvm::Expected<dxbc::RTS0::v2::RootDescriptor> read(uint32_t Version) {
168     if (Version == 1) {
169       auto Descriptor = readParameter<dxbc::RTS0::v1::RootDescriptor>();
170       if (Error E = Descriptor.takeError())
171         return E;
172       return dxbc::RTS0::v2::RootDescriptor(*Descriptor);
173     }
174     if (Version != 2)
175       return make_error<GenericBinaryError>("Invalid Root Signature version: " +
176                                                 Twine(Version),
177                                             object_error::parse_failed);
178     return readParameter<dxbc::RTS0::v2::RootDescriptor>();
179   }
180 };
181 template <typename T> struct DescriptorTable {
182   uint32_t NumRanges;
183   uint32_t RangesOffset;
184   ViewArray<T> Ranges;
185 
beginDescriptorTable186   typename ViewArray<T>::iterator begin() const { return Ranges.begin(); }
187 
endDescriptorTable188   typename ViewArray<T>::iterator end() const { return Ranges.end(); }
189 };
190 
191 struct DescriptorTableView : RootParameterView {
classofDescriptorTableView192   static bool classof(const RootParameterView *V) {
193     return (V->Header.ParameterType ==
194             llvm::to_underlying(dxbc::RootParameterType::DescriptorTable));
195   }
196 
197   // Define a type alias to access the template parameter from inside classof
readDescriptorTableView198   template <typename T> llvm::Expected<DescriptorTable<T>> read() {
199     const char *Current = ParamData.begin();
200     DescriptorTable<T> Table;
201 
202     Table.NumRanges =
203         support::endian::read<uint32_t, llvm::endianness::little>(Current);
204     Current += sizeof(uint32_t);
205 
206     Table.RangesOffset =
207         support::endian::read<uint32_t, llvm::endianness::little>(Current);
208     Current += sizeof(uint32_t);
209 
210     Table.Ranges.Data = ParamData.substr(2 * sizeof(uint32_t),
211                                          Table.NumRanges * Table.Ranges.Stride);
212     return Table;
213   }
214 };
215 
parseFailed(const Twine & Msg)216 static Error parseFailed(const Twine &Msg) {
217   return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);
218 }
219 
220 class RootSignature {
221 private:
222   uint32_t Version;
223   uint32_t NumParameters;
224   uint32_t RootParametersOffset;
225   uint32_t NumStaticSamplers;
226   uint32_t StaticSamplersOffset;
227   uint32_t Flags;
228   ViewArray<dxbc::RTS0::v1::RootParameterHeader> ParametersHeaders;
229   StringRef PartData;
230   ViewArray<dxbc::RTS0::v1::StaticSampler> StaticSamplers;
231 
232   using param_header_iterator =
233       ViewArray<dxbc::RTS0::v1::RootParameterHeader>::iterator;
234   using samplers_iterator = ViewArray<dxbc::RTS0::v1::StaticSampler>::iterator;
235 
236 public:
RootSignature(StringRef PD)237   RootSignature(StringRef PD) : PartData(PD) {}
238 
239   LLVM_ABI Error parse();
getVersion()240   uint32_t getVersion() const { return Version; }
getNumParameters()241   uint32_t getNumParameters() const { return NumParameters; }
getRootParametersOffset()242   uint32_t getRootParametersOffset() const { return RootParametersOffset; }
getNumStaticSamplers()243   uint32_t getNumStaticSamplers() const { return NumStaticSamplers; }
getStaticSamplersOffset()244   uint32_t getStaticSamplersOffset() const { return StaticSamplersOffset; }
getNumRootParameters()245   uint32_t getNumRootParameters() const { return ParametersHeaders.size(); }
param_headers()246   llvm::iterator_range<param_header_iterator> param_headers() const {
247     return llvm::make_range(ParametersHeaders.begin(), ParametersHeaders.end());
248   }
samplers()249   llvm::iterator_range<samplers_iterator> samplers() const {
250     return llvm::make_range(StaticSamplers.begin(), StaticSamplers.end());
251   }
getFlags()252   uint32_t getFlags() const { return Flags; }
253 
254   llvm::Expected<RootParameterView>
getParameter(const dxbc::RTS0::v1::RootParameterHeader & Header)255   getParameter(const dxbc::RTS0::v1::RootParameterHeader &Header) const {
256     size_t DataSize;
257     size_t EndOfSectionByte = getNumStaticSamplers() == 0
258                                   ? PartData.size()
259                                   : getStaticSamplersOffset();
260 
261     if (!dxbc::isValidParameterType(Header.ParameterType))
262       return parseFailed("invalid parameter type");
263 
264     switch (static_cast<dxbc::RootParameterType>(Header.ParameterType)) {
265     case dxbc::RootParameterType::Constants32Bit:
266       DataSize = sizeof(dxbc::RTS0::v1::RootConstants);
267       break;
268     case dxbc::RootParameterType::CBV:
269     case dxbc::RootParameterType::SRV:
270     case dxbc::RootParameterType::UAV:
271       if (Version == 1)
272         DataSize = sizeof(dxbc::RTS0::v1::RootDescriptor);
273       else
274         DataSize = sizeof(dxbc::RTS0::v2::RootDescriptor);
275       break;
276     case dxbc::RootParameterType::DescriptorTable:
277       if (Header.ParameterOffset + sizeof(uint32_t) > EndOfSectionByte)
278         return parseFailed("Reading structure out of file bounds");
279 
280       uint32_t NumRanges =
281           support::endian::read<uint32_t, llvm::endianness::little>(
282               PartData.begin() + Header.ParameterOffset);
283       if (Version == 1)
284         DataSize = sizeof(dxbc::RTS0::v1::DescriptorRange) * NumRanges;
285       else
286         DataSize = sizeof(dxbc::RTS0::v2::DescriptorRange) * NumRanges;
287 
288       // 4 bytes for the number of ranges in table and
289       // 4 bytes for the ranges offset
290       DataSize += 2 * sizeof(uint32_t);
291       break;
292     }
293     if (Header.ParameterOffset + DataSize > EndOfSectionByte)
294       return parseFailed("Reading structure out of file bounds");
295 
296     StringRef Buff = PartData.substr(Header.ParameterOffset, DataSize);
297     RootParameterView View = RootParameterView(Header, Buff);
298     return View;
299   }
300 };
301 
302 class PSVRuntimeInfo {
303 
304   using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>;
305   using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>;
306 
307   StringRef Data;
308   uint32_t Size;
309   using InfoStruct =
310       std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
311                    dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo,
312                    dxbc::PSV::v3::RuntimeInfo>;
313   InfoStruct BasicInfo;
314   ResourceArray Resources;
315   StringRef StringTable;
316   SmallVector<uint32_t> SemanticIndexTable;
317   SigElementArray SigInputElements;
318   SigElementArray SigOutputElements;
319   SigElementArray SigPatchOrPrimElements;
320 
321   std::array<ViewArray<uint32_t>, 4> OutputVectorMasks;
322   ViewArray<uint32_t> PatchOrPrimMasks;
323   std::array<ViewArray<uint32_t>, 4> InputOutputMap;
324   ViewArray<uint32_t> InputPatchMap;
325   ViewArray<uint32_t> PatchOutputMap;
326 
327 public:
PSVRuntimeInfo(StringRef D)328   PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
329 
330   // Parsing depends on the shader kind
331   LLVM_ABI Error parse(uint16_t ShaderKind);
332 
getSize()333   uint32_t getSize() const { return Size; }
getResourceCount()334   uint32_t getResourceCount() const { return Resources.size(); }
getResources()335   ResourceArray getResources() const { return Resources; }
336 
getVersion()337   uint32_t getVersion() const {
338     return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo)
339                ? 3
340                : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo)     ? 2
341                   : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1
342                                                                  : 0);
343   }
344 
getResourceStride()345   uint32_t getResourceStride() const { return Resources.Stride; }
346 
getInfo()347   const InfoStruct &getInfo() const { return BasicInfo; }
348 
getInfoAs()349   template <typename T> const T *getInfoAs() const {
350     if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
351       return static_cast<const T *>(P);
352     if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value)
353       return nullptr;
354 
355     if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
356       return static_cast<const T *>(P);
357     if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
358       return nullptr;
359 
360     if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
361       return static_cast<const T *>(P);
362     if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
363       return nullptr;
364 
365     if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo))
366       return static_cast<const T *>(P);
367     return nullptr;
368   }
369 
getStringTable()370   StringRef getStringTable() const { return StringTable; }
getSemanticIndexTable()371   ArrayRef<uint32_t> getSemanticIndexTable() const {
372     return SemanticIndexTable;
373   }
374 
375   LLVM_ABI uint8_t getSigInputCount() const;
376   LLVM_ABI uint8_t getSigOutputCount() const;
377   LLVM_ABI uint8_t getSigPatchOrPrimCount() const;
378 
getSigInputElements()379   SigElementArray getSigInputElements() const { return SigInputElements; }
getSigOutputElements()380   SigElementArray getSigOutputElements() const { return SigOutputElements; }
getSigPatchOrPrimElements()381   SigElementArray getSigPatchOrPrimElements() const {
382     return SigPatchOrPrimElements;
383   }
384 
getOutputVectorMasks(size_t Idx)385   ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
386     assert(Idx < 4);
387     return OutputVectorMasks[Idx];
388   }
389 
getPatchOrPrimMasks()390   ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
391 
getInputOutputMap(size_t Idx)392   ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
393     assert(Idx < 4);
394     return InputOutputMap[Idx];
395   }
396 
getInputPatchMap()397   ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
getPatchOutputMap()398   ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
399 
getSigElementStride()400   uint32_t getSigElementStride() const { return SigInputElements.Stride; }
401 
usesViewID()402   bool usesViewID() const {
403     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
404       return P->UsesViewID != 0;
405     return false;
406   }
407 
getInputVectorCount()408   uint8_t getInputVectorCount() const {
409     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
410       return P->SigInputVectors;
411     return 0;
412   }
413 
getOutputVectorCounts()414   ArrayRef<uint8_t> getOutputVectorCounts() const {
415     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
416       return ArrayRef<uint8_t>(P->SigOutputVectors);
417     return ArrayRef<uint8_t>();
418   }
419 
getPatchConstOrPrimVectorCount()420   uint8_t getPatchConstOrPrimVectorCount() const {
421     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
422       return P->GeomData.SigPatchConstOrPrimVectors;
423     return 0;
424   }
425 };
426 
427 class Signature {
428   ViewArray<dxbc::ProgramSignatureElement> Parameters;
429   uint32_t StringTableOffset;
430   StringRef StringTable;
431 
432 public:
begin()433   ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const {
434     return Parameters.begin();
435   }
436 
end()437   ViewArray<dxbc::ProgramSignatureElement>::iterator end() const {
438     return Parameters.end();
439   }
440 
getName(uint32_t Offset)441   StringRef getName(uint32_t Offset) const {
442     assert(Offset >= StringTableOffset &&
443            Offset < StringTableOffset + StringTable.size() &&
444            "Offset out of range.");
445     // Name offsets are from the start of the signature data, not from the start
446     // of the string table. The header encodes the start offset of the sting
447     // table, so we convert the offset here.
448     uint32_t TableOffset = Offset - StringTableOffset;
449     return StringTable.slice(TableOffset, StringTable.find('\0', TableOffset));
450   }
451 
isEmpty()452   bool isEmpty() const { return Parameters.isEmpty(); }
453 
454   LLVM_ABI Error initialize(StringRef Part);
455 };
456 
457 } // namespace DirectX
458 
459 class DXContainer {
460 public:
461   using DXILData = std::pair<dxbc::ProgramHeader, const char *>;
462 
463 private:
464   DXContainer(MemoryBufferRef O);
465 
466   MemoryBufferRef Data;
467   dxbc::Header Header;
468   SmallVector<uint32_t, 4> PartOffsets;
469   std::optional<DXILData> DXIL;
470   std::optional<uint64_t> ShaderFeatureFlags;
471   std::optional<dxbc::ShaderHash> Hash;
472   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
473   std::optional<DirectX::RootSignature> RootSignature;
474   DirectX::Signature InputSignature;
475   DirectX::Signature OutputSignature;
476   DirectX::Signature PatchConstantSignature;
477 
478   Error parseHeader();
479   Error parsePartOffsets();
480   Error parseDXILHeader(StringRef Part);
481   Error parseShaderFeatureFlags(StringRef Part);
482   Error parseHash(StringRef Part);
483   Error parseRootSignature(StringRef Part);
484   Error parsePSVInfo(StringRef Part);
485   Error parseSignature(StringRef Part, DirectX::Signature &Array);
486   friend class PartIterator;
487 
488 public:
489   // The PartIterator is a wrapper around the iterator for the PartOffsets
490   // member of the DXContainer. It contains a refernce to the container, and the
491   // current iterator value, as well as storage for a parsed part header.
492   class PartIterator {
493     const DXContainer &Container;
494     SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
495     struct PartData {
496       dxbc::PartHeader Part;
497       uint32_t Offset;
498       StringRef Data;
499     } IteratorState;
500 
501     friend class DXContainer;
502 
PartIterator(const DXContainer & C,SmallVectorImpl<uint32_t>::const_iterator It)503     PartIterator(const DXContainer &C,
504                  SmallVectorImpl<uint32_t>::const_iterator It)
505         : Container(C), OffsetIt(It) {
506       if (OffsetIt == Container.PartOffsets.end())
507         updateIteratorImpl(Container.PartOffsets.back());
508       else
509         updateIterator();
510     }
511 
512     // Updates the iterator's state data. This results in copying the part
513     // header into the iterator and handling any required byte swapping. This is
514     // called when incrementing or decrementing the iterator.
updateIterator()515     void updateIterator() {
516       if (OffsetIt != Container.PartOffsets.end())
517         updateIteratorImpl(*OffsetIt);
518     }
519 
520     // Implementation for updating the iterator state based on a specified
521     // offest.
522     LLVM_ABI void updateIteratorImpl(const uint32_t Offset);
523 
524   public:
525     PartIterator &operator++() {
526       if (OffsetIt == Container.PartOffsets.end())
527         return *this;
528       ++OffsetIt;
529       updateIterator();
530       return *this;
531     }
532 
533     PartIterator operator++(int) {
534       PartIterator Tmp = *this;
535       ++(*this);
536       return Tmp;
537     }
538 
539     bool operator==(const PartIterator &RHS) const {
540       return OffsetIt == RHS.OffsetIt;
541     }
542 
543     bool operator!=(const PartIterator &RHS) const {
544       return OffsetIt != RHS.OffsetIt;
545     }
546 
547     const PartData &operator*() { return IteratorState; }
548     const PartData *operator->() { return &IteratorState; }
549   };
550 
begin()551   PartIterator begin() const {
552     return PartIterator(*this, PartOffsets.begin());
553   }
554 
end()555   PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
556 
getData()557   StringRef getData() const { return Data.getBuffer(); }
558   LLVM_ABI static Expected<DXContainer> create(MemoryBufferRef Object);
559 
getHeader()560   const dxbc::Header &getHeader() const { return Header; }
561 
getDXIL()562   const std::optional<DXILData> &getDXIL() const { return DXIL; }
563 
getShaderFeatureFlags()564   std::optional<uint64_t> getShaderFeatureFlags() const {
565     return ShaderFeatureFlags;
566   }
567 
getShaderHash()568   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
569 
getRootSignature()570   std::optional<DirectX::RootSignature> getRootSignature() const {
571     return RootSignature;
572   }
573 
getPSVInfo()574   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
575     return PSVInfo;
576   };
577 
getInputSignature()578   const DirectX::Signature &getInputSignature() const { return InputSignature; }
getOutputSignature()579   const DirectX::Signature &getOutputSignature() const {
580     return OutputSignature;
581   }
getPatchConstantSignature()582   const DirectX::Signature &getPatchConstantSignature() const {
583     return PatchConstantSignature;
584   }
585 };
586 
587 } // namespace object
588 } // namespace llvm
589 
590 #endif // LLVM_OBJECT_DXCONTAINER_H
591