xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (revision 9f23cbd6cae82fd77edfad7173432fa8dccd0a95)
1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineModuleInfo.h"
23 
24 #include <type_traits>
25 
26 namespace llvm {
27 namespace SPIRV {
28 // NOTE: using MapVector instead of DenseMap because it helps getting
29 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30 // memory and expensive removals which do not happen anyway.
31 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
32   SmallVector<DTSortableEntry *, 2> Deps;
33 
34   struct FlagsTy {
35     unsigned IsFunc : 1;
36     unsigned IsGV : 1;
37     // NOTE: bit-field default init is a C++20 feature.
38     FlagsTy() : IsFunc(0), IsGV(0) {}
39   };
40   FlagsTy Flags;
41 
42 public:
43   // Common hoisting utility doesn't support function, because their hoisting
44   // require hoisting of params as well.
45   bool getIsFunc() const { return Flags.IsFunc; }
46   bool getIsGV() const { return Flags.IsGV; }
47   void setIsFunc(bool V) { Flags.IsFunc = V; }
48   void setIsGV(bool V) { Flags.IsGV = V; }
49 
50   const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
51   void addDep(DTSortableEntry *E) { Deps.push_back(E); }
52 };
53 
54 struct SpecialTypeDescriptor {
55   enum SpecialTypeKind {
56     STK_Empty = 0,
57     STK_Image,
58     STK_SampledImage,
59     STK_Sampler,
60     STK_Pipe,
61     STK_DeviceEvent,
62     STK_Last = -1
63   };
64   SpecialTypeKind Kind;
65 
66   unsigned Hash;
67 
68   SpecialTypeDescriptor() = delete;
69   SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
70 
71   unsigned getHash() const { return Hash; }
72 
73   virtual ~SpecialTypeDescriptor() {}
74 };
75 
76 struct ImageTypeDescriptor : public SpecialTypeDescriptor {
77   union ImageAttrs {
78     struct BitFlags {
79       unsigned Dim : 3;
80       unsigned Depth : 2;
81       unsigned Arrayed : 1;
82       unsigned MS : 1;
83       unsigned Sampled : 2;
84       unsigned ImageFormat : 6;
85       unsigned AQ : 2;
86     } Flags;
87     unsigned Val;
88   };
89 
90   ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
91                       unsigned Arrayed, unsigned MS, unsigned Sampled,
92                       unsigned ImageFormat, unsigned AQ = 0)
93       : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
94     ImageAttrs Attrs;
95     Attrs.Val = 0;
96     Attrs.Flags.Dim = Dim;
97     Attrs.Flags.Depth = Depth;
98     Attrs.Flags.Arrayed = Arrayed;
99     Attrs.Flags.MS = MS;
100     Attrs.Flags.Sampled = Sampled;
101     Attrs.Flags.ImageFormat = ImageFormat;
102     Attrs.Flags.AQ = AQ;
103     Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
104            ((Attrs.Val << 8) | Kind);
105   }
106 
107   static bool classof(const SpecialTypeDescriptor *TD) {
108     return TD->Kind == SpecialTypeKind::STK_Image;
109   }
110 };
111 
112 struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
113   SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
114       : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
115     assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
116     ImageTypeDescriptor TD(
117         SampledTy, ImageTy->getOperand(2).getImm(),
118         ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
119         ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
120         ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
121     Hash = TD.getHash() ^ Kind;
122   }
123 
124   static bool classof(const SpecialTypeDescriptor *TD) {
125     return TD->Kind == SpecialTypeKind::STK_SampledImage;
126   }
127 };
128 
129 struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
130   SamplerTypeDescriptor()
131       : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
132     Hash = Kind;
133   }
134 
135   static bool classof(const SpecialTypeDescriptor *TD) {
136     return TD->Kind == SpecialTypeKind::STK_Sampler;
137   }
138 };
139 
140 struct PipeTypeDescriptor : public SpecialTypeDescriptor {
141 
142   PipeTypeDescriptor(uint8_t AQ)
143       : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
144     Hash = (AQ << 8) | Kind;
145   }
146 
147   static bool classof(const SpecialTypeDescriptor *TD) {
148     return TD->Kind == SpecialTypeKind::STK_Pipe;
149   }
150 };
151 
152 struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
153 
154   DeviceEventTypeDescriptor()
155       : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
156     Hash = Kind;
157   }
158 
159   static bool classof(const SpecialTypeDescriptor *TD) {
160     return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
161   }
162 };
163 } // namespace SPIRV
164 
165 template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
166   static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
167     return SPIRV::SpecialTypeDescriptor(
168         SPIRV::SpecialTypeDescriptor::STK_Empty);
169   }
170   static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
171     return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
172   }
173   static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
174     return Val.getHash();
175   }
176   static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
177                       SPIRV::SpecialTypeDescriptor RHS) {
178     return getHashValue(LHS) == getHashValue(RHS);
179   }
180 };
181 
182 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
183 public:
184   // NOTE: using MapVector instead of DenseMap helps getting everything ordered
185   // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
186   // expensive removals which don't happen anyway.
187   using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
188 
189 private:
190   StorageTy Storage;
191 
192 public:
193   void add(KeyTy V, const MachineFunction *MF, Register R) {
194     if (find(V, MF).isValid())
195       return;
196 
197     Storage[V][MF] = R;
198     if (std::is_same<Function,
199                      typename std::remove_const<
200                          typename std::remove_pointer<KeyTy>::type>::type>() ||
201         std::is_same<Argument,
202                      typename std::remove_const<
203                          typename std::remove_pointer<KeyTy>::type>::type>())
204       Storage[V].setIsFunc(true);
205     if (std::is_same<GlobalVariable,
206                      typename std::remove_const<
207                          typename std::remove_pointer<KeyTy>::type>::type>())
208       Storage[V].setIsGV(true);
209   }
210 
211   Register find(KeyTy V, const MachineFunction *MF) const {
212     auto iter = Storage.find(V);
213     if (iter != Storage.end()) {
214       auto Map = iter->second;
215       auto iter2 = Map.find(MF);
216       if (iter2 != Map.end())
217         return iter2->second;
218     }
219     return Register();
220   }
221 
222   const StorageTy &getAllUses() const { return Storage; }
223 
224 private:
225   StorageTy &getAllUses() { return Storage; }
226 
227   // The friend class needs to have access to the internal storage
228   // to be able to build dependency graph, can't declare only one
229   // function a 'friend' due to the incomplete declaration at this point
230   // and mutual dependency problems.
231   friend class SPIRVGeneralDuplicatesTracker;
232 };
233 
234 template <typename T>
235 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
236 
237 template <>
238 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
239     : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
240 
241 class SPIRVGeneralDuplicatesTracker {
242   SPIRVDuplicatesTracker<Type> TT;
243   SPIRVDuplicatesTracker<Constant> CT;
244   SPIRVDuplicatesTracker<GlobalVariable> GT;
245   SPIRVDuplicatesTracker<Function> FT;
246   SPIRVDuplicatesTracker<Argument> AT;
247   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
248 
249   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
250   // to use flat data structure.
251   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
252   // but makes LITs more stable, should prefer DenseMap still due to
253   // significant perf difference.
254   using SPIRVReg2EntryTy =
255       MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
256 
257   template <typename T>
258   void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
259                          SPIRVReg2EntryTy &Reg2Entry);
260 
261 public:
262   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
263                       MachineModuleInfo *MMI);
264 
265   void add(const Type *T, const MachineFunction *MF, Register R) {
266     TT.add(T, MF, R);
267   }
268 
269   void add(const Constant *C, const MachineFunction *MF, Register R) {
270     CT.add(C, MF, R);
271   }
272 
273   void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
274     GT.add(GV, MF, R);
275   }
276 
277   void add(const Function *F, const MachineFunction *MF, Register R) {
278     FT.add(F, MF, R);
279   }
280 
281   void add(const Argument *Arg, const MachineFunction *MF, Register R) {
282     AT.add(Arg, MF, R);
283   }
284 
285   void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
286            Register R) {
287     ST.add(TD, MF, R);
288   }
289 
290   Register find(const Type *T, const MachineFunction *MF) {
291     return TT.find(const_cast<Type *>(T), MF);
292   }
293 
294   Register find(const Constant *C, const MachineFunction *MF) {
295     return CT.find(const_cast<Constant *>(C), MF);
296   }
297 
298   Register find(const GlobalVariable *GV, const MachineFunction *MF) {
299     return GT.find(const_cast<GlobalVariable *>(GV), MF);
300   }
301 
302   Register find(const Function *F, const MachineFunction *MF) {
303     return FT.find(const_cast<Function *>(F), MF);
304   }
305 
306   Register find(const Argument *Arg, const MachineFunction *MF) {
307     return AT.find(const_cast<Argument *>(Arg), MF);
308   }
309 
310   Register find(const SPIRV::SpecialTypeDescriptor &TD,
311                 const MachineFunction *MF) {
312     return ST.find(TD, MF);
313   }
314 
315   const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
316 };
317 } // namespace llvm
318 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
319