xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "SPIRVInstrInfo.h"
16 #include "SPIRVRegisterBankInfo.h"
17 #include "SPIRVRegisterInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "SPIRVTargetMachine.h"
20 #include "llvm/CodeGen/MachineInstrBuilder.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/IR/IntrinsicsSPIRV.h"
23 
24 #define DEBUG_TYPE "spirv-lower"
25 
26 using namespace llvm;
27 
28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30   // This code avoids CallLowering fail inside getVectorTypeBreakdown
31   // on v3i1 arguments. Maybe we need to return 1 for all types.
32   // TODO: remove it once this case is supported by the default implementation.
33   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34       (VT.getVectorElementType() == MVT::i1 ||
35        VT.getVectorElementType() == MVT::i8))
36     return 1;
37   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38     return 1;
39   return getNumRegisters(Context, VT);
40 }
41 
42 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
43                                                        CallingConv::ID CC,
44                                                        EVT VT) const {
45   // This code avoids CallLowering fail inside getVectorTypeBreakdown
46   // on v3i1 arguments. Maybe we need to return i32 for all types.
47   // TODO: remove it once this case is supported by the default implementation.
48   if (VT.isVector() && VT.getVectorNumElements() == 3) {
49     if (VT.getVectorElementType() == MVT::i1)
50       return MVT::v4i1;
51     else if (VT.getVectorElementType() == MVT::i8)
52       return MVT::v4i8;
53   }
54   return getRegisterType(Context, VT);
55 }
56 
57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58                                              const CallInst &I,
59                                              MachineFunction &MF,
60                                              unsigned Intrinsic) const {
61   unsigned AlignIdx = 3;
62   switch (Intrinsic) {
63   case Intrinsic::spv_load:
64     AlignIdx = 2;
65     [[fallthrough]];
66   case Intrinsic::spv_store: {
67     if (I.getNumOperands() >= AlignIdx + 1) {
68       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69       Info.align = Align(AlignOp->getZExtValue());
70     }
71     Info.flags = static_cast<MachineMemOperand::Flags>(
72         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73     Info.memVT = MVT::i64;
74     // TODO: take into account opaque pointers (don't use getElementType).
75     // MVT::getVT(PtrTy->getElementType());
76     return true;
77     break;
78   }
79   default:
80     break;
81   }
82   return false;
83 }
84 
85 std::pair<unsigned, const TargetRegisterClass *>
86 SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
87                                                   StringRef Constraint,
88                                                   MVT VT) const {
89   const TargetRegisterClass *RC = nullptr;
90   if (Constraint.starts_with("{"))
91     return std::make_pair(0u, RC);
92 
93   if (VT.isFloatingPoint())
94     RC = VT.isVector() ? &SPIRV::vfIDRegClass
95                        : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass
96                                                         : &SPIRV::fIDRegClass);
97   else if (VT.isInteger())
98     RC = VT.isVector() ? &SPIRV::vIDRegClass
99                        : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass
100                                                         : &SPIRV::IDRegClass);
101   else
102     RC = &SPIRV::IDRegClass;
103 
104   return std::make_pair(0u, RC);
105 }
106 
107 inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
108   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
109   return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
110              ? TypeInst->getOperand(1).getReg()
111              : OpReg;
112 }
113 
114 static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
115                             SPIRVGlobalRegistry &GR, MachineInstr &I,
116                             Register OpReg, unsigned OpIdx,
117                             SPIRVType *NewPtrType) {
118   Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
119   MachineIRBuilder MIB(I);
120   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
121                  .addDef(NewReg)
122                  .addUse(GR.getSPIRVTypeID(NewPtrType))
123                  .addUse(OpReg)
124                  .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
125                                    *STI.getRegBankInfo());
126   if (!Res)
127     report_fatal_error("insert validation bitcast: cannot constrain all uses");
128   MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
129   GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
130   I.getOperand(OpIdx).setReg(NewReg);
131 }
132 
133 static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
134                                    SPIRVType *OpType, bool ReuseType,
135                                    bool EmitIR, SPIRVType *ResType,
136                                    const Type *ResTy) {
137   SPIRV::StorageClass::StorageClass SC =
138       static_cast<SPIRV::StorageClass::StorageClass>(
139           OpType->getOperand(1).getImm());
140   MachineIRBuilder MIB(I);
141   SPIRVType *NewBaseType =
142       ReuseType ? ResType
143                 : GR.getOrCreateSPIRVType(
144                       ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
145   return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
146 }
147 
148 // Insert a bitcast before the instruction to keep SPIR-V code valid
149 // when there is a type mismatch between results and operand types.
150 static void validatePtrTypes(const SPIRVSubtarget &STI,
151                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
152                              MachineInstr &I, unsigned OpIdx,
153                              SPIRVType *ResType, const Type *ResTy = nullptr) {
154   // Get operand type
155   MachineFunction *MF = I.getParent()->getParent();
156   Register OpReg = I.getOperand(OpIdx).getReg();
157   Register OpTypeReg = getTypeReg(MRI, OpReg);
158   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
159   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
160     return;
161   // Get operand's pointee type
162   Register ElemTypeReg = OpType->getOperand(2).getReg();
163   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
164   if (!ElemType)
165     return;
166   // Check if we need a bitcast to make a statement valid
167   bool IsSameMF = MF == ResType->getParent()->getParent();
168   bool IsEqualTypes = IsSameMF ? ElemType == ResType
169                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
170   if (IsEqualTypes)
171     return;
172   // There is a type mismatch between results and operand types
173   // and we insert a bitcast before the instruction to keep SPIR-V code valid
174   SPIRVType *NewPtrType =
175       createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
176   if (!GR.isBitcastCompatible(NewPtrType, OpType))
177     report_fatal_error(
178         "insert validation bitcast: incompatible result and operand types");
179   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
180 }
181 
182 // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
183 // that doesn't point to OpTypeEvent.
184 static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
185                                        MachineRegisterInfo *MRI,
186                                        SPIRVGlobalRegistry &GR,
187                                        MachineInstr &I) {
188   constexpr unsigned OpIdx = 2;
189   MachineFunction *MF = I.getParent()->getParent();
190   Register OpReg = I.getOperand(OpIdx).getReg();
191   Register OpTypeReg = getTypeReg(MRI, OpReg);
192   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
193   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
194     return;
195   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
196   if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
197     return;
198   // Insert a bitcast before the instruction to keep SPIR-V code valid.
199   LLVMContext &Context = MF->getFunction().getContext();
200   SPIRVType *NewPtrType =
201       createNewPtrType(GR, I, OpType, false, true, nullptr,
202                        TargetExtType::get(Context, "spirv.Event"));
203   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
204 }
205 
206 static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
207                                       MachineRegisterInfo *MRI,
208                                       SPIRVGlobalRegistry &GR, MachineInstr &I,
209                                       unsigned OpIdx) {
210   MachineFunction *MF = I.getParent()->getParent();
211   Register OpReg = I.getOperand(OpIdx).getReg();
212   Register OpTypeReg = getTypeReg(MRI, OpReg);
213   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
214   if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
215     return;
216   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
217   if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
218       ElemType->getNumOperands() != 2)
219     return;
220   // It's a structure-wrapper around another type with a single member field.
221   SPIRVType *MemberType =
222       GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
223   if (!MemberType)
224     return;
225   unsigned MemberTypeOp = MemberType->getOpcode();
226   if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
227       MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
228     return;
229   // It's a structure-wrapper around a valid type. Insert a bitcast before the
230   // instruction to keep SPIR-V code valid.
231   SPIRV::StorageClass::StorageClass SC =
232       static_cast<SPIRV::StorageClass::StorageClass>(
233           OpType->getOperand(1).getImm());
234   MachineIRBuilder MIB(I);
235   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
236   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
237 }
238 
239 // Insert a bitcast before the function call instruction to keep SPIR-V code
240 // valid when there is a type mismatch between actual and expected types of an
241 // argument:
242 // %formal = OpFunctionParameter %formal_type
243 // ...
244 // %res = OpFunctionCall %ty %fun %actual ...
245 // implies that %actual is of %formal_type, and in case of opaque pointers.
246 // We may need to insert a bitcast to ensure this.
247 void validateFunCallMachineDef(const SPIRVSubtarget &STI,
248                                MachineRegisterInfo *DefMRI,
249                                MachineRegisterInfo *CallMRI,
250                                SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
251                                MachineInstr *FunDef) {
252   if (FunDef->getOpcode() != SPIRV::OpFunction)
253     return;
254   unsigned OpIdx = 3;
255   for (FunDef = FunDef->getNextNode();
256        FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
257        OpIdx < FunCall.getNumOperands();
258        FunDef = FunDef->getNextNode(), OpIdx++) {
259     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
260     SPIRVType *DefElemType =
261         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
262             ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
263                                      DefPtrType->getParent()->getParent())
264             : nullptr;
265     if (DefElemType) {
266       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
267       // validatePtrTypes() works in the context if the call site
268       // When we process historical records about forward calls
269       // we need to switch context to the (forward) call site and
270       // then restore it back to the current machine function.
271       MachineFunction *CurMF =
272           GR.setCurrentFunc(*FunCall.getParent()->getParent());
273       validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
274                        DefElemTy);
275       GR.setCurrentFunc(*CurMF);
276     }
277   }
278 }
279 
280 // Ensure there is no mismatch between actual and expected arg types: calls
281 // with a processed definition. Return Function pointer if it's a forward
282 // call (ahead of definition), and nullptr otherwise.
283 const Function *validateFunCall(const SPIRVSubtarget &STI,
284                                 MachineRegisterInfo *CallMRI,
285                                 SPIRVGlobalRegistry &GR,
286                                 MachineInstr &FunCall) {
287   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
288   const Function *F = dyn_cast<Function>(GV);
289   MachineInstr *FunDef =
290       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
291   if (!FunDef)
292     return F;
293   MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
294   validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
295   return nullptr;
296 }
297 
298 // Ensure there is no mismatch between actual and expected arg types: calls
299 // ahead of a processed definition.
300 void validateForwardCalls(const SPIRVSubtarget &STI,
301                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
302                           MachineInstr &FunDef) {
303   const Function *F = GR.getFunctionByDefinition(&FunDef);
304   if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
305     for (MachineInstr *FunCall : *FwdCalls) {
306       MachineRegisterInfo *CallMRI =
307           &FunCall->getParent()->getParent()->getRegInfo();
308       validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
309     }
310 }
311 
312 // Validation of an access chain.
313 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
314                          SPIRVGlobalRegistry &GR, MachineInstr &I) {
315   SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
316   if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
317     SPIRVType *BaseElemType =
318         GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
319     validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
320   }
321 }
322 
323 // TODO: the logic of inserting additional bitcast's is to be moved
324 // to pre-IRTranslation passes eventually
325 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
326   // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
327   // We'd like to avoid the needless second processing pass.
328   if (ProcessedMF.find(&MF) != ProcessedMF.end())
329     return;
330 
331   MachineRegisterInfo *MRI = &MF.getRegInfo();
332   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
333   GR.setCurrentFunc(MF);
334   for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
335     MachineBasicBlock *MBB = &*I;
336     for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
337          MBBI != MBBE;) {
338       MachineInstr &MI = *MBBI++;
339       switch (MI.getOpcode()) {
340       case SPIRV::OpAtomicLoad:
341       case SPIRV::OpAtomicExchange:
342       case SPIRV::OpAtomicCompareExchange:
343       case SPIRV::OpAtomicCompareExchangeWeak:
344       case SPIRV::OpAtomicIIncrement:
345       case SPIRV::OpAtomicIDecrement:
346       case SPIRV::OpAtomicIAdd:
347       case SPIRV::OpAtomicISub:
348       case SPIRV::OpAtomicSMin:
349       case SPIRV::OpAtomicUMin:
350       case SPIRV::OpAtomicSMax:
351       case SPIRV::OpAtomicUMax:
352       case SPIRV::OpAtomicAnd:
353       case SPIRV::OpAtomicOr:
354       case SPIRV::OpAtomicXor:
355         // for the above listed instructions
356         // OpAtomicXXX <ResType>, ptr %Op, ...
357         // implies that %Op is a pointer to <ResType>
358       case SPIRV::OpLoad:
359         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
360         validatePtrTypes(STI, MRI, GR, MI, 2,
361                          GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362         break;
363       case SPIRV::OpAtomicStore:
364         // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
365         // implies that %Op points to the <Obj>'s type
366         validatePtrTypes(STI, MRI, GR, MI, 0,
367                          GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
368         break;
369       case SPIRV::OpStore:
370         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
371         validatePtrTypes(STI, MRI, GR, MI, 0,
372                          GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
373         break;
374       case SPIRV::OpPtrCastToGeneric:
375       case SPIRV::OpGenericCastToPtr:
376         validateAccessChain(STI, MRI, GR, MI);
377         break;
378       case SPIRV::OpInBoundsPtrAccessChain:
379         if (MI.getNumOperands() == 4)
380           validateAccessChain(STI, MRI, GR, MI);
381         break;
382 
383       case SPIRV::OpFunctionCall:
384         // ensure there is no mismatch between actual and expected arg types:
385         // calls with a processed definition
386         if (MI.getNumOperands() > 3)
387           if (const Function *F = validateFunCall(STI, MRI, GR, MI))
388             GR.addForwardCall(F, &MI);
389         break;
390       case SPIRV::OpFunction:
391         // ensure there is no mismatch between actual and expected arg types:
392         // calls ahead of a processed definition
393         validateForwardCalls(STI, MRI, GR, MI);
394         break;
395 
396       // ensure that LLVM IR bitwise instructions result in logical SPIR-V
397       // instructions when applied to bool type
398       case SPIRV::OpBitwiseOrS:
399       case SPIRV::OpBitwiseOrV:
400         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
401                                       SPIRV::OpTypeBool))
402           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
403         break;
404       case SPIRV::OpBitwiseAndS:
405       case SPIRV::OpBitwiseAndV:
406         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
407                                       SPIRV::OpTypeBool))
408           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
409         break;
410       case SPIRV::OpBitwiseXorS:
411       case SPIRV::OpBitwiseXorV:
412         if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
413                                       SPIRV::OpTypeBool))
414           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
415         break;
416       case SPIRV::OpGroupAsyncCopy:
417         validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
418         validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
419         break;
420       case SPIRV::OpGroupWaitEvents:
421         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
422         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
423         break;
424       case SPIRV::OpConstantI: {
425         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
426         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
427             MI.getOperand(2).getImm() == 0) {
428           // Validate the null constant of a target extension type
429           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
430           for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
431             MI.removeOperand(i);
432         }
433       } break;
434       }
435     }
436   }
437   ProcessedMF.insert(&MF);
438   TargetLowering::finalizeLowering(MF);
439 }
440