1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DXContainerGlobalsPass implementation.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "DXILRootSignature.h"
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DXILMetadataAnalysis.h"
20 #include "llvm/Analysis/DXILResource.h"
21 #include "llvm/BinaryFormat/DXContainer.h"
22 #include "llvm/CodeGen/Passes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/MC/DXContainerPSVInfo.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/MD5.h"
29 #include "llvm/TargetParser/Triple.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31 #include <optional>
32
33 using namespace llvm;
34 using namespace llvm::dxil;
35 using namespace llvm::mcdxbc;
36
37 namespace {
38 class DXContainerGlobals : public llvm::ModulePass {
39
40 GlobalVariable *buildContainerGlobal(Module &M, Constant *Content,
41 StringRef Name, StringRef SectionName);
42 GlobalVariable *getFeatureFlags(Module &M);
43 GlobalVariable *computeShaderHash(Module &M);
44 GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
45 StringRef SectionName);
46 void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
47 void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
48 void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
49 void addPipelineStateValidationInfo(Module &M,
50 SmallVector<GlobalValue *> &Globals);
51
52 public:
53 static char ID; // Pass identification, replacement for typeid
DXContainerGlobals()54 DXContainerGlobals() : ModulePass(ID) {}
55
getPassName() const56 StringRef getPassName() const override {
57 return "DXContainer Global Emitter";
58 }
59
60 bool runOnModule(Module &M) override;
61
getAnalysisUsage(AnalysisUsage & AU) const62 void getAnalysisUsage(AnalysisUsage &AU) const override {
63 AU.setPreservesAll();
64 AU.addRequired<ShaderFlagsAnalysisWrapper>();
65 AU.addRequired<RootSignatureAnalysisWrapper>();
66 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
67 AU.addRequired<DXILResourceTypeWrapperPass>();
68 AU.addRequired<DXILResourceWrapperPass>();
69 }
70 };
71
72 } // namespace
73
runOnModule(Module & M)74 bool DXContainerGlobals::runOnModule(Module &M) {
75 llvm::SmallVector<GlobalValue *> Globals;
76 Globals.push_back(getFeatureFlags(M));
77 Globals.push_back(computeShaderHash(M));
78 addSignature(M, Globals);
79 addRootSignature(M, Globals);
80 addPipelineStateValidationInfo(M, Globals);
81 appendToCompilerUsed(M, Globals);
82 return true;
83 }
84
getFeatureFlags(Module & M)85 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
86 uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
87 .getShaderFlags()
88 .getCombinedFlags()
89 .getFeatureFlags();
90
91 Constant *FeatureFlagsConstant =
92 ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
93 return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
94 }
95
computeShaderHash(Module & M)96 GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) {
97 auto *DXILConstant =
98 cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer());
99 MD5 Digest;
100 Digest.update(DXILConstant->getRawDataValues());
101 MD5::MD5Result Result = Digest.final();
102
103 dxbc::ShaderHash HashData = {0, {0}};
104 // The Hash's IncludesSource flag gets set whenever the hashed shader includes
105 // debug information.
106 if (M.debug_compile_units_begin() != M.debug_compile_units_end())
107 HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
108
109 memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16);
110 if (sys::IsBigEndianHost)
111 HashData.swapBytes();
112 StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash));
113
114 Constant *ModuleConstant =
115 ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data));
116 return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH");
117 }
118
buildContainerGlobal(Module & M,Constant * Content,StringRef Name,StringRef SectionName)119 GlobalVariable *DXContainerGlobals::buildContainerGlobal(
120 Module &M, Constant *Content, StringRef Name, StringRef SectionName) {
121 auto *GV = new llvm::GlobalVariable(
122 M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name);
123 GV->setSection(SectionName);
124 GV->setAlignment(Align(4));
125 return GV;
126 }
127
buildSignature(Module & M,Signature & Sig,StringRef Name,StringRef SectionName)128 GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig,
129 StringRef Name,
130 StringRef SectionName) {
131 SmallString<256> Data;
132 raw_svector_ostream OS(Data);
133 Sig.write(OS);
134 Constant *Constant =
135 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
136 return buildContainerGlobal(M, Constant, Name, SectionName);
137 }
138
addSignature(Module & M,SmallVector<GlobalValue * > & Globals)139 void DXContainerGlobals::addSignature(Module &M,
140 SmallVector<GlobalValue *> &Globals) {
141 // FIXME: support graphics shader.
142 // see issue https://github.com/llvm/llvm-project/issues/90504.
143
144 Signature InputSig;
145 Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1"));
146
147 Signature OutputSig;
148 Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
149 }
150
addRootSignature(Module & M,SmallVector<GlobalValue * > & Globals)151 void DXContainerGlobals::addRootSignature(Module &M,
152 SmallVector<GlobalValue *> &Globals) {
153
154 dxil::ModuleMetadataInfo &MMI =
155 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
156
157 // Root Signature in Library don't compile to DXContainer.
158 if (MMI.ShaderProfile == llvm::Triple::Library)
159 return;
160
161 assert(MMI.EntryPropertyVec.size() == 1);
162
163 auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
164 const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
165 const std::optional<mcdxbc::RootSignatureDesc> &RS =
166 RSA.getDescForFunction(EntryFunction);
167
168 if (!RS)
169 return;
170
171 SmallString<256> Data;
172 raw_svector_ostream OS(Data);
173
174 RS->write(OS);
175
176 Constant *Constant =
177 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
178 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
179 }
180
addResourcesForPSV(Module & M,PSVRuntimeInfo & PSV)181 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
182 const DXILResourceMap &DRM =
183 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
184 DXILResourceTypeMap &DRTM =
185 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
186
187 auto MakeBinding =
188 [](const dxil::ResourceInfo::ResourceBinding &Binding,
189 const dxbc::PSV::ResourceType Type, const dxil::ResourceKind Kind,
190 const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) {
191 dxbc::PSV::v2::ResourceBindInfo BindInfo;
192 BindInfo.Type = Type;
193 BindInfo.LowerBound = Binding.LowerBound;
194 BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1;
195 BindInfo.Space = Binding.Space;
196 BindInfo.Kind = static_cast<dxbc::PSV::ResourceKind>(Kind);
197 BindInfo.Flags = Flags;
198 return BindInfo;
199 };
200
201 for (const dxil::ResourceInfo &RI : DRM.cbuffers()) {
202 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
203 PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV,
204 dxil::ResourceKind::CBuffer));
205 }
206 for (const dxil::ResourceInfo &RI : DRM.samplers()) {
207 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
208 PSV.Resources.push_back(MakeBinding(Binding,
209 dxbc::PSV::ResourceType::Sampler,
210 dxil::ResourceKind::Sampler));
211 }
212 for (const dxil::ResourceInfo &RI : DRM.srvs()) {
213 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
214
215 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
216 dxbc::PSV::ResourceType ResType;
217 if (TypeInfo.isStruct())
218 ResType = dxbc::PSV::ResourceType::SRVStructured;
219 else if (TypeInfo.isTyped())
220 ResType = dxbc::PSV::ResourceType::SRVTyped;
221 else
222 ResType = dxbc::PSV::ResourceType::SRVRaw;
223
224 PSV.Resources.push_back(
225 MakeBinding(Binding, ResType, TypeInfo.getResourceKind()));
226 }
227 for (const dxil::ResourceInfo &RI : DRM.uavs()) {
228 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
229
230 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
231 dxbc::PSV::ResourceType ResType;
232 if (RI.hasCounter())
233 ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter;
234 else if (TypeInfo.isStruct())
235 ResType = dxbc::PSV::ResourceType::UAVStructured;
236 else if (TypeInfo.isTyped())
237 ResType = dxbc::PSV::ResourceType::UAVTyped;
238 else
239 ResType = dxbc::PSV::ResourceType::UAVRaw;
240
241 dxbc::PSV::ResourceFlags Flags;
242 // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking
243 // with https://github.com/llvm/llvm-project/issues/104392
244 Flags.Flags = 0u;
245
246 PSV.Resources.push_back(
247 MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags));
248 }
249 }
250
addPipelineStateValidationInfo(Module & M,SmallVector<GlobalValue * > & Globals)251 void DXContainerGlobals::addPipelineStateValidationInfo(
252 Module &M, SmallVector<GlobalValue *> &Globals) {
253 SmallString<256> Data;
254 raw_svector_ostream OS(Data);
255 PSVRuntimeInfo PSV;
256 PSV.BaseData.MinimumWaveLaneCount = 0;
257 PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
258
259 dxil::ModuleMetadataInfo &MMI =
260 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
261 assert(MMI.EntryPropertyVec.size() == 1 ||
262 MMI.ShaderProfile == Triple::Library);
263 PSV.BaseData.ShaderStage =
264 static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
265
266 addResourcesForPSV(M, PSV);
267
268 // Hardcoded values here to unblock loading the shader into D3D.
269 //
270 // TODO: Lots more stuff to do here!
271 //
272 // See issue https://github.com/llvm/llvm-project/issues/96674.
273 switch (MMI.ShaderProfile) {
274 case Triple::Compute:
275 PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
276 PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
277 PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
278 break;
279 default:
280 break;
281 }
282
283 if (MMI.ShaderProfile != Triple::Library)
284 PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
285
286 PSV.finalize(MMI.ShaderProfile);
287 PSV.write(OS);
288 Constant *Constant =
289 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
290 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0"));
291 }
292
293 char DXContainerGlobals::ID = 0;
294 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
295 "DXContainer Global Emitter", false, true)
INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)296 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
297 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
298 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
299 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
300 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
301 "DXContainer Global Emitter", false, true)
302
303 ModulePass *llvm::createDXContainerGlobalsPass() {
304 return new DXContainerGlobals();
305 }
306