xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILResource.cpp (revision 19fae0f66023a97a9b464b3beeeabb2081f575b3)
1 //===- DXILResource.cpp - DXIL Resource helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL Resources.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILResource.h"
14 #include "CBufferDataLayout.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/Metadata.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/Format.h"
21 
22 using namespace llvm;
23 using namespace llvm::dxil;
24 using namespace llvm::hlsl;
25 
26 template <typename T> void ResourceTable<T>::collect(Module &M) {
27   NamedMDNode *Entry = M.getNamedMetadata(MDName);
28   if (!Entry || Entry->getNumOperands() == 0)
29     return;
30 
31   uint32_t Counter = 0;
32   for (auto *Res : Entry->operands()) {
33     Data.push_back(T(Counter++, FrontendResource(cast<MDNode>(Res))));
34   }
35 }
36 
37 template <> void ResourceTable<ConstantBuffer>::collect(Module &M) {
38   NamedMDNode *Entry = M.getNamedMetadata(MDName);
39   if (!Entry || Entry->getNumOperands() == 0)
40     return;
41 
42   uint32_t Counter = 0;
43   for (auto *Res : Entry->operands()) {
44     Data.push_back(
45         ConstantBuffer(Counter++, FrontendResource(cast<MDNode>(Res))));
46   }
47   // FIXME: share CBufferDataLayout with CBuffer load lowering.
48   //   See https://github.com/llvm/llvm-project/issues/58381
49   CBufferDataLayout CBDL(M.getDataLayout(), /*IsLegacy*/ true);
50   for (auto &CB : Data)
51     CB.setSize(CBDL);
52 }
53 
54 void Resources::collect(Module &M) {
55   UAVs.collect(M);
56   CBuffers.collect(M);
57 }
58 
59 ResourceBase::ResourceBase(uint32_t I, FrontendResource R)
60     : ID(I), GV(R.getGlobalVariable()), Name(""), Space(R.getSpace()),
61       LowerBound(R.getResourceIndex()), RangeSize(1) {
62   if (auto *ArrTy = dyn_cast<ArrayType>(GV->getValueType()))
63     RangeSize = ArrTy->getNumElements();
64 }
65 
66 StringRef ResourceBase::getComponentTypeName(ComponentType CompType) {
67   switch (CompType) {
68   case ComponentType::LastEntry:
69   case ComponentType::Invalid:
70     return "invalid";
71   case ComponentType::I1:
72     return "i1";
73   case ComponentType::I16:
74     return "i16";
75   case ComponentType::U16:
76     return "u16";
77   case ComponentType::I32:
78     return "i32";
79   case ComponentType::U32:
80     return "u32";
81   case ComponentType::I64:
82     return "i64";
83   case ComponentType::U64:
84     return "u64";
85   case ComponentType::F16:
86     return "f16";
87   case ComponentType::F32:
88     return "f32";
89   case ComponentType::F64:
90     return "f64";
91   case ComponentType::SNormF16:
92     return "snorm_f16";
93   case ComponentType::UNormF16:
94     return "unorm_f16";
95   case ComponentType::SNormF32:
96     return "snorm_f32";
97   case ComponentType::UNormF32:
98     return "unorm_f32";
99   case ComponentType::SNormF64:
100     return "snorm_f64";
101   case ComponentType::UNormF64:
102     return "unorm_f64";
103   case ComponentType::PackedS8x32:
104     return "p32i8";
105   case ComponentType::PackedU8x32:
106     return "p32u8";
107   }
108 }
109 
110 void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType,
111                                       unsigned Alignment, raw_ostream &OS) {
112   switch (Kind) {
113   default:
114     // TODO: add vector size.
115     OS << right_justify(getComponentTypeName(CompType), Alignment);
116     break;
117   case Kinds::RawBuffer:
118     OS << right_justify("byte", Alignment);
119     break;
120   case Kinds::StructuredBuffer:
121     OS << right_justify("struct", Alignment);
122     break;
123   case Kinds::CBuffer:
124   case Kinds::Sampler:
125     OS << right_justify("NA", Alignment);
126     break;
127   case Kinds::Invalid:
128   case Kinds::NumEntries:
129     break;
130   }
131 }
132 
133 StringRef ResourceBase::getKindName(Kinds Kind) {
134   switch (Kind) {
135   case Kinds::NumEntries:
136   case Kinds::Invalid:
137     return "invalid";
138   case Kinds::Texture1D:
139     return "1d";
140   case Kinds::Texture2D:
141     return "2d";
142   case Kinds::Texture2DMS:
143     return "2dMS";
144   case Kinds::Texture3D:
145     return "3d";
146   case Kinds::TextureCube:
147     return "cube";
148   case Kinds::Texture1DArray:
149     return "1darray";
150   case Kinds::Texture2DArray:
151     return "2darray";
152   case Kinds::Texture2DMSArray:
153     return "2darrayMS";
154   case Kinds::TextureCubeArray:
155     return "cubearray";
156   case Kinds::TypedBuffer:
157     return "buf";
158   case Kinds::RawBuffer:
159     return "rawbuf";
160   case Kinds::StructuredBuffer:
161     return "structbuf";
162   case Kinds::CBuffer:
163     return "cbuffer";
164   case Kinds::Sampler:
165     return "sampler";
166   case Kinds::TBuffer:
167     return "tbuffer";
168   case Kinds::RTAccelerationStructure:
169     return "ras";
170   case Kinds::FeedbackTexture2D:
171     return "fbtex2d";
172   case Kinds::FeedbackTexture2DArray:
173     return "fbtex2darray";
174   }
175 }
176 
177 void ResourceBase::printKind(Kinds Kind, unsigned Alignment, raw_ostream &OS,
178                              bool SRV, bool HasCounter, uint32_t SampleCount) {
179   switch (Kind) {
180   default:
181     OS << right_justify(getKindName(Kind), Alignment);
182     break;
183 
184   case Kinds::RawBuffer:
185   case Kinds::StructuredBuffer:
186     if (SRV)
187       OS << right_justify("r/o", Alignment);
188     else {
189       if (!HasCounter)
190         OS << right_justify("r/w", Alignment);
191       else
192         OS << right_justify("r/w+cnt", Alignment);
193     }
194     break;
195   case Kinds::TypedBuffer:
196     OS << right_justify("buf", Alignment);
197     break;
198   case Kinds::Texture2DMS:
199   case Kinds::Texture2DMSArray: {
200     std::string DimName = getKindName(Kind).str();
201     if (SampleCount)
202       DimName += std::to_string(SampleCount);
203     OS << right_justify(DimName, Alignment);
204   } break;
205   case Kinds::CBuffer:
206   case Kinds::Sampler:
207     OS << right_justify("NA", Alignment);
208     break;
209   case Kinds::Invalid:
210   case Kinds::NumEntries:
211     break;
212   }
213 }
214 
215 void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
216                          StringRef BindingPrefix) const {
217   std::string ResID = IDPrefix.str();
218   ResID += std::to_string(ID);
219   OS << right_justify(ResID, 8);
220 
221   std::string Bind = BindingPrefix.str();
222   Bind += std::to_string(LowerBound);
223   if (Space)
224     Bind += ",space" + std::to_string(Space);
225 
226   OS << right_justify(Bind, 15);
227   if (RangeSize != UINT_MAX)
228     OS << right_justify(std::to_string(RangeSize), 6) << "\n";
229   else
230     OS << right_justify("unbounded", 6) << "\n";
231 }
232 
233 UAVResource::UAVResource(uint32_t I, FrontendResource R)
234     : ResourceBase(I, R),
235       Shape(static_cast<ResourceBase::Kinds>(R.getResourceKind())),
236       GloballyCoherent(false), HasCounter(false), IsROV(false), ExtProps() {
237   parseSourceType(R.getSourceType());
238 }
239 
240 void UAVResource::print(raw_ostream &OS) const {
241   OS << "; " << left_justify(Name, 31);
242 
243   OS << right_justify("UAV", 10);
244 
245   printComponentType(
246       Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS);
247 
248   // FIXME: support SampleCount.
249   // See https://github.com/llvm/llvm-project/issues/58175
250   printKind(Shape, 12, OS, /*SRV*/ false, HasCounter);
251   // Print the binding part.
252   ResourceBase::print(OS, "U", "u");
253 }
254 
255 // FIXME: Capture this in HLSL source. I would go do this right now, but I want
256 // to get this in first so that I can make sure to capture all the extra
257 // information we need to remove the source type string from here (See issue:
258 // https://github.com/llvm/llvm-project/issues/57991).
259 void UAVResource::parseSourceType(StringRef S) {
260   IsROV = S.startswith("RasterizerOrdered");
261   if (IsROV)
262     S = S.substr(strlen("RasterizerOrdered"));
263   if (S.startswith("RW"))
264     S = S.substr(strlen("RW"));
265 
266   // Note: I'm deliberately not handling any of the Texture buffer types at the
267   // moment. I want to resolve the issue above before adding Texture or Sampler
268   // support.
269   Shape = StringSwitch<ResourceBase::Kinds>(S)
270               .StartsWith("Buffer<", Kinds::TypedBuffer)
271               .StartsWith("ByteAddressBuffer<", Kinds::RawBuffer)
272               .StartsWith("StructuredBuffer<", Kinds::StructuredBuffer)
273               .Default(Kinds::Invalid);
274   assert(Shape != Kinds::Invalid && "Unsupported buffer type");
275 
276   S = S.substr(S.find("<") + 1);
277 
278   constexpr size_t PrefixLen = StringRef("vector<").size();
279   if (S.startswith("vector<"))
280     S = S.substr(PrefixLen, S.find(",") - PrefixLen);
281   else
282     S = S.substr(0, S.find(">"));
283 
284   ComponentType ElTy = StringSwitch<ResourceBase::ComponentType>(S)
285                            .Case("bool", ComponentType::I1)
286                            .Case("int16_t", ComponentType::I16)
287                            .Case("uint16_t", ComponentType::U16)
288                            .Case("int32_t", ComponentType::I32)
289                            .Case("uint32_t", ComponentType::U32)
290                            .Case("int64_t", ComponentType::I64)
291                            .Case("uint64_t", ComponentType::U64)
292                            .Case("half", ComponentType::F16)
293                            .Case("float", ComponentType::F32)
294                            .Case("double", ComponentType::F64)
295                            .Default(ComponentType::Invalid);
296   if (ElTy != ComponentType::Invalid)
297     ExtProps.ElementType = ElTy;
298 }
299 
300 ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
301     : ResourceBase(I, R) {}
302 
303 void ConstantBuffer::setSize(CBufferDataLayout &DL) {
304   CBufferSizeInBytes = DL.getTypeAllocSizeInBytes(GV->getValueType());
305 }
306 
307 void ConstantBuffer::print(raw_ostream &OS) const {
308   OS << "; " << left_justify(Name, 31);
309 
310   OS << right_justify("cbuffer", 10);
311 
312   printComponentType(Kinds::CBuffer, ComponentType::Invalid, 8, OS);
313 
314   printKind(Kinds::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
315   // Print the binding part.
316   ResourceBase::print(OS, "CB", "cb");
317 }
318 
319 template <typename T> void ResourceTable<T>::print(raw_ostream &OS) const {
320   for (auto &Res : Data)
321     Res.print(OS);
322 }
323 
324 MDNode *ResourceBase::ExtendedProperties::write(LLVMContext &Ctx) const {
325   IRBuilder<> B(Ctx);
326   SmallVector<Metadata *> Entries;
327   if (ElementType) {
328     Entries.emplace_back(
329         ConstantAsMetadata::get(B.getInt32(TypedBufferElementType)));
330     Entries.emplace_back(ConstantAsMetadata::get(
331         B.getInt32(static_cast<uint32_t>(*ElementType))));
332   }
333   if (Entries.empty())
334     return nullptr;
335   return MDNode::get(Ctx, Entries);
336 }
337 
338 void ResourceBase::write(LLVMContext &Ctx,
339                          MutableArrayRef<Metadata *> Entries) const {
340   IRBuilder<> B(Ctx);
341   Entries[0] = ConstantAsMetadata::get(B.getInt32(ID));
342   Entries[1] = ConstantAsMetadata::get(GV);
343   Entries[2] = MDString::get(Ctx, Name);
344   Entries[3] = ConstantAsMetadata::get(B.getInt32(Space));
345   Entries[4] = ConstantAsMetadata::get(B.getInt32(LowerBound));
346   Entries[5] = ConstantAsMetadata::get(B.getInt32(RangeSize));
347 }
348 
349 MDNode *UAVResource::write() const {
350   auto &Ctx = GV->getContext();
351   IRBuilder<> B(Ctx);
352   Metadata *Entries[11];
353   ResourceBase::write(Ctx, Entries);
354   Entries[6] =
355       ConstantAsMetadata::get(B.getInt32(static_cast<uint32_t>(Shape)));
356   Entries[7] = ConstantAsMetadata::get(B.getInt1(GloballyCoherent));
357   Entries[8] = ConstantAsMetadata::get(B.getInt1(HasCounter));
358   Entries[9] = ConstantAsMetadata::get(B.getInt1(IsROV));
359   Entries[10] = ExtProps.write(Ctx);
360   return MDNode::get(Ctx, Entries);
361 }
362 
363 MDNode *ConstantBuffer::write() const {
364   auto &Ctx = GV->getContext();
365   IRBuilder<> B(Ctx);
366   Metadata *Entries[7];
367   ResourceBase::write(Ctx, Entries);
368 
369   Entries[6] = ConstantAsMetadata::get(B.getInt32(CBufferSizeInBytes));
370   return MDNode::get(Ctx, Entries);
371 }
372 
373 template <typename T> MDNode *ResourceTable<T>::write(Module &M) const {
374   if (Data.empty())
375     return nullptr;
376   SmallVector<Metadata *> MDs;
377   for (auto &Res : Data)
378     MDs.emplace_back(Res.write());
379 
380   NamedMDNode *Entry = M.getNamedMetadata(MDName);
381   if (Entry)
382     Entry->eraseFromParent();
383 
384   return MDNode::get(M.getContext(), MDs);
385 }
386 
387 void Resources::write(Module &M) const {
388   Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr};
389 
390   ResourceMDs[1] = UAVs.write(M);
391 
392   ResourceMDs[2] = CBuffers.write(M);
393 
394   bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr ||
395                      ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr;
396 
397   if (HasResource) {
398     NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources");
399     DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs));
400   }
401 
402   NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs");
403   if (Entry)
404     Entry->eraseFromParent();
405 }
406 
407 void Resources::print(raw_ostream &O) const {
408   O << ";\n"
409     << "; Resource Bindings:\n"
410     << ";\n"
411     << "; Name                                 Type  Format         Dim      "
412        "ID      HLSL Bind  Count\n"
413     << "; ------------------------------ ---------- ------- ----------- "
414        "------- -------------- ------\n";
415 
416   CBuffers.print(O);
417   UAVs.print(O);
418 }
419 
420 void Resources::dump() const { print(dbgs()); }
421