xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILTranslateMetadata.h"
10 #include "DXILShaderFlags.h"
11 #include "DirectX.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Analysis/DXILMetadataAnalysis.h"
15 #include "llvm/Analysis/DXILResource.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/DiagnosticInfo.h"
19 #include "llvm/IR/DiagnosticPrinter.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "llvm/IR/Metadata.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/VersionTuple.h"
30 #include "llvm/TargetParser/Triple.h"
31 #include <cstdint>
32 
33 using namespace llvm;
34 using namespace llvm::dxil;
35 
36 namespace {
37 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
38 /// for TranslateMetadata pass
39 class DiagnosticInfoTranslateMD : public DiagnosticInfo {
40 private:
41   const Twine &Msg;
42   const Module &Mod;
43 
44 public:
45   /// \p M is the module for which the diagnostic is being emitted. \p Msg is
46   /// the message to show. Note that this class does not copy this message, so
47   /// this reference must be valid for the whole life time of the diagnostic.
48   DiagnosticInfoTranslateMD(const Module &M,
49                             const Twine &Msg LLVM_LIFETIME_BOUND,
50                             DiagnosticSeverity Severity = DS_Error)
51       : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
52 
53   void print(DiagnosticPrinter &DP) const override {
54     DP << Mod.getName() << ": " << Msg << '\n';
55   }
56 };
57 
58 enum class EntryPropsTag {
59   ShaderFlags = 0,
60   GSState,
61   DSState,
62   HSState,
63   NumThreads,
64   AutoBindingSpace,
65   RayPayloadSize,
66   RayAttribSize,
67   ShaderKind,
68   MSState,
69   ASStateTag,
70   WaveSize,
71   EntryRootSig,
72 };
73 
74 } // namespace
75 
76 static NamedMDNode *emitResourceMetadata(Module &M, DXILResourceMap &DRM,
77                                          DXILResourceTypeMap &DRTM) {
78   LLVMContext &Context = M.getContext();
79 
80   for (ResourceInfo &RI : DRM)
81     if (!RI.hasSymbol())
82       RI.createSymbol(M,
83                       DRTM[RI.getHandleTy()].createElementStruct(RI.getName()));
84 
85   SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
86   for (const ResourceInfo &RI : DRM.srvs())
87     SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
88   for (const ResourceInfo &RI : DRM.uavs())
89     UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
90   for (const ResourceInfo &RI : DRM.cbuffers())
91     CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
92   for (const ResourceInfo &RI : DRM.samplers())
93     Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
94 
95   Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
96   Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
97   Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs);
98   Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps);
99 
100   if (DRM.empty())
101     return nullptr;
102 
103   NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
104   ResourceMD->addOperand(
105       MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
106 
107   return ResourceMD;
108 }
109 
110 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
111   switch (Env) {
112   case Triple::Pixel:
113     return "ps";
114   case Triple::Vertex:
115     return "vs";
116   case Triple::Geometry:
117     return "gs";
118   case Triple::Hull:
119     return "hs";
120   case Triple::Domain:
121     return "ds";
122   case Triple::Compute:
123     return "cs";
124   case Triple::Library:
125     return "lib";
126   case Triple::Mesh:
127     return "ms";
128   case Triple::Amplification:
129     return "as";
130   default:
131     break;
132   }
133   llvm_unreachable("Unsupported environment for DXIL generation.");
134 }
135 
136 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
137   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
138 }
139 
140 static SmallVector<Metadata *>
141 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
142   SmallVector<Metadata *> MDVals;
143   MDVals.emplace_back(ConstantAsMetadata::get(
144       ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
145   switch (Tag) {
146   case EntryPropsTag::ShaderFlags:
147     MDVals.emplace_back(ConstantAsMetadata::get(
148         ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
149     break;
150   case EntryPropsTag::ShaderKind:
151     MDVals.emplace_back(ConstantAsMetadata::get(
152         ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
153     break;
154   case EntryPropsTag::GSState:
155   case EntryPropsTag::DSState:
156   case EntryPropsTag::HSState:
157   case EntryPropsTag::NumThreads:
158   case EntryPropsTag::AutoBindingSpace:
159   case EntryPropsTag::RayPayloadSize:
160   case EntryPropsTag::RayAttribSize:
161   case EntryPropsTag::MSState:
162   case EntryPropsTag::ASStateTag:
163   case EntryPropsTag::WaveSize:
164   case EntryPropsTag::EntryRootSig:
165     llvm_unreachable("NYI: Unhandled entry property tag");
166   }
167   return MDVals;
168 }
169 
170 static MDTuple *
171 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
172                        const Triple::EnvironmentType ShaderProfile) {
173   SmallVector<Metadata *> MDVals;
174   LLVMContext &Ctx = EP.Entry->getContext();
175   if (EntryShaderFlags != 0)
176     MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
177                                         EntryShaderFlags, Ctx));
178 
179   if (EP.Entry != nullptr) {
180     // FIXME: support more props.
181     // See https://github.com/llvm/llvm-project/issues/57948.
182     // Add shader kind for lib entries.
183     if (ShaderProfile == Triple::EnvironmentType::Library &&
184         EP.ShaderStage != Triple::EnvironmentType::Library)
185       MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
186                                           getShaderStage(EP.ShaderStage), Ctx));
187 
188     if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
189       MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
190           Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
191       Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
192                                        Type::getInt32Ty(Ctx), EP.NumThreadsX)),
193                                    ConstantAsMetadata::get(ConstantInt::get(
194                                        Type::getInt32Ty(Ctx), EP.NumThreadsY)),
195                                    ConstantAsMetadata::get(ConstantInt::get(
196                                        Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
197       MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
198     }
199   }
200   if (MDVals.empty())
201     return nullptr;
202   return MDNode::get(Ctx, MDVals);
203 }
204 
205 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
206                                 MDNode *Resources, MDTuple *Properties,
207                                 LLVMContext &Ctx) {
208   // Each entry point metadata record specifies:
209   //  * reference to the entry point function global symbol
210   //  * unmangled name
211   //  * list of signatures
212   //  * list of resources
213   //  * list of tag-value pairs of shader capabilities and other properties
214   Metadata *MDVals[5];
215   MDVals[0] =
216       EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
217   MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
218   MDVals[2] = Signatures;
219   MDVals[3] = Resources;
220   MDVals[4] = Properties;
221   return MDNode::get(Ctx, MDVals);
222 }
223 
224 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
225                             MDNode *MDResources,
226                             const uint64_t EntryShaderFlags,
227                             const Triple::EnvironmentType ShaderProfile) {
228   MDTuple *Properties =
229       getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
230   return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
231                                 EP.Entry->getContext());
232 }
233 
234 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
235   if (MMDI.ValidatorVersion.empty())
236     return;
237 
238   LLVMContext &Ctx = M.getContext();
239   IRBuilder<> IRB(Ctx);
240   Metadata *MDVals[2];
241   MDVals[0] =
242       ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
243   MDVals[1] = ConstantAsMetadata::get(
244       IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
245   NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
246   // Set validator version obtained from DXIL Metadata Analysis pass
247   ValVerNode->clearOperands();
248   ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
249 }
250 
251 static void emitShaderModelVersionMD(Module &M,
252                                      const ModuleMetadataInfo &MMDI) {
253   LLVMContext &Ctx = M.getContext();
254   IRBuilder<> IRB(Ctx);
255   Metadata *SMVals[3];
256   VersionTuple SM = MMDI.ShaderModelVersion;
257   SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
258   SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
259   SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
260   NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
261   SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
262 }
263 
264 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
265   LLVMContext &Ctx = M.getContext();
266   IRBuilder<> IRB(Ctx);
267   VersionTuple DXILVer = MMDI.DXILVersion;
268   Metadata *DXILVals[2];
269   DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
270   DXILVals[1] =
271       ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
272   NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
273   DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
274 }
275 
276 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
277                                         uint64_t ShaderFlags) {
278   LLVMContext &Ctx = M.getContext();
279   MDTuple *Properties = nullptr;
280   if (ShaderFlags != 0) {
281     SmallVector<Metadata *> MDVals;
282     MDVals.append(
283         getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
284     Properties = MDNode::get(Ctx, MDVals);
285   }
286   // Library has an entry metadata with resource table metadata and all other
287   // MDNodes as null.
288   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
289 }
290 
291 // TODO: We might need to refactor this to be more generic,
292 // in case we need more metadata to be replaced.
293 static void translateBranchMetadata(Module &M) {
294   for (Function &F : M) {
295     for (BasicBlock &BB : F) {
296       Instruction *BBTerminatorInst = BB.getTerminator();
297 
298       MDNode *HlslControlFlowMD =
299           BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
300 
301       if (!HlslControlFlowMD)
302         continue;
303 
304       assert(HlslControlFlowMD->getNumOperands() == 2 &&
305              "invalid operands for hlsl.controlflow.hint");
306 
307       MDBuilder MDHelper(M.getContext());
308       ConstantInt *Op1 =
309           mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
310 
311       SmallVector<llvm::Metadata *, 2> Vals(
312           ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
313                                MDHelper.createConstant(Op1)});
314 
315       MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
316 
317       BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
318       BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
319     }
320   }
321 }
322 
323 static void translateMetadata(Module &M, DXILResourceMap &DRM,
324                               DXILResourceTypeMap &DRTM,
325                               const ModuleShaderFlags &ShaderFlags,
326                               const ModuleMetadataInfo &MMDI) {
327   LLVMContext &Ctx = M.getContext();
328   IRBuilder<> IRB(Ctx);
329   SmallVector<MDNode *> EntryFnMDNodes;
330 
331   emitValidatorVersionMD(M, MMDI);
332   emitShaderModelVersionMD(M, MMDI);
333   emitDXILVersionTupleMD(M, MMDI);
334   NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, DRTM);
335   auto *ResourceMD =
336       (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
337   // FIXME: Add support to construct Signatures
338   // See https://github.com/llvm/llvm-project/issues/57928
339   MDTuple *Signatures = nullptr;
340 
341   if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
342     // Get the combined shader flag mask of all functions in the library to be
343     // used as shader flags mask value associated with top-level library entry
344     // metadata.
345     uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
346     EntryFnMDNodes.emplace_back(
347         emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
348   } else if (MMDI.EntryPropertyVec.size() > 1) {
349     M.getContext().diagnose(DiagnosticInfoTranslateMD(
350         M, "Non-library shader: One and only one entry expected"));
351   }
352 
353   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
354     const ComputedShaderFlags &EntrySFMask =
355         ShaderFlags.getFunctionFlags(EntryProp.Entry);
356 
357     // If ShaderProfile is Library, mask is already consolidated in the
358     // top-level library node. Hence it is not emitted.
359     uint64_t EntryShaderFlags = 0;
360     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
361       EntryShaderFlags = EntrySFMask;
362       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
363         M.getContext().diagnose(DiagnosticInfoTranslateMD(
364             M,
365             "Shader stage '" +
366                 Twine(getShortShaderStage(EntryProp.ShaderStage) +
367                       "' for entry '" + Twine(EntryProp.Entry->getName()) +
368                       "' different from specified target profile '" +
369                       Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
370                             "'"))));
371       }
372     }
373     EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
374                                             EntryShaderFlags,
375                                             MMDI.ShaderProfile));
376   }
377 
378   NamedMDNode *EntryPointsNamedMD =
379       M.getOrInsertNamedMetadata("dx.entryPoints");
380   for (auto *Entry : EntryFnMDNodes)
381     EntryPointsNamedMD->addOperand(Entry);
382 }
383 
384 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
385                                              ModuleAnalysisManager &MAM) {
386   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
387   DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
388   const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
389   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
390 
391   translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
392   translateBranchMetadata(M);
393 
394   return PreservedAnalyses::all();
395 }
396 
397 namespace {
398 class DXILTranslateMetadataLegacy : public ModulePass {
399 public:
400   static char ID; // Pass identification, replacement for typeid
401   explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
402 
403   StringRef getPassName() const override { return "DXIL Translate Metadata"; }
404 
405   void getAnalysisUsage(AnalysisUsage &AU) const override {
406     AU.addRequired<DXILResourceTypeWrapperPass>();
407     AU.addRequired<DXILResourceWrapperPass>();
408     AU.addRequired<ShaderFlagsAnalysisWrapper>();
409     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
410     AU.addPreserved<DXILResourceWrapperPass>();
411     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
412     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
413     AU.addPreserved<DXILResourceBindingWrapperPass>();
414   }
415 
416   bool runOnModule(Module &M) override {
417     DXILResourceMap &DRM =
418         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
419     DXILResourceTypeMap &DRTM =
420         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
421     const ModuleShaderFlags &ShaderFlags =
422         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
423     dxil::ModuleMetadataInfo MMDI =
424         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
425 
426     translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
427     translateBranchMetadata(M);
428     return true;
429   }
430 };
431 
432 } // namespace
433 
434 char DXILTranslateMetadataLegacy::ID = 0;
435 
436 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() {
437   return new DXILTranslateMetadataLegacy();
438 }
439 
440 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
441                       "DXIL Translate Metadata", false, false)
442 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
443 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
444 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
445 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
446                     "DXIL Translate Metadata", false, false)
447