xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Object/DXContainer.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the DXContainerFile class, which implements the ObjectFile
10 // interface for DXContainer files.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_OBJECT_DXCONTAINER_H
16 #define LLVM_OBJECT_DXCONTAINER_H
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/BinaryFormat/DXContainer.h"
21 #include "llvm/Support/Error.h"
22 #include "llvm/Support/MemoryBufferRef.h"
23 #include "llvm/TargetParser/Triple.h"
24 #include <array>
25 #include <variant>
26 
27 namespace llvm {
28 namespace object {
29 
30 namespace detail {
31 template <typename T>
swapBytes(T & value)32 std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
33   sys::swapByteOrder(value);
34 }
35 
36 template <typename T>
swapBytes(T & value)37 std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
38   value.swapBytes();
39 }
40 } // namespace detail
41 
42 // This class provides a view into the underlying resource array. The Resource
43 // data is little-endian encoded and may not be properly aligned to read
44 // directly from. The dereference operator creates a copy of the data and byte
45 // swaps it as appropriate.
46 template <typename T> struct ViewArray {
47   StringRef Data;
48   uint32_t Stride = sizeof(T); // size of each element in the list.
49 
50   ViewArray() = default;
ViewArrayViewArray51   ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
52 
53   using value_type = T;
MaxStrideViewArray54   static constexpr uint32_t MaxStride() {
55     return static_cast<uint32_t>(sizeof(value_type));
56   }
57 
58   struct iterator {
59     StringRef Data;
60     uint32_t Stride; // size of each element in the list.
61     const char *Current;
62 
iteratorViewArray::iterator63     iterator(const ViewArray &A, const char *C)
64         : Data(A.Data), Stride(A.Stride), Current(C) {}
65     iterator(const iterator &) = default;
66 
67     value_type operator*() {
68       // Explicitly zero the structure so that unused fields are zeroed. It is
69       // up to the user to know if the fields are used by verifying the PSV
70       // version.
71       value_type Val;
72       std::memset(&Val, 0, sizeof(value_type));
73       if (Current >= Data.end())
74         return Val;
75       memcpy(static_cast<void *>(&Val), Current, std::min(Stride, MaxStride()));
76       if (sys::IsBigEndianHost)
77         detail::swapBytes(Val);
78       return Val;
79     }
80 
81     iterator operator++() {
82       if (Current < Data.end())
83         Current += Stride;
84       return *this;
85     }
86 
87     iterator operator++(int) {
88       iterator Tmp = *this;
89       ++*this;
90       return Tmp;
91     }
92 
93     iterator operator--() {
94       if (Current > Data.begin())
95         Current -= Stride;
96       return *this;
97     }
98 
99     iterator operator--(int) {
100       iterator Tmp = *this;
101       --*this;
102       return Tmp;
103     }
104 
105     bool operator==(const iterator I) { return I.Current == Current; }
106     bool operator!=(const iterator I) { return !(*this == I); }
107   };
108 
beginViewArray109   iterator begin() const { return iterator(*this, Data.begin()); }
110 
endViewArray111   iterator end() const { return iterator(*this, Data.end()); }
112 
sizeViewArray113   size_t size() const { return Data.size() / Stride; }
114 
isEmptyViewArray115   bool isEmpty() const { return Data.empty(); }
116 };
117 
118 namespace DirectX {
119 class PSVRuntimeInfo {
120 
121   using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>;
122   using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>;
123 
124   StringRef Data;
125   uint32_t Size;
126   using InfoStruct =
127       std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
128                    dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo,
129                    dxbc::PSV::v3::RuntimeInfo>;
130   InfoStruct BasicInfo;
131   ResourceArray Resources;
132   StringRef StringTable;
133   SmallVector<uint32_t> SemanticIndexTable;
134   SigElementArray SigInputElements;
135   SigElementArray SigOutputElements;
136   SigElementArray SigPatchOrPrimElements;
137 
138   std::array<ViewArray<uint32_t>, 4> OutputVectorMasks;
139   ViewArray<uint32_t> PatchOrPrimMasks;
140   std::array<ViewArray<uint32_t>, 4> InputOutputMap;
141   ViewArray<uint32_t> InputPatchMap;
142   ViewArray<uint32_t> PatchOutputMap;
143 
144 public:
PSVRuntimeInfo(StringRef D)145   PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
146 
147   // Parsing depends on the shader kind
148   Error parse(uint16_t ShaderKind);
149 
getSize()150   uint32_t getSize() const { return Size; }
getResourceCount()151   uint32_t getResourceCount() const { return Resources.size(); }
getResources()152   ResourceArray getResources() const { return Resources; }
153 
getVersion()154   uint32_t getVersion() const {
155     return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo)
156                ? 3
157                : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo)     ? 2
158                   : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1
159                                                                  : 0);
160   }
161 
getResourceStride()162   uint32_t getResourceStride() const { return Resources.Stride; }
163 
getInfo()164   const InfoStruct &getInfo() const { return BasicInfo; }
165 
getInfoAs()166   template <typename T> const T *getInfoAs() const {
167     if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
168       return static_cast<const T *>(P);
169     if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value)
170       return nullptr;
171 
172     if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
173       return static_cast<const T *>(P);
174     if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
175       return nullptr;
176 
177     if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
178       return static_cast<const T *>(P);
179     if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
180       return nullptr;
181 
182     if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo))
183       return static_cast<const T *>(P);
184     return nullptr;
185   }
186 
getStringTable()187   StringRef getStringTable() const { return StringTable; }
getSemanticIndexTable()188   ArrayRef<uint32_t> getSemanticIndexTable() const {
189     return SemanticIndexTable;
190   }
191 
192   uint8_t getSigInputCount() const;
193   uint8_t getSigOutputCount() const;
194   uint8_t getSigPatchOrPrimCount() const;
195 
getSigInputElements()196   SigElementArray getSigInputElements() const { return SigInputElements; }
getSigOutputElements()197   SigElementArray getSigOutputElements() const { return SigOutputElements; }
getSigPatchOrPrimElements()198   SigElementArray getSigPatchOrPrimElements() const {
199     return SigPatchOrPrimElements;
200   }
201 
getOutputVectorMasks(size_t Idx)202   ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
203     assert(Idx < 4);
204     return OutputVectorMasks[Idx];
205   }
206 
getPatchOrPrimMasks()207   ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
208 
getInputOutputMap(size_t Idx)209   ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
210     assert(Idx < 4);
211     return InputOutputMap[Idx];
212   }
213 
getInputPatchMap()214   ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
getPatchOutputMap()215   ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
216 
getSigElementStride()217   uint32_t getSigElementStride() const { return SigInputElements.Stride; }
218 
usesViewID()219   bool usesViewID() const {
220     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
221       return P->UsesViewID != 0;
222     return false;
223   }
224 
getInputVectorCount()225   uint8_t getInputVectorCount() const {
226     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
227       return P->SigInputVectors;
228     return 0;
229   }
230 
getOutputVectorCounts()231   ArrayRef<uint8_t> getOutputVectorCounts() const {
232     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
233       return ArrayRef<uint8_t>(P->SigOutputVectors);
234     return ArrayRef<uint8_t>();
235   }
236 
getPatchConstOrPrimVectorCount()237   uint8_t getPatchConstOrPrimVectorCount() const {
238     if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
239       return P->GeomData.SigPatchConstOrPrimVectors;
240     return 0;
241   }
242 };
243 
244 class Signature {
245   ViewArray<dxbc::ProgramSignatureElement> Parameters;
246   uint32_t StringTableOffset;
247   StringRef StringTable;
248 
249 public:
begin()250   ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const {
251     return Parameters.begin();
252   }
253 
end()254   ViewArray<dxbc::ProgramSignatureElement>::iterator end() const {
255     return Parameters.end();
256   }
257 
getName(uint32_t Offset)258   StringRef getName(uint32_t Offset) const {
259     assert(Offset >= StringTableOffset &&
260            Offset < StringTableOffset + StringTable.size() &&
261            "Offset out of range.");
262     // Name offsets are from the start of the signature data, not from the start
263     // of the string table. The header encodes the start offset of the sting
264     // table, so we convert the offset here.
265     uint32_t TableOffset = Offset - StringTableOffset;
266     return StringTable.slice(TableOffset, StringTable.find('\0', TableOffset));
267   }
268 
isEmpty()269   bool isEmpty() const { return Parameters.isEmpty(); }
270 
271   Error initialize(StringRef Part);
272 };
273 
274 } // namespace DirectX
275 
276 class DXContainer {
277 public:
278   using DXILData = std::pair<dxbc::ProgramHeader, const char *>;
279 
280 private:
281   DXContainer(MemoryBufferRef O);
282 
283   MemoryBufferRef Data;
284   dxbc::Header Header;
285   SmallVector<uint32_t, 4> PartOffsets;
286   std::optional<DXILData> DXIL;
287   std::optional<uint64_t> ShaderFeatureFlags;
288   std::optional<dxbc::ShaderHash> Hash;
289   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
290   DirectX::Signature InputSignature;
291   DirectX::Signature OutputSignature;
292   DirectX::Signature PatchConstantSignature;
293 
294   Error parseHeader();
295   Error parsePartOffsets();
296   Error parseDXILHeader(StringRef Part);
297   Error parseShaderFeatureFlags(StringRef Part);
298   Error parseHash(StringRef Part);
299   Error parsePSVInfo(StringRef Part);
300   Error parseSignature(StringRef Part, DirectX::Signature &Array);
301   friend class PartIterator;
302 
303 public:
304   // The PartIterator is a wrapper around the iterator for the PartOffsets
305   // member of the DXContainer. It contains a refernce to the container, and the
306   // current iterator value, as well as storage for a parsed part header.
307   class PartIterator {
308     const DXContainer &Container;
309     SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
310     struct PartData {
311       dxbc::PartHeader Part;
312       uint32_t Offset;
313       StringRef Data;
314     } IteratorState;
315 
316     friend class DXContainer;
317 
PartIterator(const DXContainer & C,SmallVectorImpl<uint32_t>::const_iterator It)318     PartIterator(const DXContainer &C,
319                  SmallVectorImpl<uint32_t>::const_iterator It)
320         : Container(C), OffsetIt(It) {
321       if (OffsetIt == Container.PartOffsets.end())
322         updateIteratorImpl(Container.PartOffsets.back());
323       else
324         updateIterator();
325     }
326 
327     // Updates the iterator's state data. This results in copying the part
328     // header into the iterator and handling any required byte swapping. This is
329     // called when incrementing or decrementing the iterator.
updateIterator()330     void updateIterator() {
331       if (OffsetIt != Container.PartOffsets.end())
332         updateIteratorImpl(*OffsetIt);
333     }
334 
335     // Implementation for updating the iterator state based on a specified
336     // offest.
337     void updateIteratorImpl(const uint32_t Offset);
338 
339   public:
340     PartIterator &operator++() {
341       if (OffsetIt == Container.PartOffsets.end())
342         return *this;
343       ++OffsetIt;
344       updateIterator();
345       return *this;
346     }
347 
348     PartIterator operator++(int) {
349       PartIterator Tmp = *this;
350       ++(*this);
351       return Tmp;
352     }
353 
354     bool operator==(const PartIterator &RHS) const {
355       return OffsetIt == RHS.OffsetIt;
356     }
357 
358     bool operator!=(const PartIterator &RHS) const {
359       return OffsetIt != RHS.OffsetIt;
360     }
361 
362     const PartData &operator*() { return IteratorState; }
363     const PartData *operator->() { return &IteratorState; }
364   };
365 
begin()366   PartIterator begin() const {
367     return PartIterator(*this, PartOffsets.begin());
368   }
369 
end()370   PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
371 
getData()372   StringRef getData() const { return Data.getBuffer(); }
373   static Expected<DXContainer> create(MemoryBufferRef Object);
374 
getHeader()375   const dxbc::Header &getHeader() const { return Header; }
376 
getDXIL()377   const std::optional<DXILData> &getDXIL() const { return DXIL; }
378 
getShaderFeatureFlags()379   std::optional<uint64_t> getShaderFeatureFlags() const {
380     return ShaderFeatureFlags;
381   }
382 
getShaderHash()383   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
384 
getPSVInfo()385   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
386     return PSVInfo;
387   };
388 
getInputSignature()389   const DirectX::Signature &getInputSignature() const { return InputSignature; }
getOutputSignature()390   const DirectX::Signature &getOutputSignature() const {
391     return OutputSignature;
392   }
getPatchConstantSignature()393   const DirectX::Signature &getPatchConstantSignature() const {
394     return PatchConstantSignature;
395   }
396 };
397 
398 } // namespace object
399 } // namespace llvm
400 
401 #endif // LLVM_OBJECT_DXCONTAINER_H
402