xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXContainerGlobals.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DXContainerGlobalsPass implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILRootSignature.h"
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DXILMetadataAnalysis.h"
20 #include "llvm/Analysis/DXILResource.h"
21 #include "llvm/BinaryFormat/DXContainer.h"
22 #include "llvm/CodeGen/Passes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/MC/DXContainerPSVInfo.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/MD5.h"
29 #include "llvm/TargetParser/Triple.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31 #include <optional>
32 
33 using namespace llvm;
34 using namespace llvm::dxil;
35 using namespace llvm::mcdxbc;
36 
37 namespace {
38 class DXContainerGlobals : public llvm::ModulePass {
39 
40   GlobalVariable *buildContainerGlobal(Module &M, Constant *Content,
41                                        StringRef Name, StringRef SectionName);
42   GlobalVariable *getFeatureFlags(Module &M);
43   GlobalVariable *computeShaderHash(Module &M);
44   GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
45                                  StringRef SectionName);
46   void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
47   void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
48   void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
49   void addPipelineStateValidationInfo(Module &M,
50                                       SmallVector<GlobalValue *> &Globals);
51 
52 public:
53   static char ID; // Pass identification, replacement for typeid
DXContainerGlobals()54   DXContainerGlobals() : ModulePass(ID) {}
55 
getPassName() const56   StringRef getPassName() const override {
57     return "DXContainer Global Emitter";
58   }
59 
60   bool runOnModule(Module &M) override;
61 
getAnalysisUsage(AnalysisUsage & AU) const62   void getAnalysisUsage(AnalysisUsage &AU) const override {
63     AU.setPreservesAll();
64     AU.addRequired<ShaderFlagsAnalysisWrapper>();
65     AU.addRequired<RootSignatureAnalysisWrapper>();
66     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
67     AU.addRequired<DXILResourceTypeWrapperPass>();
68     AU.addRequired<DXILResourceWrapperPass>();
69   }
70 };
71 
72 } // namespace
73 
runOnModule(Module & M)74 bool DXContainerGlobals::runOnModule(Module &M) {
75   llvm::SmallVector<GlobalValue *> Globals;
76   Globals.push_back(getFeatureFlags(M));
77   Globals.push_back(computeShaderHash(M));
78   addSignature(M, Globals);
79   addRootSignature(M, Globals);
80   addPipelineStateValidationInfo(M, Globals);
81   appendToCompilerUsed(M, Globals);
82   return true;
83 }
84 
getFeatureFlags(Module & M)85 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
86   uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
87                                       .getShaderFlags()
88                                       .getCombinedFlags()
89                                       .getFeatureFlags();
90 
91   Constant *FeatureFlagsConstant =
92       ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
93   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
94 }
95 
computeShaderHash(Module & M)96 GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) {
97   auto *DXILConstant =
98       cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer());
99   MD5 Digest;
100   Digest.update(DXILConstant->getRawDataValues());
101   MD5::MD5Result Result = Digest.final();
102 
103   dxbc::ShaderHash HashData = {0, {0}};
104   // The Hash's IncludesSource flag gets set whenever the hashed shader includes
105   // debug information.
106   if (M.debug_compile_units_begin() != M.debug_compile_units_end())
107     HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
108 
109   memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16);
110   if (sys::IsBigEndianHost)
111     HashData.swapBytes();
112   StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash));
113 
114   Constant *ModuleConstant =
115       ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data));
116   return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH");
117 }
118 
buildContainerGlobal(Module & M,Constant * Content,StringRef Name,StringRef SectionName)119 GlobalVariable *DXContainerGlobals::buildContainerGlobal(
120     Module &M, Constant *Content, StringRef Name, StringRef SectionName) {
121   auto *GV = new llvm::GlobalVariable(
122       M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name);
123   GV->setSection(SectionName);
124   GV->setAlignment(Align(4));
125   return GV;
126 }
127 
buildSignature(Module & M,Signature & Sig,StringRef Name,StringRef SectionName)128 GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig,
129                                                    StringRef Name,
130                                                    StringRef SectionName) {
131   SmallString<256> Data;
132   raw_svector_ostream OS(Data);
133   Sig.write(OS);
134   Constant *Constant =
135       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
136   return buildContainerGlobal(M, Constant, Name, SectionName);
137 }
138 
addSignature(Module & M,SmallVector<GlobalValue * > & Globals)139 void DXContainerGlobals::addSignature(Module &M,
140                                       SmallVector<GlobalValue *> &Globals) {
141   // FIXME: support graphics shader.
142   //  see issue https://github.com/llvm/llvm-project/issues/90504.
143 
144   Signature InputSig;
145   Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1"));
146 
147   Signature OutputSig;
148   Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
149 }
150 
addRootSignature(Module & M,SmallVector<GlobalValue * > & Globals)151 void DXContainerGlobals::addRootSignature(Module &M,
152                                           SmallVector<GlobalValue *> &Globals) {
153 
154   dxil::ModuleMetadataInfo &MMI =
155       getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
156 
157   // Root Signature in Library don't compile to DXContainer.
158   if (MMI.ShaderProfile == llvm::Triple::Library)
159     return;
160 
161   assert(MMI.EntryPropertyVec.size() == 1);
162 
163   auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
164   const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
165   const std::optional<mcdxbc::RootSignatureDesc> &RS =
166       RSA.getDescForFunction(EntryFunction);
167 
168   if (!RS)
169     return;
170 
171   SmallString<256> Data;
172   raw_svector_ostream OS(Data);
173 
174   RS->write(OS);
175 
176   Constant *Constant =
177       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
178   Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
179 }
180 
addResourcesForPSV(Module & M,PSVRuntimeInfo & PSV)181 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
182   const DXILResourceMap &DRM =
183       getAnalysis<DXILResourceWrapperPass>().getResourceMap();
184   DXILResourceTypeMap &DRTM =
185       getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
186 
187   auto MakeBinding =
188       [](const dxil::ResourceInfo::ResourceBinding &Binding,
189          const dxbc::PSV::ResourceType Type, const dxil::ResourceKind Kind,
190          const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) {
191         dxbc::PSV::v2::ResourceBindInfo BindInfo;
192         BindInfo.Type = Type;
193         BindInfo.LowerBound = Binding.LowerBound;
194         BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1;
195         BindInfo.Space = Binding.Space;
196         BindInfo.Kind = static_cast<dxbc::PSV::ResourceKind>(Kind);
197         BindInfo.Flags = Flags;
198         return BindInfo;
199       };
200 
201   for (const dxil::ResourceInfo &RI : DRM.cbuffers()) {
202     const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
203     PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV,
204                                         dxil::ResourceKind::CBuffer));
205   }
206   for (const dxil::ResourceInfo &RI : DRM.samplers()) {
207     const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
208     PSV.Resources.push_back(MakeBinding(Binding,
209                                         dxbc::PSV::ResourceType::Sampler,
210                                         dxil::ResourceKind::Sampler));
211   }
212   for (const dxil::ResourceInfo &RI : DRM.srvs()) {
213     const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
214 
215     dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
216     dxbc::PSV::ResourceType ResType;
217     if (TypeInfo.isStruct())
218       ResType = dxbc::PSV::ResourceType::SRVStructured;
219     else if (TypeInfo.isTyped())
220       ResType = dxbc::PSV::ResourceType::SRVTyped;
221     else
222       ResType = dxbc::PSV::ResourceType::SRVRaw;
223 
224     PSV.Resources.push_back(
225         MakeBinding(Binding, ResType, TypeInfo.getResourceKind()));
226   }
227   for (const dxil::ResourceInfo &RI : DRM.uavs()) {
228     const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
229 
230     dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
231     dxbc::PSV::ResourceType ResType;
232     if (RI.hasCounter())
233       ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter;
234     else if (TypeInfo.isStruct())
235       ResType = dxbc::PSV::ResourceType::UAVStructured;
236     else if (TypeInfo.isTyped())
237       ResType = dxbc::PSV::ResourceType::UAVTyped;
238     else
239       ResType = dxbc::PSV::ResourceType::UAVRaw;
240 
241     dxbc::PSV::ResourceFlags Flags;
242     // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking
243     // with https://github.com/llvm/llvm-project/issues/104392
244     Flags.Flags = 0u;
245 
246     PSV.Resources.push_back(
247         MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags));
248   }
249 }
250 
addPipelineStateValidationInfo(Module & M,SmallVector<GlobalValue * > & Globals)251 void DXContainerGlobals::addPipelineStateValidationInfo(
252     Module &M, SmallVector<GlobalValue *> &Globals) {
253   SmallString<256> Data;
254   raw_svector_ostream OS(Data);
255   PSVRuntimeInfo PSV;
256   PSV.BaseData.MinimumWaveLaneCount = 0;
257   PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
258 
259   dxil::ModuleMetadataInfo &MMI =
260       getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
261   assert(MMI.EntryPropertyVec.size() == 1 ||
262          MMI.ShaderProfile == Triple::Library);
263   PSV.BaseData.ShaderStage =
264       static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
265 
266   addResourcesForPSV(M, PSV);
267 
268   // Hardcoded values here to unblock loading the shader into D3D.
269   //
270   // TODO: Lots more stuff to do here!
271   //
272   // See issue https://github.com/llvm/llvm-project/issues/96674.
273   switch (MMI.ShaderProfile) {
274   case Triple::Compute:
275     PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
276     PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
277     PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
278     break;
279   default:
280     break;
281   }
282 
283   if (MMI.ShaderProfile != Triple::Library)
284     PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
285 
286   PSV.finalize(MMI.ShaderProfile);
287   PSV.write(OS);
288   Constant *Constant =
289       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
290   Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0"));
291 }
292 
293 char DXContainerGlobals::ID = 0;
294 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
295                       "DXContainer Global Emitter", false, true)
INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)296 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
297 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
298 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
299 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
300 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
301                     "DXContainer Global Emitter", false, true)
302 
303 ModulePass *llvm::createDXContainerGlobalsPass() {
304   return new DXContainerGlobals();
305 }
306