xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/MachineModuleInfo.h"
24 
25 #include <type_traits>
26 
27 namespace llvm {
28 namespace SPIRV {
29 // NOTE: using MapVector instead of DenseMap because it helps getting
30 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
31 // memory and expensive removals which do not happen anyway.
32 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
33   SmallVector<DTSortableEntry *, 2> Deps;
34 
35   struct FlagsTy {
36     unsigned IsFunc : 1;
37     unsigned IsGV : 1;
38     // NOTE: bit-field default init is a C++20 feature.
39     FlagsTy() : IsFunc(0), IsGV(0) {}
40   };
41   FlagsTy Flags;
42 
43 public:
44   // Common hoisting utility doesn't support function, because their hoisting
45   // require hoisting of params as well.
46   bool getIsFunc() const { return Flags.IsFunc; }
47   bool getIsGV() const { return Flags.IsGV; }
48   void setIsFunc(bool V) { Flags.IsFunc = V; }
49   void setIsGV(bool V) { Flags.IsGV = V; }
50 
51   const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
52   void addDep(DTSortableEntry *E) { Deps.push_back(E); }
53 };
54 
55 enum SpecialTypeKind {
56   STK_Empty = 0,
57   STK_Image,
58   STK_SampledImage,
59   STK_Sampler,
60   STK_Pipe,
61   STK_DeviceEvent,
62   STK_Pointer,
63   STK_Last = -1
64 };
65 
66 using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
67 
68 union ImageAttrs {
69   struct BitFlags {
70     unsigned Dim : 3;
71     unsigned Depth : 2;
72     unsigned Arrayed : 1;
73     unsigned MS : 1;
74     unsigned Sampled : 2;
75     unsigned ImageFormat : 6;
76     unsigned AQ : 2;
77   } Flags;
78   unsigned Val;
79 
80   ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
81              unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
82     Val = 0;
83     Flags.Dim = Dim;
84     Flags.Depth = Depth;
85     Flags.Arrayed = Arrayed;
86     Flags.MS = MS;
87     Flags.Sampled = Sampled;
88     Flags.ImageFormat = ImageFormat;
89     Flags.AQ = AQ;
90   }
91 };
92 
93 inline SpecialTypeDescriptor
94 make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
95                  unsigned Arrayed, unsigned MS, unsigned Sampled,
96                  unsigned ImageFormat, unsigned AQ = 0) {
97   return std::make_tuple(
98       SampledTy,
99       ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
100       SpecialTypeKind::STK_Image);
101 }
102 
103 inline SpecialTypeDescriptor
104 make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
105   assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
106   return std::make_tuple(
107       SampledTy,
108       ImageAttrs(
109           ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
110           ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
111           ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
112           ImageTy->getOperand(8).getImm())
113           .Val,
114       SpecialTypeKind::STK_SampledImage);
115 }
116 
117 inline SpecialTypeDescriptor make_descr_sampler() {
118   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
119 }
120 
121 inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
122   return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
123 }
124 
125 inline SpecialTypeDescriptor make_descr_event() {
126   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
127 }
128 
129 inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
130                                                 unsigned AddressSpace) {
131   return std::make_tuple(ElementType, AddressSpace,
132                          SpecialTypeKind::STK_Pointer);
133 }
134 } // namespace SPIRV
135 
136 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
137 public:
138   // NOTE: using MapVector instead of DenseMap helps getting everything ordered
139   // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
140   // expensive removals which don't happen anyway.
141   using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
142 
143 private:
144   StorageTy Storage;
145 
146 public:
147   void add(KeyTy V, const MachineFunction *MF, Register R) {
148     if (find(V, MF).isValid())
149       return;
150 
151     Storage[V][MF] = R;
152     if (std::is_same<Function,
153                      typename std::remove_const<
154                          typename std::remove_pointer<KeyTy>::type>::type>() ||
155         std::is_same<Argument,
156                      typename std::remove_const<
157                          typename std::remove_pointer<KeyTy>::type>::type>())
158       Storage[V].setIsFunc(true);
159     if (std::is_same<GlobalVariable,
160                      typename std::remove_const<
161                          typename std::remove_pointer<KeyTy>::type>::type>())
162       Storage[V].setIsGV(true);
163   }
164 
165   Register find(KeyTy V, const MachineFunction *MF) const {
166     auto iter = Storage.find(V);
167     if (iter != Storage.end()) {
168       auto Map = iter->second;
169       auto iter2 = Map.find(MF);
170       if (iter2 != Map.end())
171         return iter2->second;
172     }
173     return Register();
174   }
175 
176   const StorageTy &getAllUses() const { return Storage; }
177 
178 private:
179   StorageTy &getAllUses() { return Storage; }
180 
181   // The friend class needs to have access to the internal storage
182   // to be able to build dependency graph, can't declare only one
183   // function a 'friend' due to the incomplete declaration at this point
184   // and mutual dependency problems.
185   friend class SPIRVGeneralDuplicatesTracker;
186 };
187 
188 template <typename T>
189 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
190 
191 template <>
192 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
193     : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
194 
195 class SPIRVGeneralDuplicatesTracker {
196   SPIRVDuplicatesTracker<Type> TT;
197   SPIRVDuplicatesTracker<Constant> CT;
198   SPIRVDuplicatesTracker<GlobalVariable> GT;
199   SPIRVDuplicatesTracker<Function> FT;
200   SPIRVDuplicatesTracker<Argument> AT;
201   SPIRVDuplicatesTracker<MachineInstr> MT;
202   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
203 
204   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
205   // to use flat data structure.
206   // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
207   // but makes LITs more stable, should prefer DenseMap still due to
208   // significant perf difference.
209   using SPIRVReg2EntryTy =
210       MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
211 
212   template <typename T>
213   void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
214                          SPIRVReg2EntryTy &Reg2Entry);
215 
216 public:
217   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
218                       MachineModuleInfo *MMI);
219 
220   void add(const Type *Ty, const MachineFunction *MF, Register R) {
221     TT.add(unifyPtrType(Ty), MF, R);
222   }
223 
224   void add(const Type *PointeeTy, unsigned AddressSpace,
225            const MachineFunction *MF, Register R) {
226     ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
227            R);
228   }
229 
230   void add(const Constant *C, const MachineFunction *MF, Register R) {
231     CT.add(C, MF, R);
232   }
233 
234   void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
235     GT.add(GV, MF, R);
236   }
237 
238   void add(const Function *F, const MachineFunction *MF, Register R) {
239     FT.add(F, MF, R);
240   }
241 
242   void add(const Argument *Arg, const MachineFunction *MF, Register R) {
243     AT.add(Arg, MF, R);
244   }
245 
246   void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
247     MT.add(MI, MF, R);
248   }
249 
250   void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
251            Register R) {
252     ST.add(TD, MF, R);
253   }
254 
255   Register find(const Type *Ty, const MachineFunction *MF) {
256     return TT.find(unifyPtrType(Ty), MF);
257   }
258 
259   Register find(const Type *PointeeTy, unsigned AddressSpace,
260                 const MachineFunction *MF) {
261     return ST.find(
262         SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
263   }
264 
265   Register find(const Constant *C, const MachineFunction *MF) {
266     return CT.find(const_cast<Constant *>(C), MF);
267   }
268 
269   Register find(const GlobalVariable *GV, const MachineFunction *MF) {
270     return GT.find(const_cast<GlobalVariable *>(GV), MF);
271   }
272 
273   Register find(const Function *F, const MachineFunction *MF) {
274     return FT.find(const_cast<Function *>(F), MF);
275   }
276 
277   Register find(const Argument *Arg, const MachineFunction *MF) {
278     return AT.find(const_cast<Argument *>(Arg), MF);
279   }
280 
281   Register find(const MachineInstr *MI, const MachineFunction *MF) {
282     return MT.find(const_cast<MachineInstr *>(MI), MF);
283   }
284 
285   Register find(const SPIRV::SpecialTypeDescriptor &TD,
286                 const MachineFunction *MF) {
287     return ST.find(TD, MF);
288   }
289 
290   const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
291 };
292 } // namespace llvm
293 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
294