1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // General infrastructure for keeping track of the values that according to 10 // the SPIR-V binary layout should be global to the whole module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H 15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H 16 17 #include "MCTargetDesc/SPIRVBaseInfo.h" 18 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/MapVector.h" 21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 22 #include "llvm/CodeGen/MachineModuleInfo.h" 23 24 #include <type_traits> 25 26 namespace llvm { 27 namespace SPIRV { 28 // NOTE: using MapVector instead of DenseMap because it helps getting 29 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize 30 // memory and expensive removals which do not happen anyway. 31 class DTSortableEntry : public MapVector<const MachineFunction *, Register> { 32 SmallVector<DTSortableEntry *, 2> Deps; 33 34 struct FlagsTy { 35 unsigned IsFunc : 1; 36 unsigned IsGV : 1; 37 // NOTE: bit-field default init is a C++20 feature. 38 FlagsTy() : IsFunc(0), IsGV(0) {} 39 }; 40 FlagsTy Flags; 41 42 public: 43 // Common hoisting utility doesn't support function, because their hoisting 44 // require hoisting of params as well. 45 bool getIsFunc() const { return Flags.IsFunc; } 46 bool getIsGV() const { return Flags.IsGV; } 47 void setIsFunc(bool V) { Flags.IsFunc = V; } 48 void setIsGV(bool V) { Flags.IsGV = V; } 49 50 const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; } 51 void addDep(DTSortableEntry *E) { Deps.push_back(E); } 52 }; 53 54 struct SpecialTypeDescriptor { 55 enum SpecialTypeKind { 56 STK_Empty = 0, 57 STK_Image, 58 STK_SampledImage, 59 STK_Sampler, 60 STK_Pipe, 61 STK_DeviceEvent, 62 STK_Last = -1 63 }; 64 SpecialTypeKind Kind; 65 66 unsigned Hash; 67 68 SpecialTypeDescriptor() = delete; 69 SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; } 70 71 unsigned getHash() const { return Hash; } 72 73 virtual ~SpecialTypeDescriptor() {} 74 }; 75 76 struct ImageTypeDescriptor : public SpecialTypeDescriptor { 77 union ImageAttrs { 78 struct BitFlags { 79 unsigned Dim : 3; 80 unsigned Depth : 2; 81 unsigned Arrayed : 1; 82 unsigned MS : 1; 83 unsigned Sampled : 2; 84 unsigned ImageFormat : 6; 85 unsigned AQ : 2; 86 } Flags; 87 unsigned Val; 88 }; 89 90 ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth, 91 unsigned Arrayed, unsigned MS, unsigned Sampled, 92 unsigned ImageFormat, unsigned AQ = 0) 93 : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) { 94 ImageAttrs Attrs; 95 Attrs.Val = 0; 96 Attrs.Flags.Dim = Dim; 97 Attrs.Flags.Depth = Depth; 98 Attrs.Flags.Arrayed = Arrayed; 99 Attrs.Flags.MS = MS; 100 Attrs.Flags.Sampled = Sampled; 101 Attrs.Flags.ImageFormat = ImageFormat; 102 Attrs.Flags.AQ = AQ; 103 Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^ 104 ((Attrs.Val << 8) | Kind); 105 } 106 107 static bool classof(const SpecialTypeDescriptor *TD) { 108 return TD->Kind == SpecialTypeKind::STK_Image; 109 } 110 }; 111 112 struct SampledImageTypeDescriptor : public SpecialTypeDescriptor { 113 SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy) 114 : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) { 115 assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); 116 ImageTypeDescriptor TD( 117 SampledTy, ImageTy->getOperand(2).getImm(), 118 ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(), 119 ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(), 120 ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm()); 121 Hash = TD.getHash() ^ Kind; 122 } 123 124 static bool classof(const SpecialTypeDescriptor *TD) { 125 return TD->Kind == SpecialTypeKind::STK_SampledImage; 126 } 127 }; 128 129 struct SamplerTypeDescriptor : public SpecialTypeDescriptor { 130 SamplerTypeDescriptor() 131 : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) { 132 Hash = Kind; 133 } 134 135 static bool classof(const SpecialTypeDescriptor *TD) { 136 return TD->Kind == SpecialTypeKind::STK_Sampler; 137 } 138 }; 139 140 struct PipeTypeDescriptor : public SpecialTypeDescriptor { 141 142 PipeTypeDescriptor(uint8_t AQ) 143 : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) { 144 Hash = (AQ << 8) | Kind; 145 } 146 147 static bool classof(const SpecialTypeDescriptor *TD) { 148 return TD->Kind == SpecialTypeKind::STK_Pipe; 149 } 150 }; 151 152 struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor { 153 154 DeviceEventTypeDescriptor() 155 : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) { 156 Hash = Kind; 157 } 158 159 static bool classof(const SpecialTypeDescriptor *TD) { 160 return TD->Kind == SpecialTypeKind::STK_DeviceEvent; 161 } 162 }; 163 } // namespace SPIRV 164 165 template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> { 166 static inline SPIRV::SpecialTypeDescriptor getEmptyKey() { 167 return SPIRV::SpecialTypeDescriptor( 168 SPIRV::SpecialTypeDescriptor::STK_Empty); 169 } 170 static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() { 171 return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last); 172 } 173 static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) { 174 return Val.getHash(); 175 } 176 static bool isEqual(SPIRV::SpecialTypeDescriptor LHS, 177 SPIRV::SpecialTypeDescriptor RHS) { 178 return getHashValue(LHS) == getHashValue(RHS); 179 } 180 }; 181 182 template <typename KeyTy> class SPIRVDuplicatesTrackerBase { 183 public: 184 // NOTE: using MapVector instead of DenseMap helps getting everything ordered 185 // in a stable manner for a price of extra (NumKeys)*PtrSize memory and 186 // expensive removals which don't happen anyway. 187 using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>; 188 189 private: 190 StorageTy Storage; 191 192 public: 193 void add(KeyTy V, const MachineFunction *MF, Register R) { 194 if (find(V, MF).isValid()) 195 return; 196 197 Storage[V][MF] = R; 198 if (std::is_same<Function, 199 typename std::remove_const< 200 typename std::remove_pointer<KeyTy>::type>::type>() || 201 std::is_same<Argument, 202 typename std::remove_const< 203 typename std::remove_pointer<KeyTy>::type>::type>()) 204 Storage[V].setIsFunc(true); 205 if (std::is_same<GlobalVariable, 206 typename std::remove_const< 207 typename std::remove_pointer<KeyTy>::type>::type>()) 208 Storage[V].setIsGV(true); 209 } 210 211 Register find(KeyTy V, const MachineFunction *MF) const { 212 auto iter = Storage.find(V); 213 if (iter != Storage.end()) { 214 auto Map = iter->second; 215 auto iter2 = Map.find(MF); 216 if (iter2 != Map.end()) 217 return iter2->second; 218 } 219 return Register(); 220 } 221 222 const StorageTy &getAllUses() const { return Storage; } 223 224 private: 225 StorageTy &getAllUses() { return Storage; } 226 227 // The friend class needs to have access to the internal storage 228 // to be able to build dependency graph, can't declare only one 229 // function a 'friend' due to the incomplete declaration at this point 230 // and mutual dependency problems. 231 friend class SPIRVGeneralDuplicatesTracker; 232 }; 233 234 template <typename T> 235 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {}; 236 237 template <> 238 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> 239 : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {}; 240 241 class SPIRVGeneralDuplicatesTracker { 242 SPIRVDuplicatesTracker<Type> TT; 243 SPIRVDuplicatesTracker<Constant> CT; 244 SPIRVDuplicatesTracker<GlobalVariable> GT; 245 SPIRVDuplicatesTracker<Function> FT; 246 SPIRVDuplicatesTracker<Argument> AT; 247 SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST; 248 249 // NOTE: using MOs instead of regs to get rid of MF dependency to be able 250 // to use flat data structure. 251 // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness 252 // but makes LITs more stable, should prefer DenseMap still due to 253 // significant perf difference. 254 using SPIRVReg2EntryTy = 255 MapVector<MachineOperand *, SPIRV::DTSortableEntry *>; 256 257 template <typename T> 258 void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT, 259 SPIRVReg2EntryTy &Reg2Entry); 260 261 public: 262 void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph, 263 MachineModuleInfo *MMI); 264 265 void add(const Type *T, const MachineFunction *MF, Register R) { 266 TT.add(T, MF, R); 267 } 268 269 void add(const Constant *C, const MachineFunction *MF, Register R) { 270 CT.add(C, MF, R); 271 } 272 273 void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) { 274 GT.add(GV, MF, R); 275 } 276 277 void add(const Function *F, const MachineFunction *MF, Register R) { 278 FT.add(F, MF, R); 279 } 280 281 void add(const Argument *Arg, const MachineFunction *MF, Register R) { 282 AT.add(Arg, MF, R); 283 } 284 285 void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF, 286 Register R) { 287 ST.add(TD, MF, R); 288 } 289 290 Register find(const Type *T, const MachineFunction *MF) { 291 return TT.find(const_cast<Type *>(T), MF); 292 } 293 294 Register find(const Constant *C, const MachineFunction *MF) { 295 return CT.find(const_cast<Constant *>(C), MF); 296 } 297 298 Register find(const GlobalVariable *GV, const MachineFunction *MF) { 299 return GT.find(const_cast<GlobalVariable *>(GV), MF); 300 } 301 302 Register find(const Function *F, const MachineFunction *MF) { 303 return FT.find(const_cast<Function *>(F), MF); 304 } 305 306 Register find(const Argument *Arg, const MachineFunction *MF) { 307 return AT.find(const_cast<Argument *>(Arg), MF); 308 } 309 310 Register find(const SPIRV::SpecialTypeDescriptor &TD, 311 const MachineFunction *MF) { 312 return ST.find(TD, MF); 313 } 314 315 const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; } 316 }; 317 } // namespace llvm 318 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H 319