xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/DXILResource.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILResource.h - Representations of DXIL resources -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_DXILRESOURCE_H
10 #define LLVM_ANALYSIS_DXILRESOURCE_H
11 
12 #include "llvm/ADT/MapVector.h"
13 #include "llvm/ADT/SmallVector.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/IR/DerivedTypes.h"
16 #include "llvm/IR/GlobalVariable.h"
17 #include "llvm/IR/PassManager.h"
18 #include "llvm/Pass.h"
19 #include "llvm/Support/Alignment.h"
20 #include "llvm/Support/Compiler.h"
21 #include "llvm/Support/DXILABI.h"
22 #include <cstdint>
23 
24 namespace llvm {
25 class CallInst;
26 class DataLayout;
27 class LLVMContext;
28 class MDTuple;
29 class Value;
30 
31 class DXILResourceTypeMap;
32 
33 namespace dxil {
34 
35 // Returns the resource name from dx_resource_handlefrombinding or
36 // dx_resource_handlefromimplicitbinding call
37 LLVM_ABI StringRef getResourceNameFromBindingCall(CallInst *CI);
38 
39 /// The dx.RawBuffer target extension type
40 ///
41 /// `target("dx.RawBuffer", Type, IsWriteable, IsROV)`
42 class RawBufferExtType : public TargetExtType {
43 public:
44   RawBufferExtType() = delete;
45   RawBufferExtType(const RawBufferExtType &) = delete;
46   RawBufferExtType &operator=(const RawBufferExtType &) = delete;
47 
isStructured()48   bool isStructured() const {
49     // TODO: We need to be more prescriptive here, but since there's some debate
50     // over whether byte address buffer should have a void type or an i8 type,
51     // accept either for now.
52     Type *Ty = getTypeParameter(0);
53     return !Ty->isVoidTy() && !Ty->isIntegerTy(8);
54   }
55 
getResourceType()56   Type *getResourceType() const {
57     return isStructured() ? getTypeParameter(0) : nullptr;
58   }
isWriteable()59   bool isWriteable() const { return getIntParameter(0); }
isROV()60   bool isROV() const { return getIntParameter(1); }
61 
classof(const TargetExtType * T)62   static bool classof(const TargetExtType *T) {
63     return T->getName() == "dx.RawBuffer";
64   }
classof(const Type * T)65   static bool classof(const Type *T) {
66     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
67   }
68 };
69 
70 /// The dx.TypedBuffer target extension type
71 ///
72 /// `target("dx.TypedBuffer", Type, IsWriteable, IsROV, IsSigned)`
73 class TypedBufferExtType : public TargetExtType {
74 public:
75   TypedBufferExtType() = delete;
76   TypedBufferExtType(const TypedBufferExtType &) = delete;
77   TypedBufferExtType &operator=(const TypedBufferExtType &) = delete;
78 
getResourceType()79   Type *getResourceType() const { return getTypeParameter(0); }
isWriteable()80   bool isWriteable() const { return getIntParameter(0); }
isROV()81   bool isROV() const { return getIntParameter(1); }
isSigned()82   bool isSigned() const { return getIntParameter(2); }
83 
classof(const TargetExtType * T)84   static bool classof(const TargetExtType *T) {
85     return T->getName() == "dx.TypedBuffer";
86   }
classof(const Type * T)87   static bool classof(const Type *T) {
88     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
89   }
90 };
91 
92 /// The dx.Texture target extension type
93 ///
94 /// `target("dx.Texture", Type, IsWriteable, IsROV, IsSigned, Dimension)`
95 class TextureExtType : public TargetExtType {
96 public:
97   TextureExtType() = delete;
98   TextureExtType(const TextureExtType &) = delete;
99   TextureExtType &operator=(const TextureExtType &) = delete;
100 
getResourceType()101   Type *getResourceType() const { return getTypeParameter(0); }
isWriteable()102   bool isWriteable() const { return getIntParameter(0); }
isROV()103   bool isROV() const { return getIntParameter(1); }
isSigned()104   bool isSigned() const { return getIntParameter(2); }
getDimension()105   dxil::ResourceKind getDimension() const {
106     return static_cast<dxil::ResourceKind>(getIntParameter(3));
107   }
108 
classof(const TargetExtType * T)109   static bool classof(const TargetExtType *T) {
110     return T->getName() == "dx.Texture";
111   }
classof(const Type * T)112   static bool classof(const Type *T) {
113     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
114   }
115 };
116 
117 /// The dx.MSTexture target extension type
118 ///
119 /// `target("dx.MSTexture", Type, IsWriteable, Samples, IsSigned, Dimension)`
120 class MSTextureExtType : public TargetExtType {
121 public:
122   MSTextureExtType() = delete;
123   MSTextureExtType(const MSTextureExtType &) = delete;
124   MSTextureExtType &operator=(const MSTextureExtType &) = delete;
125 
getResourceType()126   Type *getResourceType() const { return getTypeParameter(0); }
isWriteable()127   bool isWriteable() const { return getIntParameter(0); }
getSampleCount()128   uint32_t getSampleCount() const { return getIntParameter(1); }
isSigned()129   bool isSigned() const { return getIntParameter(2); }
getDimension()130   dxil::ResourceKind getDimension() const {
131     return static_cast<dxil::ResourceKind>(getIntParameter(3));
132   }
133 
classof(const TargetExtType * T)134   static bool classof(const TargetExtType *T) {
135     return T->getName() == "dx.MSTexture";
136   }
classof(const Type * T)137   static bool classof(const Type *T) {
138     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
139   }
140 };
141 
142 /// The dx.FeedbackTexture target extension type
143 ///
144 /// `target("dx.FeedbackTexture", FeedbackType, Dimension)`
145 class FeedbackTextureExtType : public TargetExtType {
146 public:
147   FeedbackTextureExtType() = delete;
148   FeedbackTextureExtType(const FeedbackTextureExtType &) = delete;
149   FeedbackTextureExtType &operator=(const FeedbackTextureExtType &) = delete;
150 
getFeedbackType()151   dxil::SamplerFeedbackType getFeedbackType() const {
152     return static_cast<dxil::SamplerFeedbackType>(getIntParameter(0));
153   }
getDimension()154   dxil::ResourceKind getDimension() const {
155     return static_cast<dxil::ResourceKind>(getIntParameter(1));
156   }
157 
classof(const TargetExtType * T)158   static bool classof(const TargetExtType *T) {
159     return T->getName() == "dx.FeedbackTexture";
160   }
classof(const Type * T)161   static bool classof(const Type *T) {
162     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
163   }
164 };
165 
166 /// The dx.CBuffer target extension type
167 ///
168 /// `target("dx.CBuffer", <Type>, ...)`
169 class CBufferExtType : public TargetExtType {
170 public:
171   CBufferExtType() = delete;
172   CBufferExtType(const CBufferExtType &) = delete;
173   CBufferExtType &operator=(const CBufferExtType &) = delete;
174 
getResourceType()175   Type *getResourceType() const { return getTypeParameter(0); }
176 
classof(const TargetExtType * T)177   static bool classof(const TargetExtType *T) {
178     return T->getName() == "dx.CBuffer";
179   }
classof(const Type * T)180   static bool classof(const Type *T) {
181     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
182   }
183 };
184 
185 /// The dx.Sampler target extension type
186 ///
187 /// `target("dx.Sampler", SamplerType)`
188 class SamplerExtType : public TargetExtType {
189 public:
190   SamplerExtType() = delete;
191   SamplerExtType(const SamplerExtType &) = delete;
192   SamplerExtType &operator=(const SamplerExtType &) = delete;
193 
getSamplerType()194   dxil::SamplerType getSamplerType() const {
195     return static_cast<dxil::SamplerType>(getIntParameter(0));
196   }
197 
classof(const TargetExtType * T)198   static bool classof(const TargetExtType *T) {
199     return T->getName() == "dx.Sampler";
200   }
classof(const Type * T)201   static bool classof(const Type *T) {
202     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
203   }
204 };
205 
206 class AnyResourceExtType : public TargetExtType {
207 public:
208   AnyResourceExtType() = delete;
209   AnyResourceExtType(const AnyResourceExtType &) = delete;
210   AnyResourceExtType &operator=(const AnyResourceExtType &) = delete;
211 
classof(const TargetExtType * T)212   static bool classof(const TargetExtType *T) {
213     return isa<RawBufferExtType>(T) || isa<TypedBufferExtType>(T) ||
214            isa<TextureExtType>(T) || isa<MSTextureExtType>(T) ||
215            isa<FeedbackTextureExtType>(T) || isa<CBufferExtType>(T) ||
216            isa<SamplerExtType>(T);
217   }
218 
classof(const Type * T)219   static bool classof(const Type *T) {
220     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
221   }
222 };
223 
224 /// The dx.Layout target extension type
225 ///
226 /// `target("dx.Layout", <Type>, <size>, [offsets...])`
227 class LayoutExtType : public TargetExtType {
228 public:
229   LayoutExtType() = delete;
230   LayoutExtType(const LayoutExtType &) = delete;
231   LayoutExtType &operator=(const LayoutExtType &) = delete;
232 
getWrappedType()233   Type *getWrappedType() const { return getTypeParameter(0); }
getSize()234   uint32_t getSize() const { return getIntParameter(0); }
getOffsetOfElement(int I)235   uint32_t getOffsetOfElement(int I) const { return getIntParameter(I + 1); }
236 
classof(const TargetExtType * T)237   static bool classof(const TargetExtType *T) {
238     return T->getName() == "dx.Layout";
239   }
classof(const Type * T)240   static bool classof(const Type *T) {
241     return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
242   }
243 };
244 
245 //===----------------------------------------------------------------------===//
246 
247 class ResourceTypeInfo {
248 public:
249   struct UAVInfo {
250     bool IsROV;
251 
252     bool operator==(const UAVInfo &RHS) const { return IsROV == RHS.IsROV; }
253     bool operator!=(const UAVInfo &RHS) const { return !(*this == RHS); }
254     bool operator<(const UAVInfo &RHS) const { return IsROV < RHS.IsROV; }
255   };
256 
257   struct StructInfo {
258     uint32_t Stride;
259     // Note: we store an integer here rather than using `MaybeAlign` because in
260     // GCC 7 MaybeAlign isn't trivial so having one in this union would delete
261     // our move constructor.
262     // See https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p0602r4.html
263     uint32_t AlignLog2;
264 
265     bool operator==(const StructInfo &RHS) const {
266       return std::tie(Stride, AlignLog2) == std::tie(RHS.Stride, RHS.AlignLog2);
267     }
268     bool operator!=(const StructInfo &RHS) const { return !(*this == RHS); }
269     bool operator<(const StructInfo &RHS) const {
270       return std::tie(Stride, AlignLog2) < std::tie(RHS.Stride, RHS.AlignLog2);
271     }
272   };
273 
274   struct TypedInfo {
275     dxil::ElementType ElementTy;
276     uint32_t ElementCount;
277 
278     bool operator==(const TypedInfo &RHS) const {
279       return std::tie(ElementTy, ElementCount) ==
280              std::tie(RHS.ElementTy, RHS.ElementCount);
281     }
282     bool operator!=(const TypedInfo &RHS) const { return !(*this == RHS); }
283     bool operator<(const TypedInfo &RHS) const {
284       return std::tie(ElementTy, ElementCount) <
285              std::tie(RHS.ElementTy, RHS.ElementCount);
286     }
287   };
288 
289 private:
290   TargetExtType *HandleTy;
291 
292   dxil::ResourceClass RC;
293   dxil::ResourceKind Kind;
294 
295 public:
296   LLVM_ABI ResourceTypeInfo(TargetExtType *HandleTy,
297                             const dxil::ResourceClass RC,
298                             const dxil::ResourceKind Kind);
ResourceTypeInfo(TargetExtType * HandleTy)299   ResourceTypeInfo(TargetExtType *HandleTy)
300       : ResourceTypeInfo(HandleTy, {}, dxil::ResourceKind::Invalid) {}
301 
getHandleTy()302   TargetExtType *getHandleTy() const { return HandleTy; }
303   LLVM_ABI StructType *createElementStruct(StringRef CBufferName = "");
304 
305   // Conditions to check before accessing specific views.
306   LLVM_ABI bool isUAV() const;
307   LLVM_ABI bool isCBuffer() const;
308   LLVM_ABI bool isSampler() const;
309   LLVM_ABI bool isStruct() const;
310   LLVM_ABI bool isTyped() const;
311   LLVM_ABI bool isFeedback() const;
312   LLVM_ABI bool isMultiSample() const;
313 
314   // Views into the type.
315   LLVM_ABI UAVInfo getUAV() const;
316   LLVM_ABI uint32_t getCBufferSize(const DataLayout &DL) const;
317   LLVM_ABI dxil::SamplerType getSamplerType() const;
318   LLVM_ABI StructInfo getStruct(const DataLayout &DL) const;
319   LLVM_ABI TypedInfo getTyped() const;
320   LLVM_ABI dxil::SamplerFeedbackType getFeedbackType() const;
321   LLVM_ABI uint32_t getMultiSampleCount() const;
322 
getResourceClass()323   dxil::ResourceClass getResourceClass() const { return RC; }
getResourceKind()324   dxil::ResourceKind getResourceKind() const { return Kind; }
325 
326   LLVM_ABI bool operator==(const ResourceTypeInfo &RHS) const;
327   bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); }
328   LLVM_ABI bool operator<(const ResourceTypeInfo &RHS) const;
329 
330   LLVM_ABI void print(raw_ostream &OS, const DataLayout &DL) const;
331 };
332 
333 //===----------------------------------------------------------------------===//
334 
335 enum class ResourceCounterDirection {
336   Increment,
337   Decrement,
338   Unknown,
339   Invalid,
340 };
341 
342 class ResourceInfo {
343 public:
344   struct ResourceBinding {
345     uint32_t RecordID;
346     uint32_t Space;
347     uint32_t LowerBound;
348     uint32_t Size;
349 
350     bool operator==(const ResourceBinding &RHS) const {
351       return std::tie(RecordID, Space, LowerBound, Size) ==
352              std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
353     }
354     bool operator!=(const ResourceBinding &RHS) const {
355       return !(*this == RHS);
356     }
357     bool operator<(const ResourceBinding &RHS) const {
358       return std::tie(RecordID, Space, LowerBound, Size) <
359              std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
360     }
overlapsWithResourceBinding361     bool overlapsWith(const ResourceBinding &RHS) const {
362       return Space == RHS.Space && LowerBound + Size - 1 >= RHS.LowerBound;
363     }
364   };
365 
366 private:
367   ResourceBinding Binding;
368   TargetExtType *HandleTy;
369   StringRef Name;
370   GlobalVariable *Symbol = nullptr;
371 
372 public:
373   bool GloballyCoherent = false;
374   ResourceCounterDirection CounterDirection = ResourceCounterDirection::Unknown;
375 
376   ResourceInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
377                uint32_t Size, TargetExtType *HandleTy, StringRef Name = "",
378                GlobalVariable *Symbol = nullptr)
379       : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy),
380         Name(Name), Symbol(Symbol) {}
381 
setBindingID(unsigned ID)382   void setBindingID(unsigned ID) { Binding.RecordID = ID; }
383 
hasCounter()384   bool hasCounter() const {
385     return CounterDirection != ResourceCounterDirection::Unknown;
386   }
387 
getBinding()388   const ResourceBinding &getBinding() const { return Binding; }
getHandleTy()389   TargetExtType *getHandleTy() const { return HandleTy; }
getName()390   StringRef getName() const { return Name; }
391 
hasSymbol()392   bool hasSymbol() const { return Symbol; }
393   LLVM_ABI GlobalVariable *createSymbol(Module &M, StructType *Ty);
394   LLVM_ABI MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo &RTI) const;
395 
396   LLVM_ABI std::pair<uint32_t, uint32_t>
397   getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const;
398 
399   bool operator==(const ResourceInfo &RHS) const {
400     return std::tie(Binding, HandleTy, Symbol, Name) ==
401            std::tie(RHS.Binding, RHS.HandleTy, RHS.Symbol, RHS.Name);
402   }
403   bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
404   bool operator<(const ResourceInfo &RHS) const {
405     return Binding < RHS.Binding;
406   }
407 
408   LLVM_ABI void print(raw_ostream &OS, dxil::ResourceTypeInfo &RTI,
409                       const DataLayout &DL) const;
410 };
411 
412 } // namespace dxil
413 
414 //===----------------------------------------------------------------------===//
415 
416 class DXILResourceTypeMap {
417   DenseMap<TargetExtType *, dxil::ResourceTypeInfo> Infos;
418 
419 public:
420   LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
421                            ModuleAnalysisManager::Invalidator &Inv);
422 
423   dxil::ResourceTypeInfo &operator[](TargetExtType *Ty) {
424     auto It = Infos.find(Ty);
425     if (It != Infos.end())
426       return It->second;
427     auto [NewIt, Inserted] = Infos.try_emplace(Ty, Ty);
428     return NewIt->second;
429   }
430 };
431 
432 class DXILResourceTypeAnalysis
433     : public AnalysisInfoMixin<DXILResourceTypeAnalysis> {
434   friend AnalysisInfoMixin<DXILResourceTypeAnalysis>;
435 
436   LLVM_ABI static AnalysisKey Key;
437 
438 public:
439   using Result = DXILResourceTypeMap;
440 
run(Module & M,ModuleAnalysisManager & AM)441   DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) {
442     // Running the pass just generates an empty map, which will be filled when
443     // users of the pass query the results.
444     return Result();
445   }
446 };
447 
448 class LLVM_ABI DXILResourceTypeWrapperPass : public ImmutablePass {
449   DXILResourceTypeMap DRTM;
450 
451   virtual void anchor();
452 
453 public:
454   static char ID;
455   DXILResourceTypeWrapperPass();
456 
getResourceTypeMap()457   DXILResourceTypeMap &getResourceTypeMap() { return DRTM; }
getResourceTypeMap()458   const DXILResourceTypeMap &getResourceTypeMap() const { return DRTM; }
459 };
460 
461 LLVM_ABI ModulePass *createDXILResourceTypeWrapperPassPass();
462 
463 //===----------------------------------------------------------------------===//
464 
465 class DXILResourceMap {
466   using CallMapTy = DenseMap<CallInst *, unsigned>;
467 
468   SmallVector<dxil::ResourceInfo> Infos;
469   CallMapTy CallMap;
470   unsigned FirstUAV = 0;
471   unsigned FirstCBuffer = 0;
472   unsigned FirstSampler = 0;
473   bool HasInvalidDirection = false;
474 
475   /// Populate all the resource instance data.
476   void populate(Module &M, DXILResourceTypeMap &DRTM);
477   /// Populate the map given the resource binding calls in the given module.
478   void populateResourceInfos(Module &M, DXILResourceTypeMap &DRTM);
479   /// Analyze and populate the directions of the resource counters.
480   void populateCounterDirections(Module &M);
481 
482   /// Resolves a resource handle into a vector of ResourceInfos that
483   /// represent the possible unique creations of the handle. Certain cases are
484   /// ambiguous so multiple creation instructions may be returned. The resulting
485   /// ResourceInfo can be used to depuplicate unique handles that
486   /// reference the same resource
487   SmallVector<dxil::ResourceInfo *> findByUse(const Value *Key);
488 
489 public:
490   using iterator = SmallVector<dxil::ResourceInfo>::iterator;
491   using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
492 
begin()493   iterator begin() { return Infos.begin(); }
begin()494   const_iterator begin() const { return Infos.begin(); }
end()495   iterator end() { return Infos.end(); }
end()496   const_iterator end() const { return Infos.end(); }
497 
empty()498   bool empty() const { return Infos.empty(); }
499 
find(const CallInst * Key)500   iterator find(const CallInst *Key) {
501     auto Pos = CallMap.find(Key);
502     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
503   }
504 
find(const CallInst * Key)505   const_iterator find(const CallInst *Key) const {
506     auto Pos = CallMap.find(Key);
507     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
508   }
509 
srv_begin()510   iterator srv_begin() { return begin(); }
srv_begin()511   const_iterator srv_begin() const { return begin(); }
srv_end()512   iterator srv_end() { return begin() + FirstUAV; }
srv_end()513   const_iterator srv_end() const { return begin() + FirstUAV; }
srvs()514   iterator_range<iterator> srvs() { return make_range(srv_begin(), srv_end()); }
srvs()515   iterator_range<const_iterator> srvs() const {
516     return make_range(srv_begin(), srv_end());
517   }
518 
uav_begin()519   iterator uav_begin() { return begin() + FirstUAV; }
uav_begin()520   const_iterator uav_begin() const { return begin() + FirstUAV; }
uav_end()521   iterator uav_end() { return begin() + FirstCBuffer; }
uav_end()522   const_iterator uav_end() const { return begin() + FirstCBuffer; }
uavs()523   iterator_range<iterator> uavs() { return make_range(uav_begin(), uav_end()); }
uavs()524   iterator_range<const_iterator> uavs() const {
525     return make_range(uav_begin(), uav_end());
526   }
527 
cbuffer_begin()528   iterator cbuffer_begin() { return begin() + FirstCBuffer; }
cbuffer_begin()529   const_iterator cbuffer_begin() const { return begin() + FirstCBuffer; }
cbuffer_end()530   iterator cbuffer_end() { return begin() + FirstSampler; }
cbuffer_end()531   const_iterator cbuffer_end() const { return begin() + FirstSampler; }
cbuffers()532   iterator_range<iterator> cbuffers() {
533     return make_range(cbuffer_begin(), cbuffer_end());
534   }
cbuffers()535   iterator_range<const_iterator> cbuffers() const {
536     return make_range(cbuffer_begin(), cbuffer_end());
537   }
538 
sampler_begin()539   iterator sampler_begin() { return begin() + FirstSampler; }
sampler_begin()540   const_iterator sampler_begin() const { return begin() + FirstSampler; }
sampler_end()541   iterator sampler_end() { return end(); }
sampler_end()542   const_iterator sampler_end() const { return end(); }
samplers()543   iterator_range<iterator> samplers() {
544     return make_range(sampler_begin(), sampler_end());
545   }
samplers()546   iterator_range<const_iterator> samplers() const {
547     return make_range(sampler_begin(), sampler_end());
548   }
549 
550   struct call_iterator
551       : iterator_adaptor_base<call_iterator, CallMapTy::iterator> {
552     call_iterator() = default;
call_iteratorcall_iterator553     call_iterator(CallMapTy::iterator Iter)
554         : call_iterator::iterator_adaptor_base(std::move(Iter)) {}
555 
556     CallInst *operator*() const { return I->first; }
557   };
558 
call_begin()559   call_iterator call_begin() { return call_iterator(CallMap.begin()); }
call_end()560   call_iterator call_end() { return call_iterator(CallMap.end()); }
calls()561   iterator_range<call_iterator> calls() {
562     return make_range(call_begin(), call_end());
563   }
564 
hasInvalidCounterDirection()565   bool hasInvalidCounterDirection() const { return HasInvalidDirection; }
566 
567   LLVM_ABI void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
568                       const DataLayout &DL) const;
569 
570   friend class DXILResourceAnalysis;
571   friend class DXILResourceWrapperPass;
572 };
573 
574 class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
575   friend AnalysisInfoMixin<DXILResourceAnalysis>;
576 
577   LLVM_ABI static AnalysisKey Key;
578 
579 public:
580   using Result = DXILResourceMap;
581 
582   /// Gather resource info for the module \c M.
583   LLVM_ABI DXILResourceMap run(Module &M, ModuleAnalysisManager &AM);
584 };
585 
586 /// Printer pass for the \c DXILResourceAnalysis results.
587 class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
588   raw_ostream &OS;
589 
590 public:
DXILResourcePrinterPass(raw_ostream & OS)591   explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {}
592 
593   LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
594 
isRequired()595   static bool isRequired() { return true; }
596 };
597 
598 class LLVM_ABI DXILResourceWrapperPass : public ModulePass {
599   std::unique_ptr<DXILResourceMap> Map;
600   DXILResourceTypeMap *DRTM;
601 
602 public:
603   static char ID; // Class identification, replacement for typeinfo
604 
605   DXILResourceWrapperPass();
606   ~DXILResourceWrapperPass() override;
607 
getResourceMap()608   const DXILResourceMap &getResourceMap() const { return *Map; }
getResourceMap()609   DXILResourceMap &getResourceMap() { return *Map; }
610 
611   void getAnalysisUsage(AnalysisUsage &AU) const override;
612   bool runOnModule(Module &M) override;
613   void releaseMemory() override;
614 
615   void print(raw_ostream &OS, const Module *M) const override;
616   void dump() const;
617 };
618 
619 LLVM_ABI ModulePass *createDXILResourceWrapperPassPass();
620 
621 //===----------------------------------------------------------------------===//
622 
623 // DXILResourceBindingInfo stores the results of DXILResourceBindingAnalysis
624 // which analyses all llvm.dx.resource.handlefrombinding calls in the module
625 // and puts together lists of used virtual register spaces and available
626 // virtual register slot ranges for each binding type.
627 // It also stores additional information found during the analysis such as
628 // whether the module uses implicit bindings or if any of the bindings overlap.
629 //
630 // This information will be used in DXILResourceImplicitBindings pass to assign
631 // register slots to resources with implicit bindings, and in a
632 // post-optimization validation pass that will raise diagnostic about
633 // overlapping bindings.
634 //
635 // For example for these resource bindings:
636 //
637 // RWBuffer<float> A[10] : register(u3);
638 // RWBuffer<float> B[] : register(u5, space2)
639 //
640 // The analysis result for UAV binding type will look like this:
641 //
642 // UAVSpaces {
643 //   ResClass = ResourceClass::UAV,
644 //   Spaces = {
645 //     { Space = 0, FreeRanges = {{ 0, 2 }, { 13, UINT32_MAX }} },
646 //     { Space = 2, FreeRanges = {{ 0, 4 }} }
647 //   }
648 // }
649 //
650 class DXILResourceBindingInfo {
651 public:
652   struct BindingRange {
653     uint32_t LowerBound;
654     uint32_t UpperBound;
BindingRangeBindingRange655     BindingRange(uint32_t LB, uint32_t UB) : LowerBound(LB), UpperBound(UB) {}
656   };
657 
658   struct RegisterSpace {
659     uint32_t Space;
660     SmallVector<BindingRange> FreeRanges;
RegisterSpaceRegisterSpace661     RegisterSpace(uint32_t Space) : Space(Space) {
662       FreeRanges.emplace_back(0, UINT32_MAX);
663     }
664     // Size == -1 means unbounded array
665     LLVM_ABI std::optional<uint32_t> findAvailableBinding(int32_t Size);
666   };
667 
668   struct BindingSpaces {
669     dxil::ResourceClass RC;
670     llvm::SmallVector<RegisterSpace> Spaces;
BindingSpacesBindingSpaces671     BindingSpaces(dxil::ResourceClass RC) : RC(RC) {}
672     LLVM_ABI RegisterSpace &getOrInsertSpace(uint32_t Space);
673   };
674 
675 private:
676   BindingSpaces SRVSpaces, UAVSpaces, CBufferSpaces, SamplerSpaces;
677   bool ImplicitBinding;
678   bool OverlappingBinding;
679 
680   // Populate the resource binding info given explicit resource binding calls
681   // in the module.
682   void populate(Module &M, DXILResourceTypeMap &DRTM);
683 
684 public:
DXILResourceBindingInfo()685   DXILResourceBindingInfo()
686       : SRVSpaces(dxil::ResourceClass::SRV),
687         UAVSpaces(dxil::ResourceClass::UAV),
688         CBufferSpaces(dxil::ResourceClass::CBuffer),
689         SamplerSpaces(dxil::ResourceClass::Sampler), ImplicitBinding(false),
690         OverlappingBinding(false) {}
691 
hasImplicitBinding()692   bool hasImplicitBinding() const { return ImplicitBinding; }
setHasImplicitBinding(bool Value)693   void setHasImplicitBinding(bool Value) { ImplicitBinding = Value; }
hasOverlappingBinding()694   bool hasOverlappingBinding() const { return OverlappingBinding; }
695 
getBindingSpaces(dxil::ResourceClass RC)696   BindingSpaces &getBindingSpaces(dxil::ResourceClass RC) {
697     switch (RC) {
698     case dxil::ResourceClass::SRV:
699       return SRVSpaces;
700     case dxil::ResourceClass::UAV:
701       return UAVSpaces;
702     case dxil::ResourceClass::CBuffer:
703       return CBufferSpaces;
704     case dxil::ResourceClass::Sampler:
705       return SamplerSpaces;
706     }
707 
708     llvm_unreachable("Invalid resource class");
709   }
710 
711   // Size == -1 means unbounded array
712   LLVM_ABI std::optional<uint32_t>
713   findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, int32_t Size);
714 
715   friend class DXILResourceBindingAnalysis;
716   friend class DXILResourceBindingWrapperPass;
717 };
718 
719 class DXILResourceBindingAnalysis
720     : public AnalysisInfoMixin<DXILResourceBindingAnalysis> {
721   friend AnalysisInfoMixin<DXILResourceBindingAnalysis>;
722 
723   LLVM_ABI static AnalysisKey Key;
724 
725 public:
726   using Result = DXILResourceBindingInfo;
727 
728   LLVM_ABI DXILResourceBindingInfo run(Module &M, ModuleAnalysisManager &AM);
729 };
730 
731 class LLVM_ABI DXILResourceBindingWrapperPass : public ModulePass {
732   std::unique_ptr<DXILResourceBindingInfo> BindingInfo;
733 
734 public:
735   static char ID;
736 
737   DXILResourceBindingWrapperPass();
738   ~DXILResourceBindingWrapperPass() override;
739 
getBindingInfo()740   DXILResourceBindingInfo &getBindingInfo() { return *BindingInfo; }
getBindingInfo()741   const DXILResourceBindingInfo &getBindingInfo() const { return *BindingInfo; }
742 
743   void getAnalysisUsage(AnalysisUsage &AU) const override;
744   bool runOnModule(Module &M) override;
745   void releaseMemory() override;
746 };
747 
748 LLVM_ABI ModulePass *createDXILResourceBindingWrapperPassPass();
749 
750 } // namespace llvm
751 
752 #endif // LLVM_ANALYSIS_DXILRESOURCE_H
753