xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILMetadata.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- DXILMetadata.cpp - DXIL Metadata helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL metadata.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILMetadata.h"
14 #include "llvm/IR/Constants.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/Metadata.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Support/VersionTuple.h"
19 #include "llvm/TargetParser/Triple.h"
20 
21 using namespace llvm;
22 using namespace llvm::dxil;
23 
24 ValidatorVersionMD::ValidatorVersionMD(Module &M)
25     : Entry(M.getOrInsertNamedMetadata("dx.valver")) {}
26 
27 void ValidatorVersionMD::update(VersionTuple ValidatorVer) {
28   auto &Ctx = Entry->getParent()->getContext();
29   IRBuilder<> B(Ctx);
30   Metadata *MDVals[2];
31   MDVals[0] = ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMajor()));
32   MDVals[1] =
33       ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMinor().value_or(0)));
34 
35   if (isEmpty())
36     Entry->addOperand(MDNode::get(Ctx, MDVals));
37   else
38     Entry->setOperand(0, MDNode::get(Ctx, MDVals));
39 }
40 
41 bool ValidatorVersionMD::isEmpty() { return Entry->getNumOperands() == 0; }
42 
43 VersionTuple ValidatorVersionMD::getAsVersionTuple() {
44   if (isEmpty())
45     return VersionTuple(1, 0);
46   auto *ValVerMD = cast<MDNode>(Entry->getOperand(0));
47   auto *MajorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(0));
48   auto *MinorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(1));
49   return VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
50 }
51 
52 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
53   switch (Env) {
54   case Triple::Pixel:
55     return "ps";
56   case Triple::Vertex:
57     return "vs";
58   case Triple::Geometry:
59     return "gs";
60   case Triple::Hull:
61     return "hs";
62   case Triple::Domain:
63     return "ds";
64   case Triple::Compute:
65     return "cs";
66   case Triple::Library:
67     return "lib";
68   case Triple::Mesh:
69     return "ms";
70   case Triple::Amplification:
71     return "as";
72   default:
73     break;
74   }
75   llvm_unreachable("Unsupported environment for DXIL generation.");
76   return "";
77 }
78 
79 void dxil::createShaderModelMD(Module &M) {
80   NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.shaderModel");
81   Triple TT(M.getTargetTriple());
82   VersionTuple Ver = TT.getOSVersion();
83   LLVMContext &Ctx = M.getContext();
84   IRBuilder<> B(Ctx);
85 
86   Metadata *Vals[3];
87   Vals[0] = MDString::get(Ctx, getShortShaderStage(TT.getEnvironment()));
88   Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
89   Vals[2] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
90   Entry->addOperand(MDNode::get(Ctx, Vals));
91 }
92 
93 void dxil::createDXILVersionMD(Module &M) {
94   Triple TT(Triple::normalize(M.getTargetTriple()));
95   VersionTuple Ver = TT.getDXILVersion();
96   LLVMContext &Ctx = M.getContext();
97   IRBuilder<> B(Ctx);
98   NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.version");
99   Metadata *Vals[2];
100   Vals[0] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
101   Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
102   Entry->addOperand(MDNode::get(Ctx, Vals));
103 }
104 
105 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
106   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
107 }
108 
109 namespace {
110 
111 struct EntryProps {
112   Triple::EnvironmentType ShaderKind;
113   // FIXME: support more shader profiles.
114   // See https://github.com/llvm/llvm-project/issues/57927.
115   struct {
116     unsigned NumThreads[3];
117   } CS;
118 
119   EntryProps(Function &F, Triple::EnvironmentType ModuleShaderKind)
120       : ShaderKind(ModuleShaderKind) {
121 
122     if (ShaderKind == Triple::EnvironmentType::Library) {
123       Attribute EntryAttr = F.getFnAttribute("hlsl.shader");
124       StringRef EntryProfile = EntryAttr.getValueAsString();
125       Triple T("", "", "", EntryProfile);
126       ShaderKind = T.getEnvironment();
127     }
128 
129     if (ShaderKind == Triple::EnvironmentType::Compute) {
130       auto NumThreadsStr =
131           F.getFnAttribute("hlsl.numthreads").getValueAsString();
132       SmallVector<StringRef> NumThreads;
133       NumThreadsStr.split(NumThreads, ',');
134       assert(NumThreads.size() == 3 && "invalid numthreads");
135       auto Zip =
136           llvm::zip(NumThreads, MutableArrayRef<unsigned>(CS.NumThreads));
137       for (auto It : Zip) {
138         StringRef Str = std::get<0>(It);
139         APInt V;
140         [[maybe_unused]] bool Result = Str.getAsInteger(10, V);
141         assert(!Result && "Failed to parse numthreads");
142 
143         unsigned &Num = std::get<1>(It);
144         Num = V.getLimitedValue();
145       }
146     }
147   }
148 
149   MDTuple *emitDXILEntryProps(uint64_t RawShaderFlag, LLVMContext &Ctx,
150                               bool IsLib) {
151     std::vector<Metadata *> MDVals;
152 
153     if (RawShaderFlag != 0)
154       appendShaderFlags(MDVals, RawShaderFlag, Ctx);
155 
156     // Add shader kind for lib entrys.
157     if (IsLib && ShaderKind != Triple::EnvironmentType::Library)
158       appendShaderKind(MDVals, Ctx);
159 
160     if (ShaderKind == Triple::EnvironmentType::Compute)
161       appendNumThreads(MDVals, Ctx);
162     // FIXME: support more props.
163     // See https://github.com/llvm/llvm-project/issues/57948.
164     return MDNode::get(Ctx, MDVals);
165   }
166 
167   static MDTuple *emitEntryPropsForEmptyEntry(uint64_t RawShaderFlag,
168                                               LLVMContext &Ctx) {
169     if (RawShaderFlag == 0)
170       return nullptr;
171 
172     std::vector<Metadata *> MDVals;
173 
174     appendShaderFlags(MDVals, RawShaderFlag, Ctx);
175     // FIXME: support more props.
176     // See https://github.com/llvm/llvm-project/issues/57948.
177     return MDNode::get(Ctx, MDVals);
178   }
179 
180 private:
181   enum EntryPropsTag {
182     ShaderFlagsTag = 0,
183     GSStateTag,
184     DSStateTag,
185     HSStateTag,
186     NumThreadsTag,
187     AutoBindingSpaceTag,
188     RayPayloadSizeTag,
189     RayAttribSizeTag,
190     ShaderKindTag,
191     MSStateTag,
192     ASStateTag,
193     WaveSizeTag,
194     EntryRootSigTag,
195   };
196 
197   void appendNumThreads(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
198     MDVals.emplace_back(ConstantAsMetadata::get(
199         ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
200 
201     std::vector<Metadata *> NumThreadVals;
202     for (auto Num : ArrayRef<unsigned>(CS.NumThreads))
203       NumThreadVals.emplace_back(ConstantAsMetadata::get(
204           ConstantInt::get(Type::getInt32Ty(Ctx), Num)));
205     MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
206   }
207 
208   static void appendShaderFlags(std::vector<Metadata *> &MDVals,
209                                 uint64_t RawShaderFlag, LLVMContext &Ctx) {
210     MDVals.emplace_back(ConstantAsMetadata::get(
211         ConstantInt::get(Type::getInt32Ty(Ctx), ShaderFlagsTag)));
212     MDVals.emplace_back(ConstantAsMetadata::get(
213         ConstantInt::get(Type::getInt64Ty(Ctx), RawShaderFlag)));
214   }
215 
216   void appendShaderKind(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
217     MDVals.emplace_back(ConstantAsMetadata::get(
218         ConstantInt::get(Type::getInt32Ty(Ctx), ShaderKindTag)));
219     MDVals.emplace_back(ConstantAsMetadata::get(
220         ConstantInt::get(Type::getInt32Ty(Ctx), getShaderStage(ShaderKind))));
221   }
222 };
223 
224 class EntryMD {
225   Function &F;
226   LLVMContext &Ctx;
227   EntryProps Props;
228 
229 public:
230   EntryMD(Function &F, Triple::EnvironmentType ModuleShaderKind)
231       : F(F), Ctx(F.getContext()), Props(F, ModuleShaderKind) {}
232 
233   MDTuple *emitEntryTuple(MDTuple *Resources, uint64_t RawShaderFlag) {
234     // FIXME: add signature for profile other than CS.
235     // See https://github.com/llvm/llvm-project/issues/57928.
236     MDTuple *Signatures = nullptr;
237     return emitDXILEntryPointTuple(
238         &F, F.getName().str(), Signatures, Resources,
239         Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ false), Ctx);
240   }
241 
242   MDTuple *emitEntryTupleForLib(uint64_t RawShaderFlag) {
243     // FIXME: add signature for profile other than CS.
244     // See https://github.com/llvm/llvm-project/issues/57928.
245     MDTuple *Signatures = nullptr;
246     return emitDXILEntryPointTuple(
247         &F, F.getName().str(), Signatures,
248         /*entry in lib doesn't need resources metadata*/ nullptr,
249         Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ true), Ctx);
250   }
251 
252   // Library will have empty entry metadata which only store the resource table
253   // metadata.
254   static MDTuple *emitEmptyEntryForLib(MDTuple *Resources,
255                                        uint64_t RawShaderFlag,
256                                        LLVMContext &Ctx) {
257     return emitDXILEntryPointTuple(
258         nullptr, "", nullptr, Resources,
259         EntryProps::emitEntryPropsForEmptyEntry(RawShaderFlag, Ctx), Ctx);
260   }
261 
262 private:
263   static MDTuple *emitDXILEntryPointTuple(Function *Fn, const std::string &Name,
264                                           MDTuple *Signatures,
265                                           MDTuple *Resources,
266                                           MDTuple *Properties,
267                                           LLVMContext &Ctx) {
268     Metadata *MDVals[5];
269     MDVals[0] = Fn ? ValueAsMetadata::get(Fn) : nullptr;
270     MDVals[1] = MDString::get(Ctx, Name.c_str());
271     MDVals[2] = Signatures;
272     MDVals[3] = Resources;
273     MDVals[4] = Properties;
274     return MDNode::get(Ctx, MDVals);
275   }
276 };
277 } // namespace
278 
279 void dxil::createEntryMD(Module &M, const uint64_t ShaderFlags) {
280   SmallVector<Function *> EntryList;
281   for (auto &F : M.functions()) {
282     if (!F.hasFnAttribute("hlsl.shader"))
283       continue;
284     EntryList.emplace_back(&F);
285   }
286 
287   auto &Ctx = M.getContext();
288   // FIXME: generate metadata for resource.
289   // See https://github.com/llvm/llvm-project/issues/57926.
290   MDTuple *MDResources = nullptr;
291   if (auto *NamedResources = M.getNamedMetadata("dx.resources"))
292     MDResources = dyn_cast<MDTuple>(NamedResources->getOperand(0));
293 
294   std::vector<MDNode *> Entries;
295   Triple T = Triple(M.getTargetTriple());
296   switch (T.getEnvironment()) {
297   case Triple::EnvironmentType::Library: {
298     // Add empty entry to put resource metadata.
299     MDTuple *EmptyEntry =
300         EntryMD::emitEmptyEntryForLib(MDResources, ShaderFlags, Ctx);
301     Entries.emplace_back(EmptyEntry);
302 
303     for (Function *Entry : EntryList) {
304       EntryMD MD(*Entry, T.getEnvironment());
305       Entries.emplace_back(MD.emitEntryTupleForLib(0));
306     }
307   } break;
308   case Triple::EnvironmentType::Compute:
309   case Triple::EnvironmentType::Amplification:
310   case Triple::EnvironmentType::Mesh:
311   case Triple::EnvironmentType::Vertex:
312   case Triple::EnvironmentType::Hull:
313   case Triple::EnvironmentType::Domain:
314   case Triple::EnvironmentType::Geometry:
315   case Triple::EnvironmentType::Pixel: {
316     assert(EntryList.size() == 1 &&
317            "non-lib profiles should only have one entry");
318     EntryMD MD(*EntryList.front(), T.getEnvironment());
319     Entries.emplace_back(MD.emitEntryTuple(MDResources, ShaderFlags));
320   } break;
321   default:
322     assert(0 && "invalid profile");
323     break;
324   }
325 
326   NamedMDNode *EntryPointsNamedMD =
327       M.getOrInsertNamedMetadata("dx.entryPoints");
328   for (auto *Entry : Entries)
329     EntryPointsNamedMD->addOperand(Entry);
330 }
331