1 //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the DXContainerFile class, which implements the ObjectFile
10 // interface for DXContainer files.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_OBJECT_DXCONTAINER_H
16 #define LLVM_OBJECT_DXCONTAINER_H
17
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/BinaryFormat/DXContainer.h"
22 #include "llvm/Object/Error.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/MemoryBufferRef.h"
27 #include "llvm/TargetParser/Triple.h"
28 #include <array>
29 #include <cstddef>
30 #include <cstdint>
31 #include <variant>
32
33 namespace llvm {
34 namespace object {
35
36 namespace detail {
37 template <typename T>
swapBytes(T & value)38 std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
39 sys::swapByteOrder(value);
40 }
41
42 template <typename T>
swapBytes(T & value)43 std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
44 value.swapBytes();
45 }
46 } // namespace detail
47
48 // This class provides a view into the underlying resource array. The Resource
49 // data is little-endian encoded and may not be properly aligned to read
50 // directly from. The dereference operator creates a copy of the data and byte
51 // swaps it as appropriate.
52 template <typename T> struct ViewArray {
53 StringRef Data;
54 uint32_t Stride = sizeof(T); // size of each element in the list.
55
56 ViewArray() = default;
ViewArrayViewArray57 ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
58
59 using value_type = T;
MaxStrideViewArray60 static constexpr uint32_t MaxStride() {
61 return static_cast<uint32_t>(sizeof(value_type));
62 }
63
64 struct iterator {
65 StringRef Data;
66 uint32_t Stride; // size of each element in the list.
67 const char *Current;
68
iteratorViewArray::iterator69 iterator(const ViewArray &A, const char *C)
70 : Data(A.Data), Stride(A.Stride), Current(C) {}
71 iterator(const iterator &) = default;
72
73 value_type operator*() {
74 // Explicitly zero the structure so that unused fields are zeroed. It is
75 // up to the user to know if the fields are used by verifying the PSV
76 // version.
77 value_type Val;
78 std::memset(&Val, 0, sizeof(value_type));
79 if (Current >= Data.end())
80 return Val;
81 memcpy(static_cast<void *>(&Val), Current, std::min(Stride, MaxStride()));
82 if (sys::IsBigEndianHost)
83 detail::swapBytes(Val);
84 return Val;
85 }
86
87 iterator operator++() {
88 if (Current < Data.end())
89 Current += Stride;
90 return *this;
91 }
92
93 iterator operator++(int) {
94 iterator Tmp = *this;
95 ++*this;
96 return Tmp;
97 }
98
99 iterator operator--() {
100 if (Current > Data.begin())
101 Current -= Stride;
102 return *this;
103 }
104
105 iterator operator--(int) {
106 iterator Tmp = *this;
107 --*this;
108 return Tmp;
109 }
110
111 bool operator==(const iterator I) { return I.Current == Current; }
112 bool operator!=(const iterator I) { return !(*this == I); }
113 };
114
beginViewArray115 iterator begin() const { return iterator(*this, Data.begin()); }
116
endViewArray117 iterator end() const { return iterator(*this, Data.end()); }
118
sizeViewArray119 size_t size() const { return Data.size() / Stride; }
120
isEmptyViewArray121 bool isEmpty() const { return Data.empty(); }
122 };
123
124 namespace DirectX {
125 struct RootParameterView {
126 const dxbc::RTS0::v1::RootParameterHeader &Header;
127 StringRef ParamData;
128
RootParameterViewRootParameterView129 RootParameterView(const dxbc::RTS0::v1::RootParameterHeader &H, StringRef P)
130 : Header(H), ParamData(P) {}
131
readParameterRootParameterView132 template <typename T> Expected<T> readParameter() {
133 T Struct;
134 if (sizeof(T) != ParamData.size())
135 return make_error<GenericBinaryError>(
136 "Reading structure out of file bounds", object_error::parse_failed);
137
138 memcpy(&Struct, ParamData.data(), sizeof(T));
139 // DXContainer is always little endian
140 if (sys::IsBigEndianHost)
141 Struct.swapBytes();
142 return Struct;
143 }
144 };
145
146 struct RootConstantView : RootParameterView {
classofRootConstantView147 static bool classof(const RootParameterView *V) {
148 return V->Header.ParameterType ==
149 (uint32_t)dxbc::RootParameterType::Constants32Bit;
150 }
151
readRootConstantView152 llvm::Expected<dxbc::RTS0::v1::RootConstants> read() {
153 return readParameter<dxbc::RTS0::v1::RootConstants>();
154 }
155 };
156
157 struct RootDescriptorView : RootParameterView {
classofRootDescriptorView158 static bool classof(const RootParameterView *V) {
159 return (V->Header.ParameterType ==
160 llvm::to_underlying(dxbc::RootParameterType::CBV) ||
161 V->Header.ParameterType ==
162 llvm::to_underlying(dxbc::RootParameterType::SRV) ||
163 V->Header.ParameterType ==
164 llvm::to_underlying(dxbc::RootParameterType::UAV));
165 }
166
readRootDescriptorView167 llvm::Expected<dxbc::RTS0::v2::RootDescriptor> read(uint32_t Version) {
168 if (Version == 1) {
169 auto Descriptor = readParameter<dxbc::RTS0::v1::RootDescriptor>();
170 if (Error E = Descriptor.takeError())
171 return E;
172 return dxbc::RTS0::v2::RootDescriptor(*Descriptor);
173 }
174 if (Version != 2)
175 return make_error<GenericBinaryError>("Invalid Root Signature version: " +
176 Twine(Version),
177 object_error::parse_failed);
178 return readParameter<dxbc::RTS0::v2::RootDescriptor>();
179 }
180 };
181 template <typename T> struct DescriptorTable {
182 uint32_t NumRanges;
183 uint32_t RangesOffset;
184 ViewArray<T> Ranges;
185
beginDescriptorTable186 typename ViewArray<T>::iterator begin() const { return Ranges.begin(); }
187
endDescriptorTable188 typename ViewArray<T>::iterator end() const { return Ranges.end(); }
189 };
190
191 struct DescriptorTableView : RootParameterView {
classofDescriptorTableView192 static bool classof(const RootParameterView *V) {
193 return (V->Header.ParameterType ==
194 llvm::to_underlying(dxbc::RootParameterType::DescriptorTable));
195 }
196
197 // Define a type alias to access the template parameter from inside classof
readDescriptorTableView198 template <typename T> llvm::Expected<DescriptorTable<T>> read() {
199 const char *Current = ParamData.begin();
200 DescriptorTable<T> Table;
201
202 Table.NumRanges =
203 support::endian::read<uint32_t, llvm::endianness::little>(Current);
204 Current += sizeof(uint32_t);
205
206 Table.RangesOffset =
207 support::endian::read<uint32_t, llvm::endianness::little>(Current);
208 Current += sizeof(uint32_t);
209
210 Table.Ranges.Data = ParamData.substr(2 * sizeof(uint32_t),
211 Table.NumRanges * Table.Ranges.Stride);
212 return Table;
213 }
214 };
215
parseFailed(const Twine & Msg)216 static Error parseFailed(const Twine &Msg) {
217 return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);
218 }
219
220 class RootSignature {
221 private:
222 uint32_t Version;
223 uint32_t NumParameters;
224 uint32_t RootParametersOffset;
225 uint32_t NumStaticSamplers;
226 uint32_t StaticSamplersOffset;
227 uint32_t Flags;
228 ViewArray<dxbc::RTS0::v1::RootParameterHeader> ParametersHeaders;
229 StringRef PartData;
230 ViewArray<dxbc::RTS0::v1::StaticSampler> StaticSamplers;
231
232 using param_header_iterator =
233 ViewArray<dxbc::RTS0::v1::RootParameterHeader>::iterator;
234 using samplers_iterator = ViewArray<dxbc::RTS0::v1::StaticSampler>::iterator;
235
236 public:
RootSignature(StringRef PD)237 RootSignature(StringRef PD) : PartData(PD) {}
238
239 LLVM_ABI Error parse();
getVersion()240 uint32_t getVersion() const { return Version; }
getNumParameters()241 uint32_t getNumParameters() const { return NumParameters; }
getRootParametersOffset()242 uint32_t getRootParametersOffset() const { return RootParametersOffset; }
getNumStaticSamplers()243 uint32_t getNumStaticSamplers() const { return NumStaticSamplers; }
getStaticSamplersOffset()244 uint32_t getStaticSamplersOffset() const { return StaticSamplersOffset; }
getNumRootParameters()245 uint32_t getNumRootParameters() const { return ParametersHeaders.size(); }
param_headers()246 llvm::iterator_range<param_header_iterator> param_headers() const {
247 return llvm::make_range(ParametersHeaders.begin(), ParametersHeaders.end());
248 }
samplers()249 llvm::iterator_range<samplers_iterator> samplers() const {
250 return llvm::make_range(StaticSamplers.begin(), StaticSamplers.end());
251 }
getFlags()252 uint32_t getFlags() const { return Flags; }
253
254 llvm::Expected<RootParameterView>
getParameter(const dxbc::RTS0::v1::RootParameterHeader & Header)255 getParameter(const dxbc::RTS0::v1::RootParameterHeader &Header) const {
256 size_t DataSize;
257 size_t EndOfSectionByte = getNumStaticSamplers() == 0
258 ? PartData.size()
259 : getStaticSamplersOffset();
260
261 if (!dxbc::isValidParameterType(Header.ParameterType))
262 return parseFailed("invalid parameter type");
263
264 switch (static_cast<dxbc::RootParameterType>(Header.ParameterType)) {
265 case dxbc::RootParameterType::Constants32Bit:
266 DataSize = sizeof(dxbc::RTS0::v1::RootConstants);
267 break;
268 case dxbc::RootParameterType::CBV:
269 case dxbc::RootParameterType::SRV:
270 case dxbc::RootParameterType::UAV:
271 if (Version == 1)
272 DataSize = sizeof(dxbc::RTS0::v1::RootDescriptor);
273 else
274 DataSize = sizeof(dxbc::RTS0::v2::RootDescriptor);
275 break;
276 case dxbc::RootParameterType::DescriptorTable:
277 if (Header.ParameterOffset + sizeof(uint32_t) > EndOfSectionByte)
278 return parseFailed("Reading structure out of file bounds");
279
280 uint32_t NumRanges =
281 support::endian::read<uint32_t, llvm::endianness::little>(
282 PartData.begin() + Header.ParameterOffset);
283 if (Version == 1)
284 DataSize = sizeof(dxbc::RTS0::v1::DescriptorRange) * NumRanges;
285 else
286 DataSize = sizeof(dxbc::RTS0::v2::DescriptorRange) * NumRanges;
287
288 // 4 bytes for the number of ranges in table and
289 // 4 bytes for the ranges offset
290 DataSize += 2 * sizeof(uint32_t);
291 break;
292 }
293 if (Header.ParameterOffset + DataSize > EndOfSectionByte)
294 return parseFailed("Reading structure out of file bounds");
295
296 StringRef Buff = PartData.substr(Header.ParameterOffset, DataSize);
297 RootParameterView View = RootParameterView(Header, Buff);
298 return View;
299 }
300 };
301
302 class PSVRuntimeInfo {
303
304 using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>;
305 using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>;
306
307 StringRef Data;
308 uint32_t Size;
309 using InfoStruct =
310 std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
311 dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo,
312 dxbc::PSV::v3::RuntimeInfo>;
313 InfoStruct BasicInfo;
314 ResourceArray Resources;
315 StringRef StringTable;
316 SmallVector<uint32_t> SemanticIndexTable;
317 SigElementArray SigInputElements;
318 SigElementArray SigOutputElements;
319 SigElementArray SigPatchOrPrimElements;
320
321 std::array<ViewArray<uint32_t>, 4> OutputVectorMasks;
322 ViewArray<uint32_t> PatchOrPrimMasks;
323 std::array<ViewArray<uint32_t>, 4> InputOutputMap;
324 ViewArray<uint32_t> InputPatchMap;
325 ViewArray<uint32_t> PatchOutputMap;
326
327 public:
PSVRuntimeInfo(StringRef D)328 PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
329
330 // Parsing depends on the shader kind
331 LLVM_ABI Error parse(uint16_t ShaderKind);
332
getSize()333 uint32_t getSize() const { return Size; }
getResourceCount()334 uint32_t getResourceCount() const { return Resources.size(); }
getResources()335 ResourceArray getResources() const { return Resources; }
336
getVersion()337 uint32_t getVersion() const {
338 return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo)
339 ? 3
340 : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo) ? 2
341 : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1
342 : 0);
343 }
344
getResourceStride()345 uint32_t getResourceStride() const { return Resources.Stride; }
346
getInfo()347 const InfoStruct &getInfo() const { return BasicInfo; }
348
getInfoAs()349 template <typename T> const T *getInfoAs() const {
350 if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
351 return static_cast<const T *>(P);
352 if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value)
353 return nullptr;
354
355 if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
356 return static_cast<const T *>(P);
357 if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
358 return nullptr;
359
360 if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
361 return static_cast<const T *>(P);
362 if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
363 return nullptr;
364
365 if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo))
366 return static_cast<const T *>(P);
367 return nullptr;
368 }
369
getStringTable()370 StringRef getStringTable() const { return StringTable; }
getSemanticIndexTable()371 ArrayRef<uint32_t> getSemanticIndexTable() const {
372 return SemanticIndexTable;
373 }
374
375 LLVM_ABI uint8_t getSigInputCount() const;
376 LLVM_ABI uint8_t getSigOutputCount() const;
377 LLVM_ABI uint8_t getSigPatchOrPrimCount() const;
378
getSigInputElements()379 SigElementArray getSigInputElements() const { return SigInputElements; }
getSigOutputElements()380 SigElementArray getSigOutputElements() const { return SigOutputElements; }
getSigPatchOrPrimElements()381 SigElementArray getSigPatchOrPrimElements() const {
382 return SigPatchOrPrimElements;
383 }
384
getOutputVectorMasks(size_t Idx)385 ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
386 assert(Idx < 4);
387 return OutputVectorMasks[Idx];
388 }
389
getPatchOrPrimMasks()390 ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
391
getInputOutputMap(size_t Idx)392 ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
393 assert(Idx < 4);
394 return InputOutputMap[Idx];
395 }
396
getInputPatchMap()397 ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
getPatchOutputMap()398 ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
399
getSigElementStride()400 uint32_t getSigElementStride() const { return SigInputElements.Stride; }
401
usesViewID()402 bool usesViewID() const {
403 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
404 return P->UsesViewID != 0;
405 return false;
406 }
407
getInputVectorCount()408 uint8_t getInputVectorCount() const {
409 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
410 return P->SigInputVectors;
411 return 0;
412 }
413
getOutputVectorCounts()414 ArrayRef<uint8_t> getOutputVectorCounts() const {
415 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
416 return ArrayRef<uint8_t>(P->SigOutputVectors);
417 return ArrayRef<uint8_t>();
418 }
419
getPatchConstOrPrimVectorCount()420 uint8_t getPatchConstOrPrimVectorCount() const {
421 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
422 return P->GeomData.SigPatchConstOrPrimVectors;
423 return 0;
424 }
425 };
426
427 class Signature {
428 ViewArray<dxbc::ProgramSignatureElement> Parameters;
429 uint32_t StringTableOffset;
430 StringRef StringTable;
431
432 public:
begin()433 ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const {
434 return Parameters.begin();
435 }
436
end()437 ViewArray<dxbc::ProgramSignatureElement>::iterator end() const {
438 return Parameters.end();
439 }
440
getName(uint32_t Offset)441 StringRef getName(uint32_t Offset) const {
442 assert(Offset >= StringTableOffset &&
443 Offset < StringTableOffset + StringTable.size() &&
444 "Offset out of range.");
445 // Name offsets are from the start of the signature data, not from the start
446 // of the string table. The header encodes the start offset of the sting
447 // table, so we convert the offset here.
448 uint32_t TableOffset = Offset - StringTableOffset;
449 return StringTable.slice(TableOffset, StringTable.find('\0', TableOffset));
450 }
451
isEmpty()452 bool isEmpty() const { return Parameters.isEmpty(); }
453
454 LLVM_ABI Error initialize(StringRef Part);
455 };
456
457 } // namespace DirectX
458
459 class DXContainer {
460 public:
461 using DXILData = std::pair<dxbc::ProgramHeader, const char *>;
462
463 private:
464 DXContainer(MemoryBufferRef O);
465
466 MemoryBufferRef Data;
467 dxbc::Header Header;
468 SmallVector<uint32_t, 4> PartOffsets;
469 std::optional<DXILData> DXIL;
470 std::optional<uint64_t> ShaderFeatureFlags;
471 std::optional<dxbc::ShaderHash> Hash;
472 std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
473 std::optional<DirectX::RootSignature> RootSignature;
474 DirectX::Signature InputSignature;
475 DirectX::Signature OutputSignature;
476 DirectX::Signature PatchConstantSignature;
477
478 Error parseHeader();
479 Error parsePartOffsets();
480 Error parseDXILHeader(StringRef Part);
481 Error parseShaderFeatureFlags(StringRef Part);
482 Error parseHash(StringRef Part);
483 Error parseRootSignature(StringRef Part);
484 Error parsePSVInfo(StringRef Part);
485 Error parseSignature(StringRef Part, DirectX::Signature &Array);
486 friend class PartIterator;
487
488 public:
489 // The PartIterator is a wrapper around the iterator for the PartOffsets
490 // member of the DXContainer. It contains a refernce to the container, and the
491 // current iterator value, as well as storage for a parsed part header.
492 class PartIterator {
493 const DXContainer &Container;
494 SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
495 struct PartData {
496 dxbc::PartHeader Part;
497 uint32_t Offset;
498 StringRef Data;
499 } IteratorState;
500
501 friend class DXContainer;
502
PartIterator(const DXContainer & C,SmallVectorImpl<uint32_t>::const_iterator It)503 PartIterator(const DXContainer &C,
504 SmallVectorImpl<uint32_t>::const_iterator It)
505 : Container(C), OffsetIt(It) {
506 if (OffsetIt == Container.PartOffsets.end())
507 updateIteratorImpl(Container.PartOffsets.back());
508 else
509 updateIterator();
510 }
511
512 // Updates the iterator's state data. This results in copying the part
513 // header into the iterator and handling any required byte swapping. This is
514 // called when incrementing or decrementing the iterator.
updateIterator()515 void updateIterator() {
516 if (OffsetIt != Container.PartOffsets.end())
517 updateIteratorImpl(*OffsetIt);
518 }
519
520 // Implementation for updating the iterator state based on a specified
521 // offest.
522 LLVM_ABI void updateIteratorImpl(const uint32_t Offset);
523
524 public:
525 PartIterator &operator++() {
526 if (OffsetIt == Container.PartOffsets.end())
527 return *this;
528 ++OffsetIt;
529 updateIterator();
530 return *this;
531 }
532
533 PartIterator operator++(int) {
534 PartIterator Tmp = *this;
535 ++(*this);
536 return Tmp;
537 }
538
539 bool operator==(const PartIterator &RHS) const {
540 return OffsetIt == RHS.OffsetIt;
541 }
542
543 bool operator!=(const PartIterator &RHS) const {
544 return OffsetIt != RHS.OffsetIt;
545 }
546
547 const PartData &operator*() { return IteratorState; }
548 const PartData *operator->() { return &IteratorState; }
549 };
550
begin()551 PartIterator begin() const {
552 return PartIterator(*this, PartOffsets.begin());
553 }
554
end()555 PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
556
getData()557 StringRef getData() const { return Data.getBuffer(); }
558 LLVM_ABI static Expected<DXContainer> create(MemoryBufferRef Object);
559
getHeader()560 const dxbc::Header &getHeader() const { return Header; }
561
getDXIL()562 const std::optional<DXILData> &getDXIL() const { return DXIL; }
563
getShaderFeatureFlags()564 std::optional<uint64_t> getShaderFeatureFlags() const {
565 return ShaderFeatureFlags;
566 }
567
getShaderHash()568 std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
569
getRootSignature()570 std::optional<DirectX::RootSignature> getRootSignature() const {
571 return RootSignature;
572 }
573
getPSVInfo()574 const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
575 return PSVInfo;
576 };
577
getInputSignature()578 const DirectX::Signature &getInputSignature() const { return InputSignature; }
getOutputSignature()579 const DirectX::Signature &getOutputSignature() const {
580 return OutputSignature;
581 }
getPatchConstantSignature()582 const DirectX::Signature &getPatchConstantSignature() const {
583 return PatchConstantSignature;
584 }
585 };
586
587 } // namespace object
588 } // namespace llvm
589
590 #endif // LLVM_OBJECT_DXCONTAINER_H
591