1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/MachineModuleInfo.h"
24
25 #include <type_traits>
26
27 namespace llvm {
28 namespace SPIRV {
29 // NOTE: using MapVector instead of DenseMap because it helps getting
30 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
31 // memory and expensive removals which do not happen anyway.
32 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
33 SmallVector<DTSortableEntry *, 2> Deps;
34
35 struct FlagsTy {
36 unsigned IsFunc : 1;
37 unsigned IsGV : 1;
38 // NOTE: bit-field default init is a C++20 feature.
FlagsTyFlagsTy39 FlagsTy() : IsFunc(0), IsGV(0) {}
40 };
41 FlagsTy Flags;
42
43 public:
44 // Common hoisting utility doesn't support function, because their hoisting
45 // require hoisting of params as well.
getIsFunc()46 bool getIsFunc() const { return Flags.IsFunc; }
getIsGV()47 bool getIsGV() const { return Flags.IsGV; }
setIsFunc(bool V)48 void setIsFunc(bool V) { Flags.IsFunc = V; }
setIsGV(bool V)49 void setIsGV(bool V) { Flags.IsGV = V; }
50
getDeps()51 const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
addDep(DTSortableEntry * E)52 void addDep(DTSortableEntry *E) { Deps.push_back(E); }
53 };
54
55 enum SpecialTypeKind {
56 STK_Empty = 0,
57 STK_Image,
58 STK_SampledImage,
59 STK_Sampler,
60 STK_Pipe,
61 STK_DeviceEvent,
62 STK_Pointer,
63 STK_Last = -1
64 };
65
66 using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
67
68 union ImageAttrs {
69 struct BitFlags {
70 unsigned Dim : 3;
71 unsigned Depth : 2;
72 unsigned Arrayed : 1;
73 unsigned MS : 1;
74 unsigned Sampled : 2;
75 unsigned ImageFormat : 6;
76 unsigned AQ : 2;
77 } Flags;
78 unsigned Val;
79
80 ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
81 unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
82 Val = 0;
83 Flags.Dim = Dim;
84 Flags.Depth = Depth;
85 Flags.Arrayed = Arrayed;
86 Flags.MS = MS;
87 Flags.Sampled = Sampled;
88 Flags.ImageFormat = ImageFormat;
89 Flags.AQ = AQ;
90 }
91 };
92
93 inline SpecialTypeDescriptor
94 make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
95 unsigned Arrayed, unsigned MS, unsigned Sampled,
96 unsigned ImageFormat, unsigned AQ = 0) {
97 return std::make_tuple(
98 SampledTy,
99 ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
100 SpecialTypeKind::STK_Image);
101 }
102
103 inline SpecialTypeDescriptor
make_descr_sampled_image(const Type * SampledTy,const MachineInstr * ImageTy)104 make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
105 assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
106 return std::make_tuple(
107 SampledTy,
108 ImageAttrs(
109 ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
110 ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
111 ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
112 ImageTy->getOperand(8).getImm())
113 .Val,
114 SpecialTypeKind::STK_SampledImage);
115 }
116
make_descr_sampler()117 inline SpecialTypeDescriptor make_descr_sampler() {
118 return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
119 }
120
make_descr_pipe(uint8_t AQ)121 inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
122 return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
123 }
124
make_descr_event()125 inline SpecialTypeDescriptor make_descr_event() {
126 return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
127 }
128
make_descr_pointee(const Type * ElementType,unsigned AddressSpace)129 inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
130 unsigned AddressSpace) {
131 return std::make_tuple(ElementType, AddressSpace,
132 SpecialTypeKind::STK_Pointer);
133 }
134 } // namespace SPIRV
135
136 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
137 public:
138 // NOTE: using MapVector instead of DenseMap helps getting everything ordered
139 // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
140 // expensive removals which don't happen anyway.
141 using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
142
143 private:
144 StorageTy Storage;
145
146 public:
add(KeyTy V,const MachineFunction * MF,Register R)147 void add(KeyTy V, const MachineFunction *MF, Register R) {
148 if (find(V, MF).isValid())
149 return;
150
151 Storage[V][MF] = R;
152 if (std::is_same<Function,
153 typename std::remove_const<
154 typename std::remove_pointer<KeyTy>::type>::type>() ||
155 std::is_same<Argument,
156 typename std::remove_const<
157 typename std::remove_pointer<KeyTy>::type>::type>())
158 Storage[V].setIsFunc(true);
159 if (std::is_same<GlobalVariable,
160 typename std::remove_const<
161 typename std::remove_pointer<KeyTy>::type>::type>())
162 Storage[V].setIsGV(true);
163 }
164
find(KeyTy V,const MachineFunction * MF)165 Register find(KeyTy V, const MachineFunction *MF) const {
166 auto iter = Storage.find(V);
167 if (iter != Storage.end()) {
168 auto Map = iter->second;
169 auto iter2 = Map.find(MF);
170 if (iter2 != Map.end())
171 return iter2->second;
172 }
173 return Register();
174 }
175
getAllUses()176 const StorageTy &getAllUses() const { return Storage; }
177
178 private:
getAllUses()179 StorageTy &getAllUses() { return Storage; }
180
181 // The friend class needs to have access to the internal storage
182 // to be able to build dependency graph, can't declare only one
183 // function a 'friend' due to the incomplete declaration at this point
184 // and mutual dependency problems.
185 friend class SPIRVGeneralDuplicatesTracker;
186 };
187
188 template <typename T>
189 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
190
191 template <>
192 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
193 : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
194
195 class SPIRVGeneralDuplicatesTracker {
196 SPIRVDuplicatesTracker<Type> TT;
197 SPIRVDuplicatesTracker<Constant> CT;
198 SPIRVDuplicatesTracker<GlobalVariable> GT;
199 SPIRVDuplicatesTracker<Function> FT;
200 SPIRVDuplicatesTracker<Argument> AT;
201 SPIRVDuplicatesTracker<MachineInstr> MT;
202 SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
203
204 // NOTE: using MOs instead of regs to get rid of MF dependency to be able
205 // to use flat data structure.
206 // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
207 // but makes LITs more stable, should prefer DenseMap still due to
208 // significant perf difference.
209 using SPIRVReg2EntryTy =
210 MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
211
212 template <typename T>
213 void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
214 SPIRVReg2EntryTy &Reg2Entry);
215
216 public:
217 void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
218 MachineModuleInfo *MMI);
219
add(const Type * Ty,const MachineFunction * MF,Register R)220 void add(const Type *Ty, const MachineFunction *MF, Register R) {
221 TT.add(unifyPtrType(Ty), MF, R);
222 }
223
add(const Type * PointeeTy,unsigned AddressSpace,const MachineFunction * MF,Register R)224 void add(const Type *PointeeTy, unsigned AddressSpace,
225 const MachineFunction *MF, Register R) {
226 ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
227 R);
228 }
229
add(const Constant * C,const MachineFunction * MF,Register R)230 void add(const Constant *C, const MachineFunction *MF, Register R) {
231 CT.add(C, MF, R);
232 }
233
add(const GlobalVariable * GV,const MachineFunction * MF,Register R)234 void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
235 GT.add(GV, MF, R);
236 }
237
add(const Function * F,const MachineFunction * MF,Register R)238 void add(const Function *F, const MachineFunction *MF, Register R) {
239 FT.add(F, MF, R);
240 }
241
add(const Argument * Arg,const MachineFunction * MF,Register R)242 void add(const Argument *Arg, const MachineFunction *MF, Register R) {
243 AT.add(Arg, MF, R);
244 }
245
add(const MachineInstr * MI,const MachineFunction * MF,Register R)246 void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
247 MT.add(MI, MF, R);
248 }
249
add(const SPIRV::SpecialTypeDescriptor & TD,const MachineFunction * MF,Register R)250 void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
251 Register R) {
252 ST.add(TD, MF, R);
253 }
254
find(const Type * Ty,const MachineFunction * MF)255 Register find(const Type *Ty, const MachineFunction *MF) {
256 return TT.find(unifyPtrType(Ty), MF);
257 }
258
find(const Type * PointeeTy,unsigned AddressSpace,const MachineFunction * MF)259 Register find(const Type *PointeeTy, unsigned AddressSpace,
260 const MachineFunction *MF) {
261 return ST.find(
262 SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
263 }
264
find(const Constant * C,const MachineFunction * MF)265 Register find(const Constant *C, const MachineFunction *MF) {
266 return CT.find(const_cast<Constant *>(C), MF);
267 }
268
find(const GlobalVariable * GV,const MachineFunction * MF)269 Register find(const GlobalVariable *GV, const MachineFunction *MF) {
270 return GT.find(const_cast<GlobalVariable *>(GV), MF);
271 }
272
find(const Function * F,const MachineFunction * MF)273 Register find(const Function *F, const MachineFunction *MF) {
274 return FT.find(const_cast<Function *>(F), MF);
275 }
276
find(const Argument * Arg,const MachineFunction * MF)277 Register find(const Argument *Arg, const MachineFunction *MF) {
278 return AT.find(const_cast<Argument *>(Arg), MF);
279 }
280
find(const MachineInstr * MI,const MachineFunction * MF)281 Register find(const MachineInstr *MI, const MachineFunction *MF) {
282 return MT.find(const_cast<MachineInstr *>(MI), MF);
283 }
284
find(const SPIRV::SpecialTypeDescriptor & TD,const MachineFunction * MF)285 Register find(const SPIRV::SpecialTypeDescriptor &TD,
286 const MachineFunction *MF) {
287 return ST.find(TD, MF);
288 }
289
getTypes()290 const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
291 };
292 } // namespace llvm
293 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
294