xref: /freebsd/contrib/llvm-project/llvm/lib/ObjectYAML/DXContainerEmitter.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- DXContainerEmitter.cpp - Convert YAML to a DXContainer -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Binary emitter for yaml to DXContainer binary
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/DXContainer.h"
15 #include "llvm/MC/DXContainerPSVInfo.h"
16 #include "llvm/MC/DXContainerRootSignature.h"
17 #include "llvm/ObjectYAML/ObjectYAML.h"
18 #include "llvm/ObjectYAML/yaml2obj.h"
19 #include "llvm/Support/Errc.h"
20 #include "llvm/Support/Error.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace llvm;
24 
25 namespace {
26 class DXContainerWriter {
27 public:
28   DXContainerWriter(DXContainerYAML::Object &ObjectFile)
29       : ObjectFile(ObjectFile) {}
30 
31   Error write(raw_ostream &OS);
32 
33 private:
34   DXContainerYAML::Object &ObjectFile;
35 
36   Error computePartOffsets();
37   Error validatePartOffsets();
38   Error validateSize(uint32_t Computed);
39 
40   void writeHeader(raw_ostream &OS);
41   void writeParts(raw_ostream &OS);
42 };
43 } // namespace
44 
45 Error DXContainerWriter::validateSize(uint32_t Computed) {
46   if (!ObjectFile.Header.FileSize)
47     ObjectFile.Header.FileSize = Computed;
48   else if (*ObjectFile.Header.FileSize < Computed)
49     return createStringError(errc::result_out_of_range,
50                              "File size specified is too small.");
51   return Error::success();
52 }
53 
54 Error DXContainerWriter::validatePartOffsets() {
55   if (ObjectFile.Parts.size() != ObjectFile.Header.PartOffsets->size())
56     return createStringError(
57         errc::invalid_argument,
58         "Mismatch between number of parts and part offsets.");
59   uint32_t RollingOffset =
60       sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
61   for (auto I : llvm::zip(ObjectFile.Parts, *ObjectFile.Header.PartOffsets)) {
62     if (RollingOffset > std::get<1>(I))
63       return createStringError(errc::invalid_argument,
64                                "Offset mismatch, not enough space for data.");
65     RollingOffset =
66         std::get<1>(I) + sizeof(dxbc::PartHeader) + std::get<0>(I).Size;
67   }
68   if (Error Err = validateSize(RollingOffset))
69     return Err;
70 
71   return Error::success();
72 }
73 
74 Error DXContainerWriter::computePartOffsets() {
75   if (ObjectFile.Header.PartOffsets)
76     return validatePartOffsets();
77   uint32_t RollingOffset =
78       sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
79   ObjectFile.Header.PartOffsets = std::vector<uint32_t>();
80   for (const auto &Part : ObjectFile.Parts) {
81     ObjectFile.Header.PartOffsets->push_back(RollingOffset);
82     RollingOffset += sizeof(dxbc::PartHeader) + Part.Size;
83   }
84   if (Error Err = validateSize(RollingOffset))
85     return Err;
86 
87   return Error::success();
88 }
89 
90 void DXContainerWriter::writeHeader(raw_ostream &OS) {
91   dxbc::Header Header;
92   memcpy(Header.Magic, "DXBC", 4);
93   memcpy(Header.FileHash.Digest, ObjectFile.Header.Hash.data(), 16);
94   Header.Version.Major = ObjectFile.Header.Version.Major;
95   Header.Version.Minor = ObjectFile.Header.Version.Minor;
96   Header.FileSize = *ObjectFile.Header.FileSize;
97   Header.PartCount = ObjectFile.Parts.size();
98   if (sys::IsBigEndianHost)
99     Header.swapBytes();
100   OS.write(reinterpret_cast<char *>(&Header), sizeof(Header));
101   SmallVector<uint32_t> Offsets(ObjectFile.Header.PartOffsets->begin(),
102                                 ObjectFile.Header.PartOffsets->end());
103   if (sys::IsBigEndianHost)
104     for (auto &O : Offsets)
105       sys::swapByteOrder(O);
106   OS.write(reinterpret_cast<char *>(Offsets.data()),
107            Offsets.size() * sizeof(uint32_t));
108 }
109 
110 void DXContainerWriter::writeParts(raw_ostream &OS) {
111   uint32_t RollingOffset =
112       sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
113   for (auto I : llvm::zip(ObjectFile.Parts, *ObjectFile.Header.PartOffsets)) {
114     if (RollingOffset < std::get<1>(I)) {
115       uint32_t PadBytes = std::get<1>(I) - RollingOffset;
116       OS.write_zeros(PadBytes);
117     }
118     DXContainerYAML::Part P = std::get<0>(I);
119     RollingOffset = std::get<1>(I) + sizeof(dxbc::PartHeader);
120     uint32_t PartSize = P.Size;
121 
122     OS.write(P.Name.c_str(), 4);
123     if (sys::IsBigEndianHost)
124       sys::swapByteOrder(P.Size);
125     OS.write(reinterpret_cast<const char *>(&P.Size), sizeof(uint32_t));
126 
127     dxbc::PartType PT = dxbc::parsePartType(P.Name);
128 
129     uint64_t DataStart = OS.tell();
130     switch (PT) {
131     case dxbc::PartType::DXIL: {
132       if (!P.Program)
133         continue;
134       dxbc::ProgramHeader Header;
135       Header.Version = dxbc::ProgramHeader::getVersion(P.Program->MajorVersion,
136                                                        P.Program->MinorVersion);
137       Header.Unused = 0;
138       Header.ShaderKind = P.Program->ShaderKind;
139       memcpy(Header.Bitcode.Magic, "DXIL", 4);
140       Header.Bitcode.MajorVersion = P.Program->DXILMajorVersion;
141       Header.Bitcode.MinorVersion = P.Program->DXILMinorVersion;
142       Header.Bitcode.Unused = 0;
143 
144       // Compute the optional fields if needed...
145       if (P.Program->DXILOffset)
146         Header.Bitcode.Offset = *P.Program->DXILOffset;
147       else
148         Header.Bitcode.Offset = sizeof(dxbc::BitcodeHeader);
149 
150       if (P.Program->DXILSize)
151         Header.Bitcode.Size = *P.Program->DXILSize;
152       else
153         Header.Bitcode.Size = P.Program->DXIL ? P.Program->DXIL->size() : 0;
154 
155       if (P.Program->Size)
156         Header.Size = *P.Program->Size;
157       else
158         Header.Size = sizeof(dxbc::ProgramHeader) + Header.Bitcode.Size;
159 
160       uint32_t BitcodeOffset = Header.Bitcode.Offset;
161       if (sys::IsBigEndianHost)
162         Header.swapBytes();
163       OS.write(reinterpret_cast<const char *>(&Header),
164                sizeof(dxbc::ProgramHeader));
165       if (P.Program->DXIL) {
166         if (BitcodeOffset > sizeof(dxbc::BitcodeHeader)) {
167           uint32_t PadBytes = BitcodeOffset - sizeof(dxbc::BitcodeHeader);
168           OS.write_zeros(PadBytes);
169         }
170         OS.write(reinterpret_cast<char *>(P.Program->DXIL->data()),
171                  P.Program->DXIL->size());
172       }
173       break;
174     }
175     case dxbc::PartType::SFI0: {
176       // If we don't have any flags we can continue here and the data will be
177       // zeroed out.
178       if (!P.Flags.has_value())
179         continue;
180       uint64_t Flags = P.Flags->getEncodedFlags();
181       if (sys::IsBigEndianHost)
182         sys::swapByteOrder(Flags);
183       OS.write(reinterpret_cast<char *>(&Flags), sizeof(uint64_t));
184       break;
185     }
186     case dxbc::PartType::HASH: {
187       if (!P.Hash.has_value())
188         continue;
189       dxbc::ShaderHash Hash = {0, {0}};
190       if (P.Hash->IncludesSource)
191         Hash.Flags |= static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
192       memcpy(&Hash.Digest[0], &P.Hash->Digest[0], 16);
193       if (sys::IsBigEndianHost)
194         Hash.swapBytes();
195       OS.write(reinterpret_cast<char *>(&Hash), sizeof(dxbc::ShaderHash));
196       break;
197     }
198     case dxbc::PartType::PSV0: {
199       if (!P.Info.has_value())
200         continue;
201       mcdxbc::PSVRuntimeInfo PSV;
202       memcpy(&PSV.BaseData, &P.Info->Info, sizeof(dxbc::PSV::v3::RuntimeInfo));
203       PSV.Resources = P.Info->Resources;
204       PSV.EntryName = P.Info->EntryName;
205 
206       for (auto El : P.Info->SigInputElements)
207         PSV.InputElements.push_back(mcdxbc::PSVSignatureElement{
208             El.Name, El.Indices, El.StartRow, El.Cols, El.StartCol,
209             El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
210             El.Stream});
211 
212       for (auto El : P.Info->SigOutputElements)
213         PSV.OutputElements.push_back(mcdxbc::PSVSignatureElement{
214             El.Name, El.Indices, El.StartRow, El.Cols, El.StartCol,
215             El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
216             El.Stream});
217 
218       for (auto El : P.Info->SigPatchOrPrimElements)
219         PSV.PatchOrPrimElements.push_back(mcdxbc::PSVSignatureElement{
220             El.Name, El.Indices, El.StartRow, El.Cols, El.StartCol,
221             El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
222             El.Stream});
223 
224       static_assert(PSV.OutputVectorMasks.size() == PSV.InputOutputMap.size());
225       for (unsigned I = 0; I < PSV.OutputVectorMasks.size(); ++I) {
226         PSV.OutputVectorMasks[I].insert(PSV.OutputVectorMasks[I].begin(),
227                                         P.Info->OutputVectorMasks[I].begin(),
228                                         P.Info->OutputVectorMasks[I].end());
229         PSV.InputOutputMap[I].insert(PSV.InputOutputMap[I].begin(),
230                                      P.Info->InputOutputMap[I].begin(),
231                                      P.Info->InputOutputMap[I].end());
232       }
233 
234       PSV.PatchOrPrimMasks.insert(PSV.PatchOrPrimMasks.begin(),
235                                   P.Info->PatchOrPrimMasks.begin(),
236                                   P.Info->PatchOrPrimMasks.end());
237       PSV.InputPatchMap.insert(PSV.InputPatchMap.begin(),
238                                P.Info->InputPatchMap.begin(),
239                                P.Info->InputPatchMap.end());
240       PSV.PatchOutputMap.insert(PSV.PatchOutputMap.begin(),
241                                 P.Info->PatchOutputMap.begin(),
242                                 P.Info->PatchOutputMap.end());
243 
244       PSV.finalize(static_cast<Triple::EnvironmentType>(
245           Triple::Pixel + P.Info->Info.ShaderStage));
246       PSV.write(OS, P.Info->Version);
247       break;
248     }
249     case dxbc::PartType::ISG1:
250     case dxbc::PartType::OSG1:
251     case dxbc::PartType::PSG1: {
252       mcdxbc::Signature Sig;
253       if (P.Signature.has_value()) {
254         for (const auto &Param : P.Signature->Parameters) {
255           Sig.addParam(Param.Stream, Param.Name, Param.Index, Param.SystemValue,
256                        Param.CompType, Param.Register, Param.Mask,
257                        Param.ExclusiveMask, Param.MinPrecision);
258         }
259       }
260       Sig.write(OS);
261       break;
262     }
263     case dxbc::PartType::Unknown:
264       break; // Skip any handling for unrecognized parts.
265     case dxbc::PartType::RTS0:
266       if (!P.RootSignature.has_value())
267         continue;
268 
269       mcdxbc::RootSignatureDesc RS;
270       RS.Flags = P.RootSignature->getEncodedFlags();
271       RS.Version = P.RootSignature->Version;
272       RS.RootParameterOffset = P.RootSignature->RootParametersOffset;
273       RS.NumStaticSamplers = P.RootSignature->NumStaticSamplers;
274       RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
275 
276       for (DXContainerYAML::RootParameterLocationYaml &L :
277            P.RootSignature->Parameters.Locations) {
278         dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
279                                          L.Header.Offset};
280 
281         switch (L.Header.Type) {
282         case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
283           const DXContainerYAML::RootConstantsYaml &ConstantYaml =
284               P.RootSignature->Parameters.getOrInsertConstants(L);
285           dxbc::RTS0::v1::RootConstants Constants;
286           Constants.Num32BitValues = ConstantYaml.Num32BitValues;
287           Constants.RegisterSpace = ConstantYaml.RegisterSpace;
288           Constants.ShaderRegister = ConstantYaml.ShaderRegister;
289           RS.ParametersContainer.addParameter(Header, Constants);
290           break;
291         }
292         case llvm::to_underlying(dxbc::RootParameterType::CBV):
293         case llvm::to_underlying(dxbc::RootParameterType::SRV):
294         case llvm::to_underlying(dxbc::RootParameterType::UAV): {
295           const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
296               P.RootSignature->Parameters.getOrInsertDescriptor(L);
297 
298           dxbc::RTS0::v2::RootDescriptor Descriptor;
299           Descriptor.RegisterSpace = DescriptorYaml.RegisterSpace;
300           Descriptor.ShaderRegister = DescriptorYaml.ShaderRegister;
301           if (RS.Version > 1)
302             Descriptor.Flags = DescriptorYaml.getEncodedFlags();
303           RS.ParametersContainer.addParameter(Header, Descriptor);
304           break;
305         }
306         case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
307           const DXContainerYAML::DescriptorTableYaml &TableYaml =
308               P.RootSignature->Parameters.getOrInsertTable(L);
309           mcdxbc::DescriptorTable Table;
310           for (const auto &R : TableYaml.Ranges) {
311 
312             dxbc::RTS0::v2::DescriptorRange Range;
313             Range.RangeType = R.RangeType;
314             Range.NumDescriptors = R.NumDescriptors;
315             Range.BaseShaderRegister = R.BaseShaderRegister;
316             Range.RegisterSpace = R.RegisterSpace;
317             Range.OffsetInDescriptorsFromTableStart =
318                 R.OffsetInDescriptorsFromTableStart;
319             if (RS.Version > 1)
320               Range.Flags = R.getEncodedFlags();
321             Table.Ranges.push_back(Range);
322           }
323           RS.ParametersContainer.addParameter(Header, Table);
324           break;
325         }
326         default:
327           // Handling invalid parameter type edge case. We intentionally let
328           // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
329           // for that to be used as a testing tool more effectively.
330           RS.ParametersContainer.addInvalidParameter(Header);
331         }
332       }
333 
334       for (const auto &Param : P.RootSignature->samplers()) {
335         dxbc::RTS0::v1::StaticSampler NewSampler;
336         NewSampler.Filter = Param.Filter;
337         NewSampler.AddressU = Param.AddressU;
338         NewSampler.AddressV = Param.AddressV;
339         NewSampler.AddressW = Param.AddressW;
340         NewSampler.MipLODBias = Param.MipLODBias;
341         NewSampler.MaxAnisotropy = Param.MaxAnisotropy;
342         NewSampler.ComparisonFunc = Param.ComparisonFunc;
343         NewSampler.BorderColor = Param.BorderColor;
344         NewSampler.MinLOD = Param.MinLOD;
345         NewSampler.MaxLOD = Param.MaxLOD;
346         NewSampler.ShaderRegister = Param.ShaderRegister;
347         NewSampler.RegisterSpace = Param.RegisterSpace;
348         NewSampler.ShaderVisibility = Param.ShaderVisibility;
349 
350         RS.StaticSamplers.push_back(NewSampler);
351       }
352 
353       RS.write(OS);
354       break;
355     }
356     uint64_t BytesWritten = OS.tell() - DataStart;
357     RollingOffset += BytesWritten;
358     if (BytesWritten < PartSize)
359       OS.write_zeros(PartSize - BytesWritten);
360     RollingOffset += PartSize;
361   }
362 }
363 
364 Error DXContainerWriter::write(raw_ostream &OS) {
365   if (Error Err = computePartOffsets())
366     return Err;
367   writeHeader(OS);
368   writeParts(OS);
369   return Error::success();
370 }
371 
372 namespace llvm {
373 namespace yaml {
374 
375 bool yaml2dxcontainer(DXContainerYAML::Object &Doc, raw_ostream &Out,
376                       ErrorHandler EH) {
377   DXContainerWriter Writer(Doc);
378   if (Error Err = Writer.write(Out)) {
379     handleAllErrors(std::move(Err),
380                     [&](const ErrorInfoBase &Err) { EH(Err.message()); });
381     return false;
382   }
383   return true;
384 }
385 
386 } // namespace yaml
387 } // namespace llvm
388