xref: /freebsd/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXContainer.cpp - DXContainer object file implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Object/DXContainer.h"
10 #include "llvm/BinaryFormat/DXContainer.h"
11 #include "llvm/Object/Error.h"
12 #include "llvm/Support/Endian.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 using namespace llvm;
16 using namespace llvm::object;
17 
parseFailed(const Twine & Msg)18 static Error parseFailed(const Twine &Msg) {
19   return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);
20 }
21 
22 template <typename T>
readStruct(StringRef Buffer,const char * Src,T & Struct)23 static Error readStruct(StringRef Buffer, const char *Src, T &Struct) {
24   // Don't read before the beginning or past the end of the file
25   if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())
26     return parseFailed("Reading structure out of file bounds");
27 
28   memcpy(&Struct, Src, sizeof(T));
29   // DXContainer is always little endian
30   if (sys::IsBigEndianHost)
31     Struct.swapBytes();
32   return Error::success();
33 }
34 
35 template <typename T>
readInteger(StringRef Buffer,const char * Src,T & Val,Twine Str="structure")36 static Error readInteger(StringRef Buffer, const char *Src, T &Val,
37                          Twine Str = "structure") {
38   static_assert(std::is_integral_v<T>,
39                 "Cannot call readInteger on non-integral type.");
40   // Don't read before the beginning or past the end of the file
41   if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())
42     return parseFailed(Twine("Reading ") + Str + " out of file bounds");
43 
44   // The DXContainer offset table is comprised of uint32_t values but not padded
45   // to a 64-bit boundary. So Parts may start unaligned if there is an odd
46   // number of parts and part data itself is not required to be padded.
47   if (reinterpret_cast<uintptr_t>(Src) % alignof(T) != 0)
48     memcpy(reinterpret_cast<char *>(&Val), Src, sizeof(T));
49   else
50     Val = *reinterpret_cast<const T *>(Src);
51   // DXContainer is always little endian
52   if (sys::IsBigEndianHost)
53     sys::swapByteOrder(Val);
54   return Error::success();
55 }
56 
DXContainer(MemoryBufferRef O)57 DXContainer::DXContainer(MemoryBufferRef O) : Data(O) {}
58 
parseHeader()59 Error DXContainer::parseHeader() {
60   return readStruct(Data.getBuffer(), Data.getBuffer().data(), Header);
61 }
62 
parseDXILHeader(StringRef Part)63 Error DXContainer::parseDXILHeader(StringRef Part) {
64   if (DXIL)
65     return parseFailed("More than one DXIL part is present in the file");
66   const char *Current = Part.begin();
67   dxbc::ProgramHeader Header;
68   if (Error Err = readStruct(Part, Current, Header))
69     return Err;
70   Current += offsetof(dxbc::ProgramHeader, Bitcode) + Header.Bitcode.Offset;
71   DXIL.emplace(std::make_pair(Header, Current));
72   return Error::success();
73 }
74 
parseShaderFeatureFlags(StringRef Part)75 Error DXContainer::parseShaderFeatureFlags(StringRef Part) {
76   if (ShaderFeatureFlags)
77     return parseFailed("More than one SFI0 part is present in the file");
78   uint64_t FlagValue = 0;
79   if (Error Err = readInteger(Part, Part.begin(), FlagValue))
80     return Err;
81   ShaderFeatureFlags = FlagValue;
82   return Error::success();
83 }
84 
parseHash(StringRef Part)85 Error DXContainer::parseHash(StringRef Part) {
86   if (Hash)
87     return parseFailed("More than one HASH part is present in the file");
88   dxbc::ShaderHash ReadHash;
89   if (Error Err = readStruct(Part, Part.begin(), ReadHash))
90     return Err;
91   Hash = ReadHash;
92   return Error::success();
93 }
94 
parseRootSignature(StringRef Part)95 Error DXContainer::parseRootSignature(StringRef Part) {
96   if (RootSignature)
97     return parseFailed("More than one RTS0 part is present in the file");
98   RootSignature = DirectX::RootSignature(Part);
99   if (Error Err = RootSignature->parse())
100     return Err;
101   return Error::success();
102 }
103 
parsePSVInfo(StringRef Part)104 Error DXContainer::parsePSVInfo(StringRef Part) {
105   if (PSVInfo)
106     return parseFailed("More than one PSV0 part is present in the file");
107   PSVInfo = DirectX::PSVRuntimeInfo(Part);
108   // Parsing the PSVRuntime info occurs late because we need to read data from
109   // other parts first.
110   return Error::success();
111 }
112 
initialize(StringRef Part)113 Error DirectX::Signature::initialize(StringRef Part) {
114   dxbc::ProgramSignatureHeader SigHeader;
115   if (Error Err = readStruct(Part, Part.begin(), SigHeader))
116     return Err;
117   size_t Size = sizeof(dxbc::ProgramSignatureElement) * SigHeader.ParamCount;
118 
119   if (Part.size() < Size + SigHeader.FirstParamOffset)
120     return parseFailed("Signature parameters extend beyond the part boundary");
121 
122   Parameters.Data = Part.substr(SigHeader.FirstParamOffset, Size);
123 
124   StringTableOffset = SigHeader.FirstParamOffset + static_cast<uint32_t>(Size);
125   StringTable = Part.substr(SigHeader.FirstParamOffset + Size);
126 
127   for (const auto &Param : Parameters) {
128     if (Param.NameOffset < StringTableOffset)
129       return parseFailed("Invalid parameter name offset: name starts before "
130                          "the first name offset");
131     if (Param.NameOffset - StringTableOffset > StringTable.size())
132       return parseFailed("Invalid parameter name offset: name starts after the "
133                          "end of the part data");
134   }
135   return Error::success();
136 }
137 
parsePartOffsets()138 Error DXContainer::parsePartOffsets() {
139   uint32_t LastOffset =
140       sizeof(dxbc::Header) + (Header.PartCount * sizeof(uint32_t));
141   const char *Current = Data.getBuffer().data() + sizeof(dxbc::Header);
142   for (uint32_t Part = 0; Part < Header.PartCount; ++Part) {
143     uint32_t PartOffset;
144     if (Error Err = readInteger(Data.getBuffer(), Current, PartOffset))
145       return Err;
146     if (PartOffset < LastOffset)
147       return parseFailed(
148           formatv(
149               "Part offset for part {0} begins before the previous part ends",
150               Part)
151               .str());
152     Current += sizeof(uint32_t);
153     if (PartOffset >= Data.getBufferSize())
154       return parseFailed("Part offset points beyond boundary of the file");
155     // To prevent overflow when reading the part name, we subtract the part name
156     // size from the buffer size, rather than adding to the offset. Since the
157     // file header is larger than the part header we can't reach this code
158     // unless the buffer is at least as large as a part header, so this
159     // subtraction can't underflow.
160     if (PartOffset >= Data.getBufferSize() - sizeof(dxbc::PartHeader::Name))
161       return parseFailed("File not large enough to read part name");
162     PartOffsets.push_back(PartOffset);
163 
164     dxbc::PartType PT =
165         dxbc::parsePartType(Data.getBuffer().substr(PartOffset, 4));
166     uint32_t PartDataStart = PartOffset + sizeof(dxbc::PartHeader);
167     uint32_t PartSize;
168     if (Error Err = readInteger(Data.getBuffer(),
169                                 Data.getBufferStart() + PartOffset + 4,
170                                 PartSize, "part size"))
171       return Err;
172     StringRef PartData = Data.getBuffer().substr(PartDataStart, PartSize);
173     LastOffset = PartOffset + PartSize;
174     switch (PT) {
175     case dxbc::PartType::DXIL:
176       if (Error Err = parseDXILHeader(PartData))
177         return Err;
178       break;
179     case dxbc::PartType::SFI0:
180       if (Error Err = parseShaderFeatureFlags(PartData))
181         return Err;
182       break;
183     case dxbc::PartType::HASH:
184       if (Error Err = parseHash(PartData))
185         return Err;
186       break;
187     case dxbc::PartType::PSV0:
188       if (Error Err = parsePSVInfo(PartData))
189         return Err;
190       break;
191     case dxbc::PartType::ISG1:
192       if (Error Err = InputSignature.initialize(PartData))
193         return Err;
194       break;
195     case dxbc::PartType::OSG1:
196       if (Error Err = OutputSignature.initialize(PartData))
197         return Err;
198       break;
199     case dxbc::PartType::PSG1:
200       if (Error Err = PatchConstantSignature.initialize(PartData))
201         return Err;
202       break;
203     case dxbc::PartType::Unknown:
204       break;
205     case dxbc::PartType::RTS0:
206       if (Error Err = parseRootSignature(PartData))
207         return Err;
208       break;
209     }
210   }
211 
212   // Fully parsing the PSVInfo requires knowing the shader kind which we read
213   // out of the program header in the DXIL part.
214   if (PSVInfo) {
215     if (!DXIL)
216       return parseFailed("Cannot fully parse pipeline state validation "
217                          "information without DXIL part.");
218     if (Error Err = PSVInfo->parse(DXIL->first.ShaderKind))
219       return Err;
220   }
221   return Error::success();
222 }
223 
create(MemoryBufferRef Object)224 Expected<DXContainer> DXContainer::create(MemoryBufferRef Object) {
225   DXContainer Container(Object);
226   if (Error Err = Container.parseHeader())
227     return std::move(Err);
228   if (Error Err = Container.parsePartOffsets())
229     return std::move(Err);
230   return Container;
231 }
232 
updateIteratorImpl(const uint32_t Offset)233 void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset) {
234   StringRef Buffer = Container.Data.getBuffer();
235   const char *Current = Buffer.data() + Offset;
236   // Offsets are validated during parsing, so all offsets in the container are
237   // valid and contain enough readable data to read a header.
238   cantFail(readStruct(Buffer, Current, IteratorState.Part));
239   IteratorState.Data =
240       StringRef(Current + sizeof(dxbc::PartHeader), IteratorState.Part.Size);
241   IteratorState.Offset = Offset;
242 }
243 
parse()244 Error DirectX::RootSignature::parse() {
245   const char *Current = PartData.begin();
246 
247   // Root Signature headers expects 6 integers to be present.
248   if (PartData.size() < 6 * sizeof(uint32_t))
249     return parseFailed(
250         "Invalid root signature, insufficient space for header.");
251 
252   Version = support::endian::read<uint32_t, llvm::endianness::little>(Current);
253   Current += sizeof(uint32_t);
254 
255   NumParameters =
256       support::endian::read<uint32_t, llvm::endianness::little>(Current);
257   Current += sizeof(uint32_t);
258 
259   RootParametersOffset =
260       support::endian::read<uint32_t, llvm::endianness::little>(Current);
261   Current += sizeof(uint32_t);
262 
263   NumStaticSamplers =
264       support::endian::read<uint32_t, llvm::endianness::little>(Current);
265   Current += sizeof(uint32_t);
266 
267   StaticSamplersOffset =
268       support::endian::read<uint32_t, llvm::endianness::little>(Current);
269   Current += sizeof(uint32_t);
270 
271   Flags = support::endian::read<uint32_t, llvm::endianness::little>(Current);
272   Current += sizeof(uint32_t);
273 
274   ParametersHeaders.Data = PartData.substr(
275       RootParametersOffset,
276       NumParameters * sizeof(dxbc::RTS0::v1::RootParameterHeader));
277 
278   StaticSamplers.Stride = sizeof(dxbc::RTS0::v1::StaticSampler);
279   StaticSamplers.Data = PartData.substr(
280       StaticSamplersOffset,
281       NumStaticSamplers * sizeof(dxbc::RTS0::v1::StaticSampler));
282 
283   return Error::success();
284 }
285 
parse(uint16_t ShaderKind)286 Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
287   Triple::EnvironmentType ShaderStage = dxbc::getShaderStage(ShaderKind);
288 
289   const char *Current = Data.begin();
290   if (Error Err = readInteger(Data, Current, Size))
291     return Err;
292   Current += sizeof(uint32_t);
293 
294   StringRef PSVInfoData = Data.substr(sizeof(uint32_t), Size);
295 
296   if (PSVInfoData.size() < Size)
297     return parseFailed(
298         "Pipeline state data extends beyond the bounds of the part");
299 
300   using namespace dxbc::PSV;
301 
302   const uint32_t PSVVersion = getVersion();
303 
304   // Detect the PSVVersion by looking at the size field.
305   if (PSVVersion == 3) {
306     v3::RuntimeInfo Info;
307     if (Error Err = readStruct(PSVInfoData, Current, Info))
308       return Err;
309     if (sys::IsBigEndianHost)
310       Info.swapBytes(ShaderStage);
311     BasicInfo = Info;
312   } else if (PSVVersion == 2) {
313     v2::RuntimeInfo Info;
314     if (Error Err = readStruct(PSVInfoData, Current, Info))
315       return Err;
316     if (sys::IsBigEndianHost)
317       Info.swapBytes(ShaderStage);
318     BasicInfo = Info;
319   } else if (PSVVersion == 1) {
320     v1::RuntimeInfo Info;
321     if (Error Err = readStruct(PSVInfoData, Current, Info))
322       return Err;
323     if (sys::IsBigEndianHost)
324       Info.swapBytes(ShaderStage);
325     BasicInfo = Info;
326   } else if (PSVVersion == 0) {
327     v0::RuntimeInfo Info;
328     if (Error Err = readStruct(PSVInfoData, Current, Info))
329       return Err;
330     if (sys::IsBigEndianHost)
331       Info.swapBytes(ShaderStage);
332     BasicInfo = Info;
333   } else
334     return parseFailed(
335         "Cannot read PSV Runtime Info, unsupported PSV version.");
336 
337   Current += Size;
338 
339   uint32_t ResourceCount = 0;
340   if (Error Err = readInteger(Data, Current, ResourceCount))
341     return Err;
342   Current += sizeof(uint32_t);
343 
344   if (ResourceCount > 0) {
345     if (Error Err = readInteger(Data, Current, Resources.Stride))
346       return Err;
347     Current += sizeof(uint32_t);
348 
349     size_t BindingDataSize = Resources.Stride * ResourceCount;
350     Resources.Data = Data.substr(Current - Data.begin(), BindingDataSize);
351 
352     if (Resources.Data.size() < BindingDataSize)
353       return parseFailed(
354           "Resource binding data extends beyond the bounds of the part");
355 
356     Current += BindingDataSize;
357   } else
358     Resources.Stride = sizeof(v2::ResourceBindInfo);
359 
360   // PSV version 0 ends after the resource bindings.
361   if (PSVVersion == 0)
362     return Error::success();
363 
364   // String table starts at a 4-byte offset.
365   Current = reinterpret_cast<const char *>(
366       alignTo<4>(reinterpret_cast<uintptr_t>(Current)));
367 
368   uint32_t StringTableSize = 0;
369   if (Error Err = readInteger(Data, Current, StringTableSize))
370     return Err;
371   if (StringTableSize % 4 != 0)
372     return parseFailed("String table misaligned");
373   Current += sizeof(uint32_t);
374   StringTable = StringRef(Current, StringTableSize);
375 
376   Current += StringTableSize;
377 
378   uint32_t SemanticIndexTableSize = 0;
379   if (Error Err = readInteger(Data, Current, SemanticIndexTableSize))
380     return Err;
381   Current += sizeof(uint32_t);
382 
383   SemanticIndexTable.reserve(SemanticIndexTableSize);
384   for (uint32_t I = 0; I < SemanticIndexTableSize; ++I) {
385     uint32_t Index = 0;
386     if (Error Err = readInteger(Data, Current, Index))
387       return Err;
388     Current += sizeof(uint32_t);
389     SemanticIndexTable.push_back(Index);
390   }
391 
392   uint8_t InputCount = getSigInputCount();
393   uint8_t OutputCount = getSigOutputCount();
394   uint8_t PatchOrPrimCount = getSigPatchOrPrimCount();
395 
396   uint32_t ElementCount = InputCount + OutputCount + PatchOrPrimCount;
397 
398   if (ElementCount > 0) {
399     if (Error Err = readInteger(Data, Current, SigInputElements.Stride))
400       return Err;
401     Current += sizeof(uint32_t);
402     // Assign the stride to all the arrays.
403     SigOutputElements.Stride = SigPatchOrPrimElements.Stride =
404         SigInputElements.Stride;
405 
406     if (Data.end() - Current <
407         (ptrdiff_t)(ElementCount * SigInputElements.Stride))
408       return parseFailed(
409           "Signature elements extend beyond the size of the part");
410 
411     size_t InputSize = SigInputElements.Stride * InputCount;
412     SigInputElements.Data = Data.substr(Current - Data.begin(), InputSize);
413     Current += InputSize;
414 
415     size_t OutputSize = SigOutputElements.Stride * OutputCount;
416     SigOutputElements.Data = Data.substr(Current - Data.begin(), OutputSize);
417     Current += OutputSize;
418 
419     size_t PSize = SigPatchOrPrimElements.Stride * PatchOrPrimCount;
420     SigPatchOrPrimElements.Data = Data.substr(Current - Data.begin(), PSize);
421     Current += PSize;
422   }
423 
424   ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();
425   uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();
426   uint8_t InputVectorCount = getInputVectorCount();
427 
428   auto maskDwordSize = [](uint8_t Vector) {
429     return (static_cast<uint32_t>(Vector) + 7) >> 3;
430   };
431 
432   auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {
433     return maskDwordSize(Y) * X * 4;
434   };
435 
436   if (usesViewID()) {
437     for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {
438       // The vector mask is one bit per component and 4 components per vector.
439       // We can compute the number of dwords required by rounding up to the next
440       // multiple of 8.
441       uint32_t NumDwords =
442           maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));
443       size_t NumBytes = NumDwords * sizeof(uint32_t);
444       OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);
445       Current += NumBytes;
446     }
447 
448     if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {
449       uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);
450       size_t NumBytes = NumDwords * sizeof(uint32_t);
451       PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);
452       Current += NumBytes;
453     }
454   }
455 
456   // Input/Output mapping table
457   for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {
458     if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)
459       continue;
460     uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);
461     size_t NumBytes = NumDwords * sizeof(uint32_t);
462     InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);
463     Current += NumBytes;
464   }
465 
466   // Hull shader: Input/Patch mapping table
467   if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&
468       InputVectorCount > 0) {
469     uint32_t NumDwords =
470         mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);
471     size_t NumBytes = NumDwords * sizeof(uint32_t);
472     InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);
473     Current += NumBytes;
474   }
475 
476   // Domain Shader: Patch/Output mapping table
477   if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
478       OutputVectorCounts[0] > 0) {
479     uint32_t NumDwords =
480         mapTableSize(PatchConstOrPrimVectorCount, OutputVectorCounts[0]);
481     size_t NumBytes = NumDwords * sizeof(uint32_t);
482     PatchOutputMap.Data = Data.substr(Current - Data.begin(), NumBytes);
483     Current += NumBytes;
484   }
485 
486   return Error::success();
487 }
488 
getSigInputCount() const489 uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {
490   if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
491     return P->SigInputElements;
492   if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
493     return P->SigInputElements;
494   if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
495     return P->SigInputElements;
496   return 0;
497 }
498 
getSigOutputCount() const499 uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {
500   if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
501     return P->SigOutputElements;
502   if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
503     return P->SigOutputElements;
504   if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
505     return P->SigOutputElements;
506   return 0;
507 }
508 
getSigPatchOrPrimCount() const509 uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const {
510   if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
511     return P->SigPatchOrPrimElements;
512   if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
513     return P->SigPatchOrPrimElements;
514   if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
515     return P->SigPatchOrPrimElements;
516   return 0;
517 }
518