xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (revision 2f9966ff63d65bd474478888c9088eeae3f9c669)
1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineModuleInfo.h"
23 
24 #include <type_traits>
25 
26 namespace llvm {
27 namespace SPIRV {
28 // NOTE: using MapVector instead of DenseMap because it helps getting
29 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30 // memory and expensive removals which do not happen anyway.
31 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
32   SmallVector<DTSortableEntry *, 2> Deps;
33 
34   struct FlagsTy {
35     unsigned IsFunc : 1;
36     unsigned IsGV : 1;
37     // NOTE: bit-field default init is a C++20 feature.
38     FlagsTy() : IsFunc(0), IsGV(0) {}
39   };
40   FlagsTy Flags;
41 
42 public:
43   // Common hoisting utility doesn't support function, because their hoisting
44   // require hoisting of params as well.
45   bool getIsFunc() const { return Flags.IsFunc; }
46   bool getIsGV() const { return Flags.IsGV; }
47   void setIsFunc(bool V) { Flags.IsFunc = V; }
48   void setIsGV(bool V) { Flags.IsGV = V; }
49 
50   const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
51   void addDep(DTSortableEntry *E) { Deps.push_back(E); }
52 };
53 
54 struct SpecialTypeDescriptor {
55   enum SpecialTypeKind {
56     STK_Empty = 0,
57     STK_Image,
58     STK_SampledImage,
59     STK_Sampler,
60     STK_Pipe,
61     STK_DeviceEvent,
62     STK_Pointer,
63     STK_Last = -1
64   };
65   SpecialTypeKind Kind;
66 
67   unsigned Hash;
68 
69   SpecialTypeDescriptor() = delete;
70   SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
71 
72   unsigned getHash() const { return Hash; }
73 
74   virtual ~SpecialTypeDescriptor() {}
75 };
76 
77 struct ImageTypeDescriptor : public SpecialTypeDescriptor {
78   union ImageAttrs {
79     struct BitFlags {
80       unsigned Dim : 3;
81       unsigned Depth : 2;
82       unsigned Arrayed : 1;
83       unsigned MS : 1;
84       unsigned Sampled : 2;
85       unsigned ImageFormat : 6;
86       unsigned AQ : 2;
87     } Flags;
88     unsigned Val;
89   };
90 
91   ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
92                       unsigned Arrayed, unsigned MS, unsigned Sampled,
93                       unsigned ImageFormat, unsigned AQ = 0)
94       : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
95     ImageAttrs Attrs;
96     Attrs.Val = 0;
97     Attrs.Flags.Dim = Dim;
98     Attrs.Flags.Depth = Depth;
99     Attrs.Flags.Arrayed = Arrayed;
100     Attrs.Flags.MS = MS;
101     Attrs.Flags.Sampled = Sampled;
102     Attrs.Flags.ImageFormat = ImageFormat;
103     Attrs.Flags.AQ = AQ;
104     Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
105            ((Attrs.Val << 8) | Kind);
106   }
107 
108   static bool classof(const SpecialTypeDescriptor *TD) {
109     return TD->Kind == SpecialTypeKind::STK_Image;
110   }
111 };
112 
113 struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
114   SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
115       : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
116     assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
117     ImageTypeDescriptor TD(
118         SampledTy, ImageTy->getOperand(2).getImm(),
119         ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
120         ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
121         ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
122     Hash = TD.getHash() ^ Kind;
123   }
124 
125   static bool classof(const SpecialTypeDescriptor *TD) {
126     return TD->Kind == SpecialTypeKind::STK_SampledImage;
127   }
128 };
129 
130 struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
131   SamplerTypeDescriptor()
132       : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
133     Hash = Kind;
134   }
135 
136   static bool classof(const SpecialTypeDescriptor *TD) {
137     return TD->Kind == SpecialTypeKind::STK_Sampler;
138   }
139 };
140 
141 struct PipeTypeDescriptor : public SpecialTypeDescriptor {
142 
143   PipeTypeDescriptor(uint8_t AQ)
144       : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
145     Hash = (AQ << 8) | Kind;
146   }
147 
148   static bool classof(const SpecialTypeDescriptor *TD) {
149     return TD->Kind == SpecialTypeKind::STK_Pipe;
150   }
151 };
152 
153 struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
154 
155   DeviceEventTypeDescriptor()
156       : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
157     Hash = Kind;
158   }
159 
160   static bool classof(const SpecialTypeDescriptor *TD) {
161     return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
162   }
163 };
164 
165 struct PointerTypeDescriptor : public SpecialTypeDescriptor {
166   const Type *ElementType;
167   unsigned AddressSpace;
168 
169   PointerTypeDescriptor() = delete;
170   PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
171       : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
172         ElementType(ElementType), AddressSpace(AddressSpace) {
173     Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
174            ((AddressSpace << 8) | Kind);
175   }
176 
177   static bool classof(const SpecialTypeDescriptor *TD) {
178     return TD->Kind == SpecialTypeKind::STK_Pointer;
179   }
180 };
181 } // namespace SPIRV
182 
183 template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
184   static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
185     return SPIRV::SpecialTypeDescriptor(
186         SPIRV::SpecialTypeDescriptor::STK_Empty);
187   }
188   static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
189     return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
190   }
191   static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
192     return Val.getHash();
193   }
194   static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
195                       SPIRV::SpecialTypeDescriptor RHS) {
196     return getHashValue(LHS) == getHashValue(RHS);
197   }
198 };
199 
200 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
201 public:
202   // NOTE: using MapVector instead of DenseMap helps getting everything ordered
203   // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
204   // expensive removals which don't happen anyway.
205   using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
206 
207 private:
208   StorageTy Storage;
209 
210 public:
211   void add(KeyTy V, const MachineFunction *MF, Register R) {
212     if (find(V, MF).isValid())
213       return;
214 
215     Storage[V][MF] = R;
216     if (std::is_same<Function,
217                      typename std::remove_const<
218                          typename std::remove_pointer<KeyTy>::type>::type>() ||
219         std::is_same<Argument,
220                      typename std::remove_const<
221                          typename std::remove_pointer<KeyTy>::type>::type>())
222       Storage[V].setIsFunc(true);
223     if (std::is_same<GlobalVariable,
224                      typename std::remove_const<
225                          typename std::remove_pointer<KeyTy>::type>::type>())
226       Storage[V].setIsGV(true);
227   }
228 
229   Register find(KeyTy V, const MachineFunction *MF) const {
230     auto iter = Storage.find(V);
231     if (iter != Storage.end()) {
232       auto Map = iter->second;
233       auto iter2 = Map.find(MF);
234       if (iter2 != Map.end())
235         return iter2->second;
236     }
237     return Register();
238   }
239 
240   const StorageTy &getAllUses() const { return Storage; }
241 
242 private:
243   StorageTy &getAllUses() { return Storage; }
244 
245   // The friend class needs to have access to the internal storage
246   // to be able to build dependency graph, can't declare only one
247   // function a 'friend' due to the incomplete declaration at this point
248   // and mutual dependency problems.
249   friend class SPIRVGeneralDuplicatesTracker;
250 };
251 
252 template <typename T>
253 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
254 
255 template <>
256 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
257     : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
258 
259 class SPIRVGeneralDuplicatesTracker {
260   SPIRVDuplicatesTracker<Type> TT;
261   SPIRVDuplicatesTracker<Constant> CT;
262   SPIRVDuplicatesTracker<GlobalVariable> GT;
263   SPIRVDuplicatesTracker<Function> FT;
264   SPIRVDuplicatesTracker<Argument> AT;
265   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
266 
267   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
268   // to use flat data structure.
269   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
270   // but makes LITs more stable, should prefer DenseMap still due to
271   // significant perf difference.
272   using SPIRVReg2EntryTy =
273       MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
274 
275   template <typename T>
276   void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
277                          SPIRVReg2EntryTy &Reg2Entry);
278 
279 public:
280   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
281                       MachineModuleInfo *MMI);
282 
283   void add(const Type *Ty, const MachineFunction *MF, Register R) {
284     TT.add(Ty, MF, R);
285   }
286 
287   void add(const Type *PointerElementType, unsigned AddressSpace,
288            const MachineFunction *MF, Register R) {
289     ST.add(SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF,
290            R);
291   }
292 
293   void add(const Constant *C, const MachineFunction *MF, Register R) {
294     CT.add(C, MF, R);
295   }
296 
297   void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
298     GT.add(GV, MF, R);
299   }
300 
301   void add(const Function *F, const MachineFunction *MF, Register R) {
302     FT.add(F, MF, R);
303   }
304 
305   void add(const Argument *Arg, const MachineFunction *MF, Register R) {
306     AT.add(Arg, MF, R);
307   }
308 
309   void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
310            Register R) {
311     ST.add(TD, MF, R);
312   }
313 
314   Register find(const Type *Ty, const MachineFunction *MF) {
315     return TT.find(const_cast<Type *>(Ty), MF);
316   }
317 
318   Register find(const Type *PointerElementType, unsigned AddressSpace,
319                 const MachineFunction *MF) {
320     return ST.find(
321         SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF);
322   }
323 
324   Register find(const Constant *C, const MachineFunction *MF) {
325     return CT.find(const_cast<Constant *>(C), MF);
326   }
327 
328   Register find(const GlobalVariable *GV, const MachineFunction *MF) {
329     return GT.find(const_cast<GlobalVariable *>(GV), MF);
330   }
331 
332   Register find(const Function *F, const MachineFunction *MF) {
333     return FT.find(const_cast<Function *>(F), MF);
334   }
335 
336   Register find(const Argument *Arg, const MachineFunction *MF) {
337     return AT.find(const_cast<Argument *>(Arg), MF);
338   }
339 
340   Register find(const SPIRV::SpecialTypeDescriptor &TD,
341                 const MachineFunction *MF) {
342     return ST.find(TD, MF);
343   }
344 
345   const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
346 };
347 } // namespace llvm
348 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
349