xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "SPIRVInstrInfo.h"
16 #include "SPIRVRegisterBankInfo.h"
17 #include "SPIRVRegisterInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/IR/IntrinsicsSPIRV.h"
22 
23 #define DEBUG_TYPE "spirv-lower"
24 
25 using namespace llvm;
26 
27 // Returns true of the types logically match, as defined in
28 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
typesLogicallyMatch(const SPIRVType * Ty1,const SPIRVType * Ty2,SPIRVGlobalRegistry & GR)29 static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
30                                 SPIRVGlobalRegistry &GR) {
31   if (Ty1->getOpcode() != Ty2->getOpcode())
32     return false;
33 
34   if (Ty1->getNumOperands() != Ty2->getNumOperands())
35     return false;
36 
37   if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
38     // Array must have the same size.
39     if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
40       return false;
41 
42     SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
43     SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
44     return ElemType1 == ElemType2 ||
45            typesLogicallyMatch(ElemType1, ElemType2, GR);
46   }
47 
48   if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
49     for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
50       SPIRVType *ElemType1 =
51           GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg());
52       SPIRVType *ElemType2 =
53           GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg());
54       if (ElemType1 != ElemType2 &&
55           !typesLogicallyMatch(ElemType1, ElemType2, GR))
56         return false;
57     }
58     return true;
59   }
60   return false;
61 }
62 
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const63 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
64     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
65   // This code avoids CallLowering fail inside getVectorTypeBreakdown
66   // on v3i1 arguments. Maybe we need to return 1 for all types.
67   // TODO: remove it once this case is supported by the default implementation.
68   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
69       (VT.getVectorElementType() == MVT::i1 ||
70        VT.getVectorElementType() == MVT::i8))
71     return 1;
72   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
73     return 1;
74   return getNumRegisters(Context, VT);
75 }
76 
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const77 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
78                                                        CallingConv::ID CC,
79                                                        EVT VT) const {
80   // This code avoids CallLowering fail inside getVectorTypeBreakdown
81   // on v3i1 arguments. Maybe we need to return i32 for all types.
82   // TODO: remove it once this case is supported by the default implementation.
83   if (VT.isVector() && VT.getVectorNumElements() == 3) {
84     if (VT.getVectorElementType() == MVT::i1)
85       return MVT::v4i1;
86     else if (VT.getVectorElementType() == MVT::i8)
87       return MVT::v4i8;
88   }
89   return getRegisterType(Context, VT);
90 }
91 
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const92 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
93                                              const CallInst &I,
94                                              MachineFunction &MF,
95                                              unsigned Intrinsic) const {
96   unsigned AlignIdx = 3;
97   switch (Intrinsic) {
98   case Intrinsic::spv_load:
99     AlignIdx = 2;
100     [[fallthrough]];
101   case Intrinsic::spv_store: {
102     if (I.getNumOperands() >= AlignIdx + 1) {
103       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
104       Info.align = Align(AlignOp->getZExtValue());
105     }
106     Info.flags = static_cast<MachineMemOperand::Flags>(
107         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
108     Info.memVT = MVT::i64;
109     // TODO: take into account opaque pointers (don't use getElementType).
110     // MVT::getVT(PtrTy->getElementType());
111     return true;
112     break;
113   }
114   default:
115     break;
116   }
117   return false;
118 }
119 
120 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const121 SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
122                                                   StringRef Constraint,
123                                                   MVT VT) const {
124   const TargetRegisterClass *RC = nullptr;
125   if (Constraint.starts_with("{"))
126     return std::make_pair(0u, RC);
127 
128   if (VT.isFloatingPoint())
129     RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
130   else if (VT.isInteger())
131     RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
132   else
133     RC = &SPIRV::iIDRegClass;
134 
135   return std::make_pair(0u, RC);
136 }
137 
getTypeReg(MachineRegisterInfo * MRI,Register OpReg)138 inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
139   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
140   return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
141              ? TypeInst->getOperand(1).getReg()
142              : OpReg;
143 }
144 
doInsertBitcast(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I,Register OpReg,unsigned OpIdx,SPIRVType * NewPtrType)145 static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
146                             SPIRVGlobalRegistry &GR, MachineInstr &I,
147                             Register OpReg, unsigned OpIdx,
148                             SPIRVType *NewPtrType) {
149   MachineIRBuilder MIB(I);
150   Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
151   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
152                  .addDef(NewReg)
153                  .addUse(GR.getSPIRVTypeID(NewPtrType))
154                  .addUse(OpReg)
155                  .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
156                                    *STI.getRegBankInfo());
157   if (!Res)
158     report_fatal_error("insert validation bitcast: cannot constrain all uses");
159   I.getOperand(OpIdx).setReg(NewReg);
160 }
161 
createNewPtrType(SPIRVGlobalRegistry & GR,MachineInstr & I,SPIRVType * OpType,bool ReuseType,SPIRVType * ResType,const Type * ResTy)162 static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
163                                    SPIRVType *OpType, bool ReuseType,
164                                    SPIRVType *ResType, const Type *ResTy) {
165   SPIRV::StorageClass::StorageClass SC =
166       static_cast<SPIRV::StorageClass::StorageClass>(
167           OpType->getOperand(1).getImm());
168   MachineIRBuilder MIB(I);
169   SPIRVType *NewBaseType =
170       ReuseType ? ResType
171                 : GR.getOrCreateSPIRVType(
172                       ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
173   return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
174 }
175 
176 // Insert a bitcast before the instruction to keep SPIR-V code valid
177 // when there is a type mismatch between results and operand types.
validatePtrTypes(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I,unsigned OpIdx,SPIRVType * ResType,const Type * ResTy=nullptr)178 static void validatePtrTypes(const SPIRVSubtarget &STI,
179                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
180                              MachineInstr &I, unsigned OpIdx,
181                              SPIRVType *ResType, const Type *ResTy = nullptr) {
182   // Get operand type
183   MachineFunction *MF = I.getParent()->getParent();
184   Register OpReg = I.getOperand(OpIdx).getReg();
185   Register OpTypeReg = getTypeReg(MRI, OpReg);
186   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
187   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
188     return;
189   // Get operand's pointee type
190   Register ElemTypeReg = OpType->getOperand(2).getReg();
191   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
192   if (!ElemType)
193     return;
194   // Check if we need a bitcast to make a statement valid
195   bool IsSameMF = MF == ResType->getParent()->getParent();
196   bool IsEqualTypes = IsSameMF ? ElemType == ResType
197                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
198   if (IsEqualTypes)
199     return;
200   // There is a type mismatch between results and operand types
201   // and we insert a bitcast before the instruction to keep SPIR-V code valid
202   SPIRVType *NewPtrType =
203       createNewPtrType(GR, I, OpType, IsSameMF, ResType, ResTy);
204   if (!GR.isBitcastCompatible(NewPtrType, OpType))
205     report_fatal_error(
206         "insert validation bitcast: incompatible result and operand types");
207   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
208 }
209 
210 // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
211 // that doesn't point to OpTypeEvent.
validateGroupWaitEventsPtr(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I)212 static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
213                                        MachineRegisterInfo *MRI,
214                                        SPIRVGlobalRegistry &GR,
215                                        MachineInstr &I) {
216   constexpr unsigned OpIdx = 2;
217   MachineFunction *MF = I.getParent()->getParent();
218   Register OpReg = I.getOperand(OpIdx).getReg();
219   Register OpTypeReg = getTypeReg(MRI, OpReg);
220   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
221   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
222     return;
223   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
224   if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
225     return;
226   // Insert a bitcast before the instruction to keep SPIR-V code valid.
227   LLVMContext &Context = MF->getFunction().getContext();
228   SPIRVType *NewPtrType =
229       createNewPtrType(GR, I, OpType, false, nullptr,
230                        TargetExtType::get(Context, "spirv.Event"));
231   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
232 }
233 
validateLifetimeStart(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I)234 static void validateLifetimeStart(const SPIRVSubtarget &STI,
235                                   MachineRegisterInfo *MRI,
236                                   SPIRVGlobalRegistry &GR, MachineInstr &I) {
237   Register PtrReg = I.getOperand(0).getReg();
238   MachineFunction *MF = I.getParent()->getParent();
239   Register PtrTypeReg = getTypeReg(MRI, PtrReg);
240   SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
241   SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
242   if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
243       (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
244        PonteeElemType->getOperand(1).getImm() == 8))
245     return;
246   // To keep the code valid a bitcast must be inserted
247   SPIRV::StorageClass::StorageClass SC =
248       static_cast<SPIRV::StorageClass::StorageClass>(
249           PtrType->getOperand(1).getImm());
250   MachineIRBuilder MIB(I);
251   LLVMContext &Context = MF->getFunction().getContext();
252   SPIRVType *NewPtrType =
253       GR.getOrCreateSPIRVPointerType(IntegerType::getInt8Ty(Context), MIB, SC);
254   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
255 }
256 
validatePtrUnwrapStructField(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I,unsigned OpIdx)257 static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
258                                          MachineRegisterInfo *MRI,
259                                          SPIRVGlobalRegistry &GR,
260                                          MachineInstr &I, unsigned OpIdx) {
261   MachineFunction *MF = I.getParent()->getParent();
262   Register OpReg = I.getOperand(OpIdx).getReg();
263   Register OpTypeReg = getTypeReg(MRI, OpReg);
264   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
265   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
266     return;
267   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
268   if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
269       ElemType->getNumOperands() != 2)
270     return;
271   // It's a structure-wrapper around another type with a single member field.
272   SPIRVType *MemberType =
273       GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
274   if (!MemberType)
275     return;
276   unsigned MemberTypeOp = MemberType->getOpcode();
277   if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
278       MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
279     return;
280   // It's a structure-wrapper around a valid type. Insert a bitcast before the
281   // instruction to keep SPIR-V code valid.
282   SPIRV::StorageClass::StorageClass SC =
283       static_cast<SPIRV::StorageClass::StorageClass>(
284           OpType->getOperand(1).getImm());
285   MachineIRBuilder MIB(I);
286   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
287   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
288 }
289 
290 // Insert a bitcast before the function call instruction to keep SPIR-V code
291 // valid when there is a type mismatch between actual and expected types of an
292 // argument:
293 // %formal = OpFunctionParameter %formal_type
294 // ...
295 // %res = OpFunctionCall %ty %fun %actual ...
296 // implies that %actual is of %formal_type, and in case of opaque pointers.
297 // We may need to insert a bitcast to ensure this.
validateFunCallMachineDef(const SPIRVSubtarget & STI,MachineRegisterInfo * DefMRI,MachineRegisterInfo * CallMRI,SPIRVGlobalRegistry & GR,MachineInstr & FunCall,MachineInstr * FunDef)298 void validateFunCallMachineDef(const SPIRVSubtarget &STI,
299                                MachineRegisterInfo *DefMRI,
300                                MachineRegisterInfo *CallMRI,
301                                SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
302                                MachineInstr *FunDef) {
303   if (FunDef->getOpcode() != SPIRV::OpFunction)
304     return;
305   unsigned OpIdx = 3;
306   for (FunDef = FunDef->getNextNode();
307        FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
308        OpIdx < FunCall.getNumOperands();
309        FunDef = FunDef->getNextNode(), OpIdx++) {
310     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
311     SPIRVType *DefElemType =
312         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
313             ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
314                                      DefPtrType->getParent()->getParent())
315             : nullptr;
316     if (DefElemType) {
317       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
318       // validatePtrTypes() works in the context if the call site
319       // When we process historical records about forward calls
320       // we need to switch context to the (forward) call site and
321       // then restore it back to the current machine function.
322       MachineFunction *CurMF =
323           GR.setCurrentFunc(*FunCall.getParent()->getParent());
324       validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
325                        DefElemTy);
326       GR.setCurrentFunc(*CurMF);
327     }
328   }
329 }
330 
331 // Ensure there is no mismatch between actual and expected arg types: calls
332 // with a processed definition. Return Function pointer if it's a forward
333 // call (ahead of definition), and nullptr otherwise.
validateFunCall(const SPIRVSubtarget & STI,MachineRegisterInfo * CallMRI,SPIRVGlobalRegistry & GR,MachineInstr & FunCall)334 const Function *validateFunCall(const SPIRVSubtarget &STI,
335                                 MachineRegisterInfo *CallMRI,
336                                 SPIRVGlobalRegistry &GR,
337                                 MachineInstr &FunCall) {
338   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
339   const Function *F = dyn_cast<Function>(GV);
340   MachineInstr *FunDef =
341       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
342   if (!FunDef)
343     return F;
344   MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
345   validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
346   return nullptr;
347 }
348 
349 // Ensure there is no mismatch between actual and expected arg types: calls
350 // ahead of a processed definition.
validateForwardCalls(const SPIRVSubtarget & STI,MachineRegisterInfo * DefMRI,SPIRVGlobalRegistry & GR,MachineInstr & FunDef)351 void validateForwardCalls(const SPIRVSubtarget &STI,
352                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
353                           MachineInstr &FunDef) {
354   const Function *F = GR.getFunctionByDefinition(&FunDef);
355   if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
356     for (MachineInstr *FunCall : *FwdCalls) {
357       MachineRegisterInfo *CallMRI =
358           &FunCall->getParent()->getParent()->getRegInfo();
359       validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
360     }
361 }
362 
363 // Validation of an access chain.
validateAccessChain(const SPIRVSubtarget & STI,MachineRegisterInfo * MRI,SPIRVGlobalRegistry & GR,MachineInstr & I)364 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
365                          SPIRVGlobalRegistry &GR, MachineInstr &I) {
366   SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
367   if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
368     SPIRVType *BaseElemType =
369         GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
370     validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
371   }
372 }
373 
374 // TODO: the logic of inserting additional bitcast's is to be moved
375 // to pre-IRTranslation passes eventually
finalizeLowering(MachineFunction & MF) const376 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
377   // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
378   // We'd like to avoid the needless second processing pass.
379   if (ProcessedMF.find(&MF) != ProcessedMF.end())
380     return;
381 
382   MachineRegisterInfo *MRI = &MF.getRegInfo();
383   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
384   GR.setCurrentFunc(MF);
385   for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
386     MachineBasicBlock *MBB = &*I;
387     SmallPtrSet<MachineInstr *, 8> ToMove;
388     for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
389          MBBI != MBBE;) {
390       MachineInstr &MI = *MBBI++;
391       switch (MI.getOpcode()) {
392       case SPIRV::OpAtomicLoad:
393       case SPIRV::OpAtomicExchange:
394       case SPIRV::OpAtomicCompareExchange:
395       case SPIRV::OpAtomicCompareExchangeWeak:
396       case SPIRV::OpAtomicIIncrement:
397       case SPIRV::OpAtomicIDecrement:
398       case SPIRV::OpAtomicIAdd:
399       case SPIRV::OpAtomicISub:
400       case SPIRV::OpAtomicSMin:
401       case SPIRV::OpAtomicUMin:
402       case SPIRV::OpAtomicSMax:
403       case SPIRV::OpAtomicUMax:
404       case SPIRV::OpAtomicAnd:
405       case SPIRV::OpAtomicOr:
406       case SPIRV::OpAtomicXor:
407         // for the above listed instructions
408         // OpAtomicXXX <ResType>, ptr %Op, ...
409         // implies that %Op is a pointer to <ResType>
410       case SPIRV::OpLoad:
411         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
412         if (enforcePtrTypeCompatibility(MI, 2, 0))
413           break;
414 
415         validatePtrTypes(STI, MRI, GR, MI, 2,
416                          GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
417         break;
418       case SPIRV::OpAtomicStore:
419         // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
420         // implies that %Op points to the <Obj>'s type
421         validatePtrTypes(STI, MRI, GR, MI, 0,
422                          GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
423         break;
424       case SPIRV::OpStore:
425         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
426         validatePtrTypes(STI, MRI, GR, MI, 0,
427                          GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
428         break;
429       case SPIRV::OpPtrCastToGeneric:
430       case SPIRV::OpGenericCastToPtr:
431       case SPIRV::OpGenericCastToPtrExplicit:
432         validateAccessChain(STI, MRI, GR, MI);
433         break;
434       case SPIRV::OpPtrAccessChain:
435       case SPIRV::OpInBoundsPtrAccessChain:
436         if (MI.getNumOperands() == 4)
437           validateAccessChain(STI, MRI, GR, MI);
438         break;
439 
440       case SPIRV::OpFunctionCall:
441         // ensure there is no mismatch between actual and expected arg types:
442         // calls with a processed definition
443         if (MI.getNumOperands() > 3)
444           if (const Function *F = validateFunCall(STI, MRI, GR, MI))
445             GR.addForwardCall(F, &MI);
446         break;
447       case SPIRV::OpFunction:
448         // ensure there is no mismatch between actual and expected arg types:
449         // calls ahead of a processed definition
450         validateForwardCalls(STI, MRI, GR, MI);
451         break;
452 
453       // ensure that LLVM IR add/sub instructions result in logical SPIR-V
454       // instructions when applied to bool type
455       case SPIRV::OpIAddS:
456       case SPIRV::OpIAddV:
457       case SPIRV::OpISubS:
458       case SPIRV::OpISubV:
459         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
460                                       SPIRV::OpTypeBool))
461           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
462         break;
463 
464       // ensure that LLVM IR bitwise instructions result in logical SPIR-V
465       // instructions when applied to bool type
466       case SPIRV::OpBitwiseOrS:
467       case SPIRV::OpBitwiseOrV:
468         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
469                                       SPIRV::OpTypeBool))
470           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
471         break;
472       case SPIRV::OpBitwiseAndS:
473       case SPIRV::OpBitwiseAndV:
474         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
475                                       SPIRV::OpTypeBool))
476           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
477         break;
478       case SPIRV::OpBitwiseXorS:
479       case SPIRV::OpBitwiseXorV:
480         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
481                                       SPIRV::OpTypeBool))
482           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
483         break;
484       case SPIRV::OpLifetimeStart:
485       case SPIRV::OpLifetimeStop:
486         if (MI.getOperand(1).getImm() > 0)
487           validateLifetimeStart(STI, MRI, GR, MI);
488         break;
489       case SPIRV::OpGroupAsyncCopy:
490         validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
491         validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
492         break;
493       case SPIRV::OpGroupWaitEvents:
494         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
495         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
496         break;
497       case SPIRV::OpConstantI: {
498         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
499         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
500             MI.getOperand(2).getImm() == 0) {
501           // Validate the null constant of a target extension type
502           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
503           for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
504             MI.removeOperand(i);
505         }
506       } break;
507       case SPIRV::OpPhi: {
508         // Phi refers to a type definition that goes after the Phi
509         // instruction, so that the virtual register definition of the type
510         // doesn't dominate all uses. Let's place the type definition
511         // instruction at the end of the predecessor.
512         MachineBasicBlock *Curr = MI.getParent();
513         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
514         if (Type->getParent() == Curr && !Curr->pred_empty())
515           ToMove.insert(const_cast<MachineInstr *>(Type));
516       } break;
517       case SPIRV::OpExtInst: {
518         // prefetch
519         if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
520             MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
521           continue;
522         switch (MI.getOperand(3).getImm()) {
523         case SPIRV::OpenCLExtInst::frexp:
524         case SPIRV::OpenCLExtInst::lgamma_r:
525         case SPIRV::OpenCLExtInst::remquo: {
526           // The last operand must be of a pointer to i32 or vector of i32
527           // values.
528           MachineIRBuilder MIB(MI);
529           SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
530           SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
531           assert(RetType && "Expected return type");
532           validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
533                            RetType->getOpcode() != SPIRV::OpTypeVector
534                                ? Int32Type
535                                : GR.getOrCreateSPIRVVectorType(
536                                      Int32Type, RetType->getOperand(2).getImm(),
537                                      MIB, false));
538         } break;
539         case SPIRV::OpenCLExtInst::fract:
540         case SPIRV::OpenCLExtInst::modf:
541         case SPIRV::OpenCLExtInst::sincos:
542           // The last operand must be of a pointer to the base type represented
543           // by the previous operand.
544           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
545                  "Expected v-reg");
546           validatePtrTypes(
547               STI, MRI, GR, MI, MI.getNumOperands() - 1,
548               GR.getSPIRVTypeForVReg(
549                   MI.getOperand(MI.getNumOperands() - 2).getReg()));
550           break;
551         case SPIRV::OpenCLExtInst::prefetch:
552           // Expected `ptr` type is a pointer to float, integer or vector, but
553           // the pontee value can be wrapped into a struct.
554           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
555                  "Expected v-reg");
556           validatePtrUnwrapStructField(STI, MRI, GR, MI,
557                                        MI.getNumOperands() - 2);
558           break;
559         }
560       } break;
561       }
562     }
563     for (MachineInstr *MI : ToMove) {
564       MachineBasicBlock *Curr = MI->getParent();
565       MachineBasicBlock *Pred = *Curr->pred_begin();
566       Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));
567     }
568   }
569   ProcessedMF.insert(&MF);
570   TargetLowering::finalizeLowering(MF);
571 }
572 
573 // Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
574 // PtrOpIdx matches the type for operand OpIdx. Returns true if they already
575 // match or if the instruction was modified to make them match.
enforcePtrTypeCompatibility(MachineInstr & I,unsigned int PtrOpIdx,unsigned int OpIdx) const576 bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
577     MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
578   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
579   SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
580   SPIRVType *PointeeType = GR.getPointeeType(PtrType);
581   SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
582 
583   if (PointeeType == OpType)
584     return true;
585 
586   if (typesLogicallyMatch(PointeeType, OpType, GR)) {
587     // Apply OpCopyLogical to OpIdx.
588     if (I.getOperand(OpIdx).isDef() &&
589         insertLogicalCopyOnResult(I, PointeeType)) {
590       return true;
591     }
592 
593     llvm_unreachable("Unable to add OpCopyLogical yet.");
594     return false;
595   }
596 
597   return false;
598 }
599 
insertLogicalCopyOnResult(MachineInstr & I,SPIRVType * NewResultType) const600 bool SPIRVTargetLowering::insertLogicalCopyOnResult(
601     MachineInstr &I, SPIRVType *NewResultType) const {
602   MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
603   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
604 
605   Register NewResultReg =
606       createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
607   Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
608 
609   assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
610          "Expected only one def");
611   MachineOperand &OldResult = *I.defs().begin();
612   Register OldResultReg = OldResult.getReg();
613   MachineOperand &OldType = *I.uses().begin();
614   Register OldTypeReg = OldType.getReg();
615 
616   OldResult.setReg(NewResultReg);
617   OldType.setReg(NewTypeReg);
618 
619   MachineIRBuilder MIB(*I.getNextNode());
620   return MIB.buildInstr(SPIRV::OpCopyLogical)
621       .addDef(OldResultReg)
622       .addUse(OldTypeReg)
623       .addUse(NewResultReg)
624       .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
625                         *STI.getRegBankInfo());
626 }
627