xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (revision 3f0efe05432b1633991114ca4ca330102a561959)
1 //===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SPIRVGlobalRegistry is used to maintain rich type information required for
10 // SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type
11 // into an OpTypeXXX instruction, and map it to a virtual register. Also it
12 // builds and supports consistency of constants and global variables.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
17 #define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
18 
19 #include "MCTargetDesc/SPIRVBaseInfo.h"
20 #include "SPIRVDuplicatesTracker.h"
21 #include "SPIRVInstrInfo.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 
24 namespace llvm {
25 using SPIRVType = const MachineInstr;
26 
27 class SPIRVGlobalRegistry {
28   // Registers holding values which have types associated with them.
29   // Initialized upon VReg definition in IRTranslator.
30   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
31   // where Reg = OpType...
32   // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not
33   // type-declaring ones).
34   DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
35       VRegToTypeMap;
36 
37   SPIRVGeneralDuplicatesTracker DT;
38 
39   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
40 
41   // Look for an equivalent of the newType in the map. Return the equivalent
42   // if it's found, otherwise insert newType to the map and return the type.
43   const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
44                                         MachineIRBuilder &MIRBuilder);
45 
46   SmallPtrSet<const Type *, 4> TypesInProcessing;
47   DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
48 
49   // Number of bits pointers and size_t integers require.
50   const unsigned PointerSize;
51 
52   // Add a new OpTypeXXX instruction without checking for duplicates.
53   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
54                              SPIRV::AccessQualifier::AccessQualifier AQ =
55                                  SPIRV::AccessQualifier::ReadWrite,
56                              bool EmitIR = true);
57   SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
58                            SPIRV::AccessQualifier::AccessQualifier accessQual =
59                                SPIRV::AccessQualifier::ReadWrite,
60                            bool EmitIR = true);
61   SPIRVType *
62   restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
63                         SPIRV::AccessQualifier::AccessQualifier AccessQual,
64                         bool EmitIR);
65 
66 public:
67   SPIRVGlobalRegistry(unsigned PointerSize);
68 
69   MachineFunction *CurMF;
70 
71   void add(const Constant *C, MachineFunction *MF, Register R) {
72     DT.add(C, MF, R);
73   }
74 
75   void add(const GlobalVariable *GV, MachineFunction *MF, Register R) {
76     DT.add(GV, MF, R);
77   }
78 
79   void add(const Function *F, MachineFunction *MF, Register R) {
80     DT.add(F, MF, R);
81   }
82 
83   void add(const Argument *Arg, MachineFunction *MF, Register R) {
84     DT.add(Arg, MF, R);
85   }
86 
87   Register find(const Constant *C, MachineFunction *MF) {
88     return DT.find(C, MF);
89   }
90 
91   Register find(const GlobalVariable *GV, MachineFunction *MF) {
92     return DT.find(GV, MF);
93   }
94 
95   Register find(const Function *F, MachineFunction *MF) {
96     return DT.find(F, MF);
97   }
98 
99   void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
100                       MachineModuleInfo *MMI = nullptr) {
101     DT.buildDepsGraph(Graph, MMI);
102   }
103 
104   // Get or create a SPIR-V type corresponding the given LLVM IR type,
105   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
106   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
107                               MachineIRBuilder &MIRBuilder,
108                               SPIRV::AccessQualifier::AccessQualifier AQ =
109                                   SPIRV::AccessQualifier::ReadWrite,
110                               bool EmitIR = true);
111   SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
112                                  MachineInstr &I, const SPIRVInstrInfo &TII);
113   SPIRVType *assignVectTypeToVReg(SPIRVType *BaseType, unsigned NumElements,
114                                   Register VReg, MachineInstr &I,
115                                   const SPIRVInstrInfo &TII);
116 
117   // In cases where the SPIR-V type is already known, this function can be
118   // used to map it to the given VReg via an ASSIGN_TYPE instruction.
119   void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg,
120                              MachineFunction &MF);
121 
122   // Either generate a new OpTypeXXX instruction or return an existing one
123   // corresponding to the given LLVM IR type.
124   // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
125   // because this method may be called from InstructionSelector and we don't
126   // want to emit extra IR instructions there.
127   SPIRVType *getOrCreateSPIRVType(const Type *Type,
128                                   MachineIRBuilder &MIRBuilder,
129                                   SPIRV::AccessQualifier::AccessQualifier AQ =
130                                       SPIRV::AccessQualifier::ReadWrite,
131                                   bool EmitIR = true);
132 
133   const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
134     auto Res = SPIRVToLLVMType.find(Ty);
135     assert(Res != SPIRVToLLVMType.end());
136     return Res->second;
137   }
138 
139   // Either generate a new OpTypeXXX instruction or return an existing one
140   // corresponding to the given string containing the name of the builtin type.
141   SPIRVType *getOrCreateSPIRVTypeByName(
142       StringRef TypeStr, MachineIRBuilder &MIRBuilder,
143       SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
144       SPIRV::AccessQualifier::AccessQualifier AQ =
145           SPIRV::AccessQualifier::ReadWrite);
146 
147   // Return the SPIR-V type instruction corresponding to the given VReg, or
148   // nullptr if no such type instruction exists.
149   SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
150 
151   // Whether the given VReg has a SPIR-V type mapped to it yet.
152   bool hasSPIRVTypeForVReg(Register VReg) const {
153     return getSPIRVTypeForVReg(VReg) != nullptr;
154   }
155 
156   // Return the VReg holding the result of the given OpTypeXXX instruction.
157   Register getSPIRVTypeID(const SPIRVType *SpirvType) const;
158 
159   void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; }
160 
161   // Whether the given VReg has an OpTypeXXX instruction mapped to it with the
162   // given opcode (e.g. OpTypeFloat).
163   bool isScalarOfType(Register VReg, unsigned TypeOpcode) const;
164 
165   // Return true if the given VReg's assigned SPIR-V type is either a scalar
166   // matching the given opcode, or a vector with an element type matching that
167   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
168   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
169 
170   // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
171   unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
172 
173   // For integer vectors or scalars, return whether the integers are signed.
174   bool isScalarOrVectorSigned(const SPIRVType *Type) const;
175 
176   // Gets the storage class of the pointer type assigned to this vreg.
177   SPIRV::StorageClass::StorageClass getPointerStorageClass(Register VReg) const;
178 
179   // Return the number of bits SPIR-V pointers and size_t variables require.
180   unsigned getPointerSize() const { return PointerSize; }
181 
182 private:
183   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
184 
185   SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
186                           bool IsSigned = false);
187 
188   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
189 
190   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
191 
192   SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
193                              MachineIRBuilder &MIRBuilder);
194 
195   SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
196                             MachineIRBuilder &MIRBuilder, bool EmitIR = true);
197 
198   SPIRVType *getOpTypeOpaque(const StructType *Ty,
199                              MachineIRBuilder &MIRBuilder);
200 
201   SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
202                              bool EmitIR = true);
203 
204   SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
205                               SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
206                               Register Reg);
207 
208   SPIRVType *getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,
209                                      MachineIRBuilder &MIRBuilder);
210 
211   SPIRVType *getOpTypeFunction(SPIRVType *RetType,
212                                const SmallVectorImpl<SPIRVType *> &ArgTypes,
213                                MachineIRBuilder &MIRBuilder);
214 
215   SPIRVType *
216   getOrCreateSpecialType(const Type *Ty, MachineIRBuilder &MIRBuilder,
217                          SPIRV::AccessQualifier::AccessQualifier AccQual);
218 
219   std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
220       uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
221       MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
222   SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
223   Register getOrCreateIntCompositeOrNull(uint64_t Val, MachineInstr &I,
224                                          SPIRVType *SpvType,
225                                          const SPIRVInstrInfo &TII,
226                                          Constant *CA, unsigned BitWidth,
227                                          unsigned ElemCnt);
228   Register getOrCreateIntCompositeOrNull(uint64_t Val,
229                                          MachineIRBuilder &MIRBuilder,
230                                          SPIRVType *SpvType, bool EmitIR,
231                                          Constant *CA, unsigned BitWidth,
232                                          unsigned ElemCnt);
233 
234 public:
235   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
236                             SPIRVType *SpvType = nullptr, bool EmitIR = true);
237   Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
238                                SPIRVType *SpvType, const SPIRVInstrInfo &TII);
239   Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
240                            SPIRVType *SpvType = nullptr);
241   Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
242                                     SPIRVType *SpvType,
243                                     const SPIRVInstrInfo &TII);
244   Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
245                                    SPIRVType *SpvType,
246                                    const SPIRVInstrInfo &TII);
247   Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
248                                     SPIRVType *SpvType, bool EmitIR = true);
249   Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
250                                    SPIRVType *SpvType, bool EmitIR = true);
251   Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
252                                    SPIRVType *SpvType);
253   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
254                                 unsigned FilerMode,
255                                 MachineIRBuilder &MIRBuilder,
256                                 SPIRVType *SpvType);
257   Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType,
258                             const SPIRVInstrInfo &TII);
259   Register buildGlobalVariable(Register Reg, SPIRVType *BaseType,
260                                StringRef Name, const GlobalValue *GV,
261                                SPIRV::StorageClass::StorageClass Storage,
262                                const MachineInstr *Init, bool IsConst,
263                                bool HasLinkageTy,
264                                SPIRV::LinkageType::LinkageType LinkageType,
265                                MachineIRBuilder &MIRBuilder,
266                                bool IsInstSelector);
267 
268   // Convenient helpers for getting types with check for duplicates.
269   SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth,
270                                          MachineIRBuilder &MIRBuilder);
271   SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I,
272                                          const SPIRVInstrInfo &TII);
273   SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
274   SPIRVType *getOrCreateSPIRVBoolType(MachineInstr &I,
275                                       const SPIRVInstrInfo &TII);
276   SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
277                                         unsigned NumElements,
278                                         MachineIRBuilder &MIRBuilder);
279   SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
280                                         unsigned NumElements, MachineInstr &I,
281                                         const SPIRVInstrInfo &TII);
282   SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType,
283                                        unsigned NumElements, MachineInstr &I,
284                                        const SPIRVInstrInfo &TII);
285 
286   SPIRVType *getOrCreateSPIRVPointerType(
287       SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
288       SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
289   SPIRVType *getOrCreateSPIRVPointerType(
290       SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
291       SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
292 
293   SPIRVType *
294   getOrCreateOpTypeImage(MachineIRBuilder &MIRBuilder, SPIRVType *SampledType,
295                          SPIRV::Dim::Dim Dim, uint32_t Depth, uint32_t Arrayed,
296                          uint32_t Multisampled, uint32_t Sampled,
297                          SPIRV::ImageFormat::ImageFormat ImageFormat,
298                          SPIRV::AccessQualifier::AccessQualifier AccQual);
299 
300   SPIRVType *getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder);
301 
302   SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
303                                            MachineIRBuilder &MIRBuilder);
304 
305   SPIRVType *
306   getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
307                         SPIRV::AccessQualifier::AccessQualifier AccQual);
308   SPIRVType *getOrCreateOpTypeDeviceEvent(MachineIRBuilder &MIRBuilder);
309   SPIRVType *getOrCreateOpTypeFunctionWithArgs(
310       const Type *Ty, SPIRVType *RetType,
311       const SmallVectorImpl<SPIRVType *> &ArgTypes,
312       MachineIRBuilder &MIRBuilder);
313   SPIRVType *getOrCreateOpTypeByOpcode(const Type *Ty,
314                                        MachineIRBuilder &MIRBuilder,
315                                        unsigned Opcode);
316 };
317 } // end namespace llvm
318 #endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
319