1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DXILTranslateMetadata.h"
10 #include "DXILShaderFlags.h"
11 #include "DirectX.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Analysis/DXILMetadataAnalysis.h"
15 #include "llvm/Analysis/DXILResource.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/DiagnosticInfo.h"
19 #include "llvm/IR/DiagnosticPrinter.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "llvm/IR/Metadata.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/VersionTuple.h"
30 #include "llvm/TargetParser/Triple.h"
31 #include <cstdint>
32
33 using namespace llvm;
34 using namespace llvm::dxil;
35
36 namespace {
37 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
38 /// for TranslateMetadata pass
39 class DiagnosticInfoTranslateMD : public DiagnosticInfo {
40 private:
41 const Twine &Msg;
42 const Module &Mod;
43
44 public:
45 /// \p M is the module for which the diagnostic is being emitted. \p Msg is
46 /// the message to show. Note that this class does not copy this message, so
47 /// this reference must be valid for the whole life time of the diagnostic.
DiagnosticInfoTranslateMD(const Module & M,const Twine & Msg LLVM_LIFETIME_BOUND,DiagnosticSeverity Severity=DS_Error)48 DiagnosticInfoTranslateMD(const Module &M,
49 const Twine &Msg LLVM_LIFETIME_BOUND,
50 DiagnosticSeverity Severity = DS_Error)
51 : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
52
print(DiagnosticPrinter & DP) const53 void print(DiagnosticPrinter &DP) const override {
54 DP << Mod.getName() << ": " << Msg << '\n';
55 }
56 };
57
58 enum class EntryPropsTag {
59 ShaderFlags = 0,
60 GSState,
61 DSState,
62 HSState,
63 NumThreads,
64 AutoBindingSpace,
65 RayPayloadSize,
66 RayAttribSize,
67 ShaderKind,
68 MSState,
69 ASStateTag,
70 WaveSize,
71 EntryRootSig,
72 };
73
74 } // namespace
75
emitResourceMetadata(Module & M,DXILResourceMap & DRM,DXILResourceTypeMap & DRTM)76 static NamedMDNode *emitResourceMetadata(Module &M, DXILResourceMap &DRM,
77 DXILResourceTypeMap &DRTM) {
78 LLVMContext &Context = M.getContext();
79
80 for (ResourceInfo &RI : DRM)
81 if (!RI.hasSymbol())
82 RI.createSymbol(M,
83 DRTM[RI.getHandleTy()].createElementStruct(RI.getName()));
84
85 SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
86 for (const ResourceInfo &RI : DRM.srvs())
87 SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
88 for (const ResourceInfo &RI : DRM.uavs())
89 UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
90 for (const ResourceInfo &RI : DRM.cbuffers())
91 CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
92 for (const ResourceInfo &RI : DRM.samplers())
93 Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
94
95 Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
96 Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
97 Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs);
98 Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps);
99
100 if (DRM.empty())
101 return nullptr;
102
103 NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
104 ResourceMD->addOperand(
105 MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
106
107 return ResourceMD;
108 }
109
getShortShaderStage(Triple::EnvironmentType Env)110 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
111 switch (Env) {
112 case Triple::Pixel:
113 return "ps";
114 case Triple::Vertex:
115 return "vs";
116 case Triple::Geometry:
117 return "gs";
118 case Triple::Hull:
119 return "hs";
120 case Triple::Domain:
121 return "ds";
122 case Triple::Compute:
123 return "cs";
124 case Triple::Library:
125 return "lib";
126 case Triple::Mesh:
127 return "ms";
128 case Triple::Amplification:
129 return "as";
130 default:
131 break;
132 }
133 llvm_unreachable("Unsupported environment for DXIL generation.");
134 }
135
getShaderStage(Triple::EnvironmentType Env)136 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
137 return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
138 }
139
140 static SmallVector<Metadata *>
getTagValueAsMetadata(EntryPropsTag Tag,uint64_t Value,LLVMContext & Ctx)141 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
142 SmallVector<Metadata *> MDVals;
143 MDVals.emplace_back(ConstantAsMetadata::get(
144 ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
145 switch (Tag) {
146 case EntryPropsTag::ShaderFlags:
147 MDVals.emplace_back(ConstantAsMetadata::get(
148 ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
149 break;
150 case EntryPropsTag::ShaderKind:
151 MDVals.emplace_back(ConstantAsMetadata::get(
152 ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
153 break;
154 case EntryPropsTag::GSState:
155 case EntryPropsTag::DSState:
156 case EntryPropsTag::HSState:
157 case EntryPropsTag::NumThreads:
158 case EntryPropsTag::AutoBindingSpace:
159 case EntryPropsTag::RayPayloadSize:
160 case EntryPropsTag::RayAttribSize:
161 case EntryPropsTag::MSState:
162 case EntryPropsTag::ASStateTag:
163 case EntryPropsTag::WaveSize:
164 case EntryPropsTag::EntryRootSig:
165 llvm_unreachable("NYI: Unhandled entry property tag");
166 }
167 return MDVals;
168 }
169
170 static MDTuple *
getEntryPropAsMetadata(const EntryProperties & EP,uint64_t EntryShaderFlags,const Triple::EnvironmentType ShaderProfile)171 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
172 const Triple::EnvironmentType ShaderProfile) {
173 SmallVector<Metadata *> MDVals;
174 LLVMContext &Ctx = EP.Entry->getContext();
175 if (EntryShaderFlags != 0)
176 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
177 EntryShaderFlags, Ctx));
178
179 if (EP.Entry != nullptr) {
180 // FIXME: support more props.
181 // See https://github.com/llvm/llvm-project/issues/57948.
182 // Add shader kind for lib entries.
183 if (ShaderProfile == Triple::EnvironmentType::Library &&
184 EP.ShaderStage != Triple::EnvironmentType::Library)
185 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
186 getShaderStage(EP.ShaderStage), Ctx));
187
188 if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
189 MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
190 Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
191 Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
192 Type::getInt32Ty(Ctx), EP.NumThreadsX)),
193 ConstantAsMetadata::get(ConstantInt::get(
194 Type::getInt32Ty(Ctx), EP.NumThreadsY)),
195 ConstantAsMetadata::get(ConstantInt::get(
196 Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
197 MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
198 }
199 }
200 if (MDVals.empty())
201 return nullptr;
202 return MDNode::get(Ctx, MDVals);
203 }
204
constructEntryMetadata(const Function * EntryFn,MDTuple * Signatures,MDNode * Resources,MDTuple * Properties,LLVMContext & Ctx)205 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
206 MDNode *Resources, MDTuple *Properties,
207 LLVMContext &Ctx) {
208 // Each entry point metadata record specifies:
209 // * reference to the entry point function global symbol
210 // * unmangled name
211 // * list of signatures
212 // * list of resources
213 // * list of tag-value pairs of shader capabilities and other properties
214 Metadata *MDVals[5];
215 MDVals[0] =
216 EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
217 MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
218 MDVals[2] = Signatures;
219 MDVals[3] = Resources;
220 MDVals[4] = Properties;
221 return MDNode::get(Ctx, MDVals);
222 }
223
emitEntryMD(const EntryProperties & EP,MDTuple * Signatures,MDNode * MDResources,const uint64_t EntryShaderFlags,const Triple::EnvironmentType ShaderProfile)224 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
225 MDNode *MDResources,
226 const uint64_t EntryShaderFlags,
227 const Triple::EnvironmentType ShaderProfile) {
228 MDTuple *Properties =
229 getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
230 return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
231 EP.Entry->getContext());
232 }
233
emitValidatorVersionMD(Module & M,const ModuleMetadataInfo & MMDI)234 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
235 if (MMDI.ValidatorVersion.empty())
236 return;
237
238 LLVMContext &Ctx = M.getContext();
239 IRBuilder<> IRB(Ctx);
240 Metadata *MDVals[2];
241 MDVals[0] =
242 ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
243 MDVals[1] = ConstantAsMetadata::get(
244 IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
245 NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
246 // Set validator version obtained from DXIL Metadata Analysis pass
247 ValVerNode->clearOperands();
248 ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
249 }
250
emitShaderModelVersionMD(Module & M,const ModuleMetadataInfo & MMDI)251 static void emitShaderModelVersionMD(Module &M,
252 const ModuleMetadataInfo &MMDI) {
253 LLVMContext &Ctx = M.getContext();
254 IRBuilder<> IRB(Ctx);
255 Metadata *SMVals[3];
256 VersionTuple SM = MMDI.ShaderModelVersion;
257 SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
258 SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
259 SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
260 NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
261 SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
262 }
263
emitDXILVersionTupleMD(Module & M,const ModuleMetadataInfo & MMDI)264 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
265 LLVMContext &Ctx = M.getContext();
266 IRBuilder<> IRB(Ctx);
267 VersionTuple DXILVer = MMDI.DXILVersion;
268 Metadata *DXILVals[2];
269 DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
270 DXILVals[1] =
271 ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
272 NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
273 DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
274 }
275
emitTopLevelLibraryNode(Module & M,MDNode * RMD,uint64_t ShaderFlags)276 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
277 uint64_t ShaderFlags) {
278 LLVMContext &Ctx = M.getContext();
279 MDTuple *Properties = nullptr;
280 if (ShaderFlags != 0) {
281 SmallVector<Metadata *> MDVals;
282 MDVals.append(
283 getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
284 Properties = MDNode::get(Ctx, MDVals);
285 }
286 // Library has an entry metadata with resource table metadata and all other
287 // MDNodes as null.
288 return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
289 }
290
291 // TODO: We might need to refactor this to be more generic,
292 // in case we need more metadata to be replaced.
translateBranchMetadata(Module & M)293 static void translateBranchMetadata(Module &M) {
294 for (Function &F : M) {
295 for (BasicBlock &BB : F) {
296 Instruction *BBTerminatorInst = BB.getTerminator();
297
298 MDNode *HlslControlFlowMD =
299 BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
300
301 if (!HlslControlFlowMD)
302 continue;
303
304 assert(HlslControlFlowMD->getNumOperands() == 2 &&
305 "invalid operands for hlsl.controlflow.hint");
306
307 MDBuilder MDHelper(M.getContext());
308 ConstantInt *Op1 =
309 mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
310
311 SmallVector<llvm::Metadata *, 2> Vals(
312 ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
313 MDHelper.createConstant(Op1)});
314
315 MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
316
317 BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
318 BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
319 }
320 }
321 }
322
translateMetadata(Module & M,DXILResourceMap & DRM,DXILResourceTypeMap & DRTM,const ModuleShaderFlags & ShaderFlags,const ModuleMetadataInfo & MMDI)323 static void translateMetadata(Module &M, DXILResourceMap &DRM,
324 DXILResourceTypeMap &DRTM,
325 const ModuleShaderFlags &ShaderFlags,
326 const ModuleMetadataInfo &MMDI) {
327 LLVMContext &Ctx = M.getContext();
328 IRBuilder<> IRB(Ctx);
329 SmallVector<MDNode *> EntryFnMDNodes;
330
331 emitValidatorVersionMD(M, MMDI);
332 emitShaderModelVersionMD(M, MMDI);
333 emitDXILVersionTupleMD(M, MMDI);
334 NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, DRTM);
335 auto *ResourceMD =
336 (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
337 // FIXME: Add support to construct Signatures
338 // See https://github.com/llvm/llvm-project/issues/57928
339 MDTuple *Signatures = nullptr;
340
341 if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
342 // Get the combined shader flag mask of all functions in the library to be
343 // used as shader flags mask value associated with top-level library entry
344 // metadata.
345 uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
346 EntryFnMDNodes.emplace_back(
347 emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
348 } else if (MMDI.EntryPropertyVec.size() > 1) {
349 M.getContext().diagnose(DiagnosticInfoTranslateMD(
350 M, "Non-library shader: One and only one entry expected"));
351 }
352
353 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
354 const ComputedShaderFlags &EntrySFMask =
355 ShaderFlags.getFunctionFlags(EntryProp.Entry);
356
357 // If ShaderProfile is Library, mask is already consolidated in the
358 // top-level library node. Hence it is not emitted.
359 uint64_t EntryShaderFlags = 0;
360 if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
361 EntryShaderFlags = EntrySFMask;
362 if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
363 M.getContext().diagnose(DiagnosticInfoTranslateMD(
364 M,
365 "Shader stage '" +
366 Twine(getShortShaderStage(EntryProp.ShaderStage) +
367 "' for entry '" + Twine(EntryProp.Entry->getName()) +
368 "' different from specified target profile '" +
369 Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
370 "'"))));
371 }
372 }
373 EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
374 EntryShaderFlags,
375 MMDI.ShaderProfile));
376 }
377
378 NamedMDNode *EntryPointsNamedMD =
379 M.getOrInsertNamedMetadata("dx.entryPoints");
380 for (auto *Entry : EntryFnMDNodes)
381 EntryPointsNamedMD->addOperand(Entry);
382 }
383
run(Module & M,ModuleAnalysisManager & MAM)384 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
385 ModuleAnalysisManager &MAM) {
386 DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
387 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
388 const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
389 const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
390
391 translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
392 translateBranchMetadata(M);
393
394 return PreservedAnalyses::all();
395 }
396
397 namespace {
398 class DXILTranslateMetadataLegacy : public ModulePass {
399 public:
400 static char ID; // Pass identification, replacement for typeid
DXILTranslateMetadataLegacy()401 explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
402
getPassName() const403 StringRef getPassName() const override { return "DXIL Translate Metadata"; }
404
getAnalysisUsage(AnalysisUsage & AU) const405 void getAnalysisUsage(AnalysisUsage &AU) const override {
406 AU.addRequired<DXILResourceTypeWrapperPass>();
407 AU.addRequired<DXILResourceWrapperPass>();
408 AU.addRequired<ShaderFlagsAnalysisWrapper>();
409 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
410 AU.addPreserved<DXILResourceWrapperPass>();
411 AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
412 AU.addPreserved<ShaderFlagsAnalysisWrapper>();
413 AU.addPreserved<DXILResourceBindingWrapperPass>();
414 }
415
runOnModule(Module & M)416 bool runOnModule(Module &M) override {
417 DXILResourceMap &DRM =
418 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
419 DXILResourceTypeMap &DRTM =
420 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
421 const ModuleShaderFlags &ShaderFlags =
422 getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
423 dxil::ModuleMetadataInfo MMDI =
424 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
425
426 translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
427 translateBranchMetadata(M);
428 return true;
429 }
430 };
431
432 } // namespace
433
434 char DXILTranslateMetadataLegacy::ID = 0;
435
createDXILTranslateMetadataLegacyPass()436 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() {
437 return new DXILTranslateMetadataLegacy();
438 }
439
440 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
441 "DXIL Translate Metadata", false, false)
442 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
443 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
444 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
445 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
446 "DXIL Translate Metadata", false, false)
447