xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVIRMapping.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===------------ SPIRVMapping.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/MachineModuleInfo.h"
24 
25 #include <type_traits>
26 
27 namespace llvm {
28 namespace SPIRV {
29 
to_hash(const MachineInstr * MI)30 inline size_t to_hash(const MachineInstr *MI) {
31   hash_code H = llvm::hash_combine(MI->getOpcode(), MI->getNumOperands());
32   for (unsigned I = MI->getNumDefs(); I < MI->getNumOperands(); ++I) {
33     const MachineOperand &MO = MI->getOperand(I);
34     if (MO.getType() == MachineOperand::MO_CImmediate)
35       H = llvm::hash_combine(H, MO.getType(), MO.getCImm());
36     else if (MO.getType() == MachineOperand::MO_FPImmediate)
37       H = llvm::hash_combine(H, MO.getType(), MO.getFPImm());
38     else
39       H = llvm::hash_combine(H, MO.getType());
40   }
41   return H;
42 }
43 
44 using MIHandle = std::tuple<const MachineInstr *, Register, size_t>;
45 
getMIKey(const MachineInstr * MI)46 inline MIHandle getMIKey(const MachineInstr *MI) {
47   return std::make_tuple(MI, MI->getOperand(0).getReg(), SPIRV::to_hash(MI));
48 }
49 
50 using IRHandle = std::tuple<const void *, unsigned, unsigned>;
51 using IRHandleMF = std::pair<IRHandle, const MachineFunction *>;
52 
getIRHandleMF(IRHandle Handle,const MachineFunction * MF)53 inline IRHandleMF getIRHandleMF(IRHandle Handle, const MachineFunction *MF) {
54   return std::make_pair(Handle, MF);
55 }
56 
57 enum SpecialTypeKind {
58   STK_Empty = 0,
59   STK_Image,
60   STK_SampledImage,
61   STK_Sampler,
62   STK_Pipe,
63   STK_DeviceEvent,
64   STK_ElementPointer,
65   STK_Type,
66   STK_Value,
67   STK_MachineInstr,
68   STK_VkBuffer,
69   STK_ExplictLayoutType,
70   STK_Last = -1
71 };
72 
73 union ImageAttrs {
74   struct BitFlags {
75     unsigned Dim : 3;
76     unsigned Depth : 2;
77     unsigned Arrayed : 1;
78     unsigned MS : 1;
79     unsigned Sampled : 2;
80     unsigned ImageFormat : 6;
81     unsigned AQ : 2;
82   } Flags;
83   unsigned Val;
84 
85   ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
86              unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
87     Val = 0;
88     Flags.Dim = Dim;
89     Flags.Depth = Depth;
90     Flags.Arrayed = Arrayed;
91     Flags.MS = MS;
92     Flags.Sampled = Sampled;
93     Flags.ImageFormat = ImageFormat;
94     Flags.AQ = AQ;
95   }
96 };
97 
98 inline IRHandle irhandle_image(const Type *SampledTy, unsigned Dim,
99                                unsigned Depth, unsigned Arrayed, unsigned MS,
100                                unsigned Sampled, unsigned ImageFormat,
101                                unsigned AQ = 0) {
102   return std::make_tuple(
103       SampledTy,
104       ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
105       SpecialTypeKind::STK_Image);
106 }
107 
irhandle_sampled_image(const Type * SampledTy,const MachineInstr * ImageTy)108 inline IRHandle irhandle_sampled_image(const Type *SampledTy,
109                                        const MachineInstr *ImageTy) {
110   assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
111   unsigned AC = AccessQualifier::AccessQualifier::None;
112   if (ImageTy->getNumOperands() > 8)
113     AC = ImageTy->getOperand(8).getImm();
114   return std::make_tuple(
115       SampledTy,
116       ImageAttrs(
117           ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
118           ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
119           ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(), AC)
120           .Val,
121       SpecialTypeKind::STK_SampledImage);
122 }
123 
irhandle_sampler()124 inline IRHandle irhandle_sampler() {
125   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
126 }
127 
irhandle_pipe(uint8_t AQ)128 inline IRHandle irhandle_pipe(uint8_t AQ) {
129   return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
130 }
131 
irhandle_event()132 inline IRHandle irhandle_event() {
133   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
134 }
135 
irhandle_pointee(const Type * ElementType,unsigned AddressSpace)136 inline IRHandle irhandle_pointee(const Type *ElementType,
137                                  unsigned AddressSpace) {
138   return std::make_tuple(unifyPtrType(ElementType), AddressSpace,
139                          SpecialTypeKind::STK_ElementPointer);
140 }
141 
irhandle_ptr(const void * Ptr,unsigned Arg,enum SpecialTypeKind STK)142 inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg,
143                              enum SpecialTypeKind STK) {
144   return std::make_tuple(Ptr, Arg, STK);
145 }
146 
irhandle_vkbuffer(const Type * ElementType,StorageClass::StorageClass SC,bool IsWriteable)147 inline IRHandle irhandle_vkbuffer(const Type *ElementType,
148                                   StorageClass::StorageClass SC,
149                                   bool IsWriteable) {
150   return std::make_tuple(ElementType, (SC << 1) | IsWriteable,
151                          SpecialTypeKind::STK_VkBuffer);
152 }
153 
irhandle_explict_layout_type(const Type * Ty)154 inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
155   const Type *WrpTy = unifyPtrType(Ty);
156   return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType);
157 }
158 
handle(const Type * Ty)159 inline IRHandle handle(const Type *Ty) {
160   const Type *WrpTy = unifyPtrType(Ty);
161   return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
162 }
163 
handle(const Value * V)164 inline IRHandle handle(const Value *V) {
165   return irhandle_ptr(V, V->getValueID(), STK_Value);
166 }
167 
handle(const MachineInstr * KeyMI)168 inline IRHandle handle(const MachineInstr *KeyMI) {
169   return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
170 }
171 
type_has_layout_decoration(const Type * T)172 inline bool type_has_layout_decoration(const Type *T) {
173   return (isa<StructType>(T) || isa<ArrayType>(T));
174 }
175 
176 } // namespace SPIRV
177 
178 // Bi-directional mappings between LLVM entities and (v-reg, machine function)
179 // pairs support management of unique SPIR-V definitions per machine function
180 // per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction).
181 class SPIRVIRMapping {
182   DenseMap<SPIRV::IRHandleMF, SPIRV::MIHandle> Vregs;
183   DenseMap<const MachineInstr *, SPIRV::IRHandleMF> Defs;
184 
185 public:
add(SPIRV::IRHandle Handle,const MachineInstr * MI)186   bool add(SPIRV::IRHandle Handle, const MachineInstr *MI) {
187     if (auto DefIt = Defs.find(MI); DefIt != Defs.end()) {
188       auto [ExistHandle, ExistMF] = DefIt->second;
189       if (Handle == ExistHandle && MI->getMF() == ExistMF)
190         return false; // already exists
191       // invalidate the record
192       Vregs.erase(DefIt->second);
193       Defs.erase(DefIt);
194     }
195     SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MI->getMF());
196     SPIRV::MIHandle MIKey = SPIRV::getMIKey(MI);
197     auto It1 = Vregs.try_emplace(HandleMF, MIKey);
198     if (!It1.second) {
199       // there is an expired record that we need to invalidate
200       Defs.erase(std::get<0>(It1.first->second));
201       // update the record
202       It1.first->second = MIKey;
203     }
204     [[maybe_unused]] auto It2 = Defs.try_emplace(MI, HandleMF);
205     assert(It2.second);
206     return true;
207   }
erase(const MachineInstr * MI)208   bool erase(const MachineInstr *MI) {
209     bool Res = false;
210     if (auto It = Defs.find(MI); It != Defs.end()) {
211       Res = Vregs.erase(It->second);
212       Defs.erase(It);
213     }
214     return Res;
215   }
findMI(SPIRV::IRHandle Handle,const MachineFunction * MF)216   const MachineInstr *findMI(SPIRV::IRHandle Handle,
217                              const MachineFunction *MF) {
218     SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF);
219     auto It = Vregs.find(HandleMF);
220     if (It == Vregs.end())
221       return nullptr;
222     auto [MI, Reg, Hash] = It->second;
223     const MachineInstr *Def = MF->getRegInfo().getVRegDef(Reg);
224     if (!Def || Def != MI || SPIRV::to_hash(MI) != Hash) {
225       // there is an expired record that we need to invalidate
226       erase(MI);
227       return nullptr;
228     }
229     assert(Defs.contains(MI) && Defs.find(MI)->second == HandleMF);
230     return MI;
231   }
find(SPIRV::IRHandle Handle,const MachineFunction * MF)232   Register find(SPIRV::IRHandle Handle, const MachineFunction *MF) {
233     const MachineInstr *MI = findMI(Handle, MF);
234     return MI ? MI->getOperand(0).getReg() : Register();
235   }
236 
237   // helpers
add(const Type * PointeeTy,unsigned AddressSpace,const MachineInstr * MI)238   bool add(const Type *PointeeTy, unsigned AddressSpace,
239            const MachineInstr *MI) {
240     return add(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MI);
241   }
find(const Type * PointeeTy,unsigned AddressSpace,const MachineFunction * MF)242   Register find(const Type *PointeeTy, unsigned AddressSpace,
243                 const MachineFunction *MF) {
244     return find(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
245   }
findMI(const Type * PointeeTy,unsigned AddressSpace,const MachineFunction * MF)246   const MachineInstr *findMI(const Type *PointeeTy, unsigned AddressSpace,
247                              const MachineFunction *MF) {
248     return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
249   }
250 
add(const Value * V,const MachineInstr * MI)251   bool add(const Value *V, const MachineInstr *MI) {
252     return add(SPIRV::handle(V), MI);
253   }
254 
add(const Type * T,bool RequiresExplicitLayout,const MachineInstr * MI)255   bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
256     if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
257       return add(SPIRV::irhandle_explict_layout_type(T), MI);
258     }
259     return add(SPIRV::handle(T), MI);
260   }
261 
add(const MachineInstr * Obj,const MachineInstr * MI)262   bool add(const MachineInstr *Obj, const MachineInstr *MI) {
263     return add(SPIRV::handle(Obj), MI);
264   }
265 
find(const Value * V,const MachineFunction * MF)266   Register find(const Value *V, const MachineFunction *MF) {
267     return find(SPIRV::handle(V), MF);
268   }
269 
find(const Type * T,bool RequiresExplicitLayout,const MachineFunction * MF)270   Register find(const Type *T, bool RequiresExplicitLayout,
271                 const MachineFunction *MF) {
272     if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
273       return find(SPIRV::irhandle_explict_layout_type(T), MF);
274     return find(SPIRV::handle(T), MF);
275   }
276 
find(const MachineInstr * MI,const MachineFunction * MF)277   Register find(const MachineInstr *MI, const MachineFunction *MF) {
278     return find(SPIRV::handle(MI), MF);
279   }
280 
findMI(const Value * Obj,const MachineFunction * MF)281   const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
282     return findMI(SPIRV::handle(Obj), MF);
283   }
284 
findMI(const Type * T,bool RequiresExplicitLayout,const MachineFunction * MF)285   const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
286                              const MachineFunction *MF) {
287     if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
288       return findMI(SPIRV::irhandle_explict_layout_type(T), MF);
289     return findMI(SPIRV::handle(T), MF);
290   }
291 
findMI(const MachineInstr * Obj,const MachineFunction * MF)292   const MachineInstr *findMI(const MachineInstr *Obj,
293                              const MachineFunction *MF) {
294     return findMI(SPIRV::handle(Obj), MF);
295   }
296 };
297 } // namespace llvm
298 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
299