xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
SPIRVPreLegalizer()34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
addConstantsToTrack(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & STI,DenseMap<MachineInstr *,Type * > & TargetExtConstTypes,SmallSet<Register,4> & TrackedConstRegs)42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
45                     SmallSet<Register, 4> &TrackedConstRegs) {
46   MachineRegisterInfo &MRI = MF.getRegInfo();
47   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
48   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
49   for (MachineBasicBlock &MBB : MF) {
50     for (MachineInstr &MI : MBB) {
51       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
52         continue;
53       ToErase.push_back(&MI);
54       Register SrcReg = MI.getOperand(2).getReg();
55       auto *Const =
56           cast<Constant>(cast<ConstantAsMetadata>(
57                              MI.getOperand(3).getMetadata()->getOperand(0))
58                              ->getValue());
59       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
60         Register Reg = GR->find(GV, &MF);
61         if (!Reg.isValid())
62           GR->add(GV, &MF, SrcReg);
63         else
64           RegsAlreadyAddedToDT[&MI] = Reg;
65       } else {
66         Register Reg = GR->find(Const, &MF);
67         if (!Reg.isValid()) {
68           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
69             auto *BuildVec = MRI.getVRegDef(SrcReg);
70             assert(BuildVec &&
71                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
72             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73               // Ensure that OpConstantComposite reuses a constant when it's
74               // already created and available in the same machine function.
75               Constant *ElemConst = ConstVec->getElementAsConstant(i);
76               Register ElemReg = GR->find(ElemConst, &MF);
77               if (!ElemReg.isValid())
78                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
79               else
80                 BuildVec->getOperand(1 + i).setReg(ElemReg);
81             }
82           }
83           GR->add(Const, &MF, SrcReg);
84           TrackedConstRegs.insert(SrcReg);
85           if (Const->getType()->isTargetExtTy()) {
86             // remember association so that we can restore it when assign types
87             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
88             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
89                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
90               TargetExtConstTypes[SrcMI] = Const->getType();
91             if (Const->isNullValue()) {
92               MachineIRBuilder MIB(MF);
93               SPIRVType *ExtType =
94                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
95               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
96               SrcMI->addOperand(MachineOperand::CreateReg(
97                   GR->getSPIRVTypeID(ExtType), false));
98             }
99           }
100         } else {
101           RegsAlreadyAddedToDT[&MI] = Reg;
102           // This MI is unused and will be removed. If the MI uses
103           // const_composite, it will be unused and should be removed too.
104           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
105           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
106           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
107             ToEraseComposites.push_back(SrcMI);
108         }
109       }
110     }
111   }
112   for (MachineInstr *MI : ToErase) {
113     Register Reg = MI->getOperand(2).getReg();
114     if (RegsAlreadyAddedToDT.contains(MI))
115       Reg = RegsAlreadyAddedToDT[MI];
116     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
117     if (!MRI.getRegClassOrNull(Reg) && RC)
118       MRI.setRegClass(Reg, RC);
119     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
120     MI->eraseFromParent();
121   }
122   for (MachineInstr *MI : ToEraseComposites)
123     MI->eraseFromParent();
124 }
125 
126 static void
foldConstantsIntoIntrinsics(MachineFunction & MF,const SmallSet<Register,4> & TrackedConstRegs)127 foldConstantsIntoIntrinsics(MachineFunction &MF,
128                             const SmallSet<Register, 4> &TrackedConstRegs) {
129   SmallVector<MachineInstr *, 10> ToErase;
130   MachineRegisterInfo &MRI = MF.getRegInfo();
131   const unsigned AssignNameOperandShift = 2;
132   for (MachineBasicBlock &MBB : MF) {
133     for (MachineInstr &MI : MBB) {
134       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
135         continue;
136       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
137       while (MI.getOperand(NumOp).isReg()) {
138         MachineOperand &MOp = MI.getOperand(NumOp);
139         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
140         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
141         MI.removeOperand(NumOp);
142         MI.addOperand(MachineOperand::CreateImm(
143             ConstMI->getOperand(1).getCImm()->getZExtValue()));
144         Register DefReg = ConstMI->getOperand(0).getReg();
145         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
146           ToErase.push_back(ConstMI);
147       }
148     }
149   }
150   for (MachineInstr *MI : ToErase)
151     MI->eraseFromParent();
152 }
153 
findAssignTypeInstr(Register Reg,MachineRegisterInfo * MRI)154 static MachineInstr *findAssignTypeInstr(Register Reg,
155                                          MachineRegisterInfo *MRI) {
156   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
157                                                IE = MRI->use_instr_end();
158        I != IE; ++I) {
159     MachineInstr *UseMI = &*I;
160     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
161          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
162         UseMI->getOperand(1).getReg() == Reg)
163       return UseMI;
164   }
165   return nullptr;
166 }
167 
insertBitcasts(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)168 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
169                            MachineIRBuilder MIB) {
170   // Get access to information about available extensions
171   const SPIRVSubtarget *ST =
172       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
173   SmallVector<MachineInstr *, 10> ToErase;
174   for (MachineBasicBlock &MBB : MF) {
175     for (MachineInstr &MI : MBB) {
176       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
177           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
178         continue;
179       assert(MI.getOperand(2).isReg());
180       MIB.setInsertPt(*MI.getParent(), MI);
181       ToErase.push_back(&MI);
182       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
183         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
184         continue;
185       }
186       Register Def = MI.getOperand(0).getReg();
187       Register Source = MI.getOperand(2).getReg();
188       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
189       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
190       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
191           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
192           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
193 
194       // If the ptrcast would be redundant, replace all uses with the source
195       // register.
196       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
197         // Erase Def's assign type instruction if we are going to replace Def.
198         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MIB.getMRI()))
199           ToErase.push_back(AssignMI);
200         MIB.getMRI()->replaceRegWith(Def, Source);
201       } else {
202         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
203         MIB.buildBitcast(Def, Source);
204       }
205     }
206   }
207   for (MachineInstr *MI : ToErase)
208     MI->eraseFromParent();
209 }
210 
211 // Translating GV, IRTranslator sometimes generates following IR:
212 //   %1 = G_GLOBAL_VALUE
213 //   %2 = COPY %1
214 //   %3 = G_ADDRSPACE_CAST %2
215 //
216 // or
217 //
218 //  %1 = G_ZEXT %2
219 //  G_MEMCPY ... %2 ...
220 //
221 // New registers have no SPIRVType and no register class info.
222 //
223 // Set SPIRVType for GV, propagate it from GV to other instructions,
224 // also set register classes.
propagateSPIRVType(MachineInstr * MI,SPIRVGlobalRegistry * GR,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)225 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
226                                      MachineRegisterInfo &MRI,
227                                      MachineIRBuilder &MIB) {
228   SPIRVType *SpirvTy = nullptr;
229   assert(MI && "Machine instr is expected");
230   if (MI->getOperand(0).isReg()) {
231     Register Reg = MI->getOperand(0).getReg();
232     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
233     if (!SpirvTy) {
234       switch (MI->getOpcode()) {
235       case TargetOpcode::G_CONSTANT: {
236         MIB.setInsertPt(*MI->getParent(), MI);
237         Type *Ty = MI->getOperand(1).getCImm()->getType();
238         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
239         break;
240       }
241       case TargetOpcode::G_GLOBAL_VALUE: {
242         MIB.setInsertPt(*MI->getParent(), MI);
243         const GlobalValue *Global = MI->getOperand(1).getGlobal();
244         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
245         auto *Ty = TypedPointerType::get(ElementTy,
246                                          Global->getType()->getAddressSpace());
247         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
248         break;
249       }
250       case TargetOpcode::G_ANYEXT:
251       case TargetOpcode::G_SEXT:
252       case TargetOpcode::G_ZEXT: {
253         if (MI->getOperand(1).isReg()) {
254           if (MachineInstr *DefInstr =
255                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
256             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
257               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
258               unsigned ExpectedBW =
259                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
260               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
261               SpirvTy = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
262               if (NumElements > 1)
263                 SpirvTy =
264                     GR->getOrCreateSPIRVVectorType(SpirvTy, NumElements, MIB);
265             }
266           }
267         }
268         break;
269       }
270       case TargetOpcode::G_PTRTOINT:
271         SpirvTy = GR->getOrCreateSPIRVIntegerType(
272             MRI.getType(Reg).getScalarSizeInBits(), MIB);
273         break;
274       case TargetOpcode::G_TRUNC:
275       case TargetOpcode::G_ADDRSPACE_CAST:
276       case TargetOpcode::G_PTR_ADD:
277       case TargetOpcode::COPY: {
278         MachineOperand &Op = MI->getOperand(1);
279         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
280         if (Def)
281           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
282         break;
283       }
284       default:
285         break;
286       }
287       if (SpirvTy)
288         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
289       if (!MRI.getRegClassOrNull(Reg))
290         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
291     }
292   }
293   return SpirvTy;
294 }
295 
296 // To support current approach and limitations wrt. bit width here we widen a
297 // scalar register with a bit width greater than 1 to valid sizes and cap it to
298 // 64 width.
widenScalarLLTNextPow2(Register Reg,MachineRegisterInfo & MRI)299 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
300   LLT RegType = MRI.getType(Reg);
301   if (!RegType.isScalar())
302     return;
303   unsigned Sz = RegType.getScalarSizeInBits();
304   if (Sz == 1)
305     return;
306   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
307   if (NewSz != Sz)
308     MRI.setType(Reg, LLT::scalar(NewSz));
309 }
310 
311 static std::pair<Register, unsigned>
createNewIdReg(SPIRVType * SpvType,Register SrcReg,MachineRegisterInfo & MRI,const SPIRVGlobalRegistry & GR)312 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
313                const SPIRVGlobalRegistry &GR) {
314   if (!SpvType)
315     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
316   assert(SpvType && "VReg is expected to have SPIRV type");
317   LLT SrcLLT = MRI.getType(SrcReg);
318   LLT NewT = LLT::scalar(32);
319   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
320   bool IsVectorFloat =
321       SpvType->getOpcode() == SPIRV::OpTypeVector &&
322       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
323           SPIRV::OpTypeFloat;
324   IsFloat |= IsVectorFloat;
325   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
326   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
327   if (SrcLLT.isPointer()) {
328     unsigned PtrSz = GR.getPointerSize();
329     NewT = LLT::pointer(0, PtrSz);
330     bool IsVec = SrcLLT.isVector();
331     if (IsVec)
332       NewT = LLT::fixed_vector(2, NewT);
333     if (PtrSz == 64) {
334       if (IsVec) {
335         GetIdOp = SPIRV::GET_vpID64;
336         DstClass = &SPIRV::vpID64RegClass;
337       } else {
338         GetIdOp = SPIRV::GET_pID64;
339         DstClass = &SPIRV::pID64RegClass;
340       }
341     } else {
342       if (IsVec) {
343         GetIdOp = SPIRV::GET_vpID32;
344         DstClass = &SPIRV::vpID32RegClass;
345       } else {
346         GetIdOp = SPIRV::GET_pID32;
347         DstClass = &SPIRV::pID32RegClass;
348       }
349     }
350   } else if (SrcLLT.isVector()) {
351     NewT = LLT::fixed_vector(2, NewT);
352     if (IsFloat) {
353       GetIdOp = SPIRV::GET_vfID;
354       DstClass = &SPIRV::vfIDRegClass;
355     } else {
356       GetIdOp = SPIRV::GET_vID;
357       DstClass = &SPIRV::vIDRegClass;
358     }
359   }
360   Register IdReg = MRI.createGenericVirtualRegister(NewT);
361   MRI.setRegClass(IdReg, DstClass);
362   return {IdReg, GetIdOp};
363 }
364 
365 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
366 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
367 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
368 // It's used also in SPIRVBuiltins.cpp.
369 // TODO: maybe move to SPIRVUtils.
370 namespace llvm {
insertAssignInstr(Register Reg,Type * Ty,SPIRVType * SpirvTy,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIB,MachineRegisterInfo & MRI)371 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
372                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
373                            MachineRegisterInfo &MRI) {
374   MachineInstr *Def = MRI.getVRegDef(Reg);
375   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
376   MIB.setInsertPt(*Def->getParent(),
377                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
378                                       : Def->getParent()->end()));
379   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
380   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
381   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
382     MRI.setRegClass(NewReg, RC);
383   } else {
384     MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
385     MRI.setRegClass(Reg, &SPIRV::IDRegClass);
386   }
387   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
388   // This is to make it convenient for Legalizer to get the SPIRVType
389   // when processing the actual MI (i.e. not pseudo one).
390   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
391   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
392   // the flags after instruction selection.
393   const uint32_t Flags = Def->getFlags();
394   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
395       .addDef(Reg)
396       .addUse(NewReg)
397       .addUse(GR->getSPIRVTypeID(SpirvTy))
398       .setMIFlags(Flags);
399   Def->getOperand(0).setReg(NewReg);
400   return NewReg;
401 }
402 
processInstr(MachineInstr & MI,MachineIRBuilder & MIB,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)403 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
404                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
405   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
406   MachineInstr &AssignTypeInst =
407       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
408   auto NewReg =
409       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
410   AssignTypeInst.getOperand(1).setReg(NewReg);
411   MI.getOperand(0).setReg(NewReg);
412   MIB.setInsertPt(*MI.getParent(),
413                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
414                                     : MI.getParent()->end()));
415   for (auto &Op : MI.operands()) {
416     if (!Op.isReg() || Op.isDef())
417       continue;
418     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
419     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
420     Op.setReg(IdOpInfo.first);
421   }
422 }
423 } // namespace llvm
424 
425 static void
generateAssignInstrs(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB,DenseMap<MachineInstr *,Type * > & TargetExtConstTypes)426 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
427                      MachineIRBuilder MIB,
428                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
429   // Get access to information about available extensions
430   const SPIRVSubtarget *ST =
431       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
432 
433   MachineRegisterInfo &MRI = MF.getRegInfo();
434   SmallVector<MachineInstr *, 10> ToErase;
435   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
436 
437   for (MachineBasicBlock *MBB : post_order(&MF)) {
438     if (MBB->empty())
439       continue;
440 
441     bool ReachedBegin = false;
442     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
443          !ReachedBegin;) {
444       MachineInstr &MI = *MII;
445       unsigned MIOp = MI.getOpcode();
446 
447       // validate bit width of scalar registers
448       for (const auto &MOP : MI.operands())
449         if (MOP.isReg())
450           widenScalarLLTNextPow2(MOP.getReg(), MRI);
451 
452       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
453         Register Reg = MI.getOperand(1).getReg();
454         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
455         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
456         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
457         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
458             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
459             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
460         MachineInstr *Def = MRI.getVRegDef(Reg);
461         assert(Def && "Expecting an instruction that defines the register");
462         // G_GLOBAL_VALUE already has type info.
463         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
464             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
465           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
466                             MF.getRegInfo());
467         ToErase.push_back(&MI);
468       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
469         Register Reg = MI.getOperand(1).getReg();
470         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
471         MachineInstr *Def = MRI.getVRegDef(Reg);
472         assert(Def && "Expecting an instruction that defines the register");
473         // G_GLOBAL_VALUE already has type info.
474         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
475             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
476           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
477         ToErase.push_back(&MI);
478       } else if (MIOp == TargetOpcode::G_CONSTANT ||
479                  MIOp == TargetOpcode::G_FCONSTANT ||
480                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
481         // %rc = G_CONSTANT ty Val
482         // ===>
483         // %cty = OpType* ty
484         // %rctmp = G_CONSTANT ty Val
485         // %rc = ASSIGN_TYPE %rctmp, %cty
486         Register Reg = MI.getOperand(0).getReg();
487         bool NeedAssignType = true;
488         if (MRI.hasOneUse(Reg)) {
489           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
490           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
491               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
492             continue;
493         }
494         Type *Ty = nullptr;
495         if (MIOp == TargetOpcode::G_CONSTANT) {
496           auto TargetExtIt = TargetExtConstTypes.find(&MI);
497           Ty = TargetExtIt == TargetExtConstTypes.end()
498                    ? MI.getOperand(1).getCImm()->getType()
499                    : TargetExtIt->second;
500           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
501           Register PrimaryReg = GR->find(OpCI, &MF);
502           if (!PrimaryReg.isValid()) {
503             GR->add(OpCI, &MF, Reg);
504           } else if (PrimaryReg != Reg &&
505                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
506             auto *RCReg = MRI.getRegClassOrNull(Reg);
507             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
508             if (!RCReg || RCPrimary == RCReg) {
509               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
510               ToErase.push_back(&MI);
511               NeedAssignType = false;
512             }
513           }
514         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
515           Ty = MI.getOperand(1).getFPImm()->getType();
516         } else {
517           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
518           Type *ElemTy = nullptr;
519           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
520           assert(ElemMI);
521 
522           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
523             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
524           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
525             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
526           else
527             llvm_unreachable("Unexpected opcode");
528           unsigned NumElts =
529               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
530           Ty = VectorType::get(ElemTy, NumElts, false);
531         }
532         if (NeedAssignType)
533           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
534       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
535         propagateSPIRVType(&MI, GR, MRI, MIB);
536       }
537 
538       if (MII == Begin)
539         ReachedBegin = true;
540       else
541         --MII;
542     }
543   }
544   for (MachineInstr *MI : ToErase) {
545     auto It = RegsAlreadyAddedToDT.find(MI);
546     if (RegsAlreadyAddedToDT.contains(MI))
547       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
548     MI->eraseFromParent();
549   }
550 
551   // Address the case when IRTranslator introduces instructions with new
552   // registers without SPIRVType associated.
553   for (MachineBasicBlock &MBB : MF) {
554     for (MachineInstr &MI : MBB) {
555       switch (MI.getOpcode()) {
556       case TargetOpcode::G_TRUNC:
557       case TargetOpcode::G_ANYEXT:
558       case TargetOpcode::G_SEXT:
559       case TargetOpcode::G_ZEXT:
560       case TargetOpcode::G_PTRTOINT:
561       case TargetOpcode::COPY:
562       case TargetOpcode::G_ADDRSPACE_CAST:
563         propagateSPIRVType(&MI, GR, MRI, MIB);
564         break;
565       }
566     }
567   }
568 }
569 
570 // Defined in SPIRVLegalizerInfo.cpp.
571 extern bool isTypeFoldingSupported(unsigned Opcode);
572 
processInstrsWithTypeFolding(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)573 static void processInstrsWithTypeFolding(MachineFunction &MF,
574                                          SPIRVGlobalRegistry *GR,
575                                          MachineIRBuilder MIB) {
576   MachineRegisterInfo &MRI = MF.getRegInfo();
577   for (MachineBasicBlock &MBB : MF) {
578     for (MachineInstr &MI : MBB) {
579       if (isTypeFoldingSupported(MI.getOpcode()))
580         processInstr(MI, MIB, MRI, GR);
581     }
582   }
583 
584   for (MachineBasicBlock &MBB : MF) {
585     for (MachineInstr &MI : MBB) {
586       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
587       // to perform tblgen'erated selection and we can't do that on Legalizer
588       // as it operates on gMIR only.
589       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
590         continue;
591       Register SrcReg = MI.getOperand(1).getReg();
592       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
593       if (!isTypeFoldingSupported(Opcode))
594         continue;
595       Register DstReg = MI.getOperand(0).getReg();
596       bool IsDstPtr = MRI.getType(DstReg).isPointer();
597       bool isDstVec = MRI.getType(DstReg).isVector();
598       if (IsDstPtr || isDstVec)
599         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
600       // Don't need to reset type of register holding constant and used in
601       // G_ADDRSPACE_CAST, since it breaks legalizer.
602       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
603         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
604         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
605           continue;
606       }
607       MRI.setType(DstReg, IsDstPtr ? LLT::pointer(0, GR->getPointerSize())
608                                    : LLT::scalar(32));
609     }
610   }
611 }
612 
613 static void
insertInlineAsmProcess(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & ST,MachineIRBuilder MIRBuilder,const SmallVector<MachineInstr * > & ToProcess)614 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
615                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
616                        const SmallVector<MachineInstr *> &ToProcess) {
617   MachineRegisterInfo &MRI = MF.getRegInfo();
618   Register AsmTargetReg;
619   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
620     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
621     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
622     MIRBuilder.setInsertPt(*I1->getParent(), *I1);
623 
624     if (!AsmTargetReg.isValid()) {
625       // define vendor specific assembly target or dialect
626       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
627       MRI.setRegClass(AsmTargetReg, &SPIRV::IDRegClass);
628       auto AsmTargetMIB =
629           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
630       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
631       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
632     }
633 
634     // create types
635     const MDNode *IAMD = I1->getOperand(1).getMetadata();
636     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
637     SmallVector<SPIRVType *, 4> ArgTypes;
638     for (const auto &ArgTy : FTy->params())
639       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
640     SPIRVType *RetType =
641         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
642     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
643         FTy, RetType, ArgTypes, MIRBuilder);
644 
645     // define vendor specific assembly instructions string
646     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
647     MRI.setRegClass(AsmReg, &SPIRV::IDRegClass);
648     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
649                       .addDef(AsmReg)
650                       .addUse(GR->getSPIRVTypeID(RetType))
651                       .addUse(GR->getSPIRVTypeID(FuncType))
652                       .addUse(AsmTargetReg);
653     // inline asm string:
654     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
655                  AsmMIB);
656     // inline asm constraint string:
657     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
658                      ->getString(),
659                  AsmMIB);
660     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
661 
662     // calls the inline assembly instruction
663     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
664     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
665       MIRBuilder.buildInstr(SPIRV::OpDecorate)
666           .addUse(AsmReg)
667           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
668     Register DefReg;
669     SmallVector<unsigned, 4> Ops;
670     unsigned StartOp = InlineAsm::MIOp_FirstOperand,
671              AsmDescOp = InlineAsm::MIOp_FirstOperand;
672     unsigned I2Sz = I2->getNumOperands();
673     for (unsigned Idx = StartOp; Idx != I2Sz; ++Idx) {
674       const MachineOperand &MO = I2->getOperand(Idx);
675       if (MO.isMetadata())
676         continue;
677       if (Idx == AsmDescOp && MO.isImm()) {
678         // compute the index of the next operand descriptor
679         const InlineAsm::Flag F(MO.getImm());
680         AsmDescOp += 1 + F.getNumOperandRegisters();
681       } else {
682         if (MO.isReg() && MO.isDef())
683           DefReg = MO.getReg();
684         else
685           Ops.push_back(Idx);
686       }
687     }
688     if (!DefReg.isValid()) {
689       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
690       MRI.setRegClass(DefReg, &SPIRV::IDRegClass);
691       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
692           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
693       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
694     }
695     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
696                        .addDef(DefReg)
697                        .addUse(GR->getSPIRVTypeID(RetType))
698                        .addUse(AsmReg);
699     unsigned IntrIdx = 2;
700     for (unsigned Idx : Ops) {
701       ++IntrIdx;
702       const MachineOperand &MO = I2->getOperand(Idx);
703       if (MO.isReg())
704         AsmCall.addUse(MO.getReg());
705       else
706         AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
707     }
708   }
709   for (MachineInstr *MI : ToProcess)
710     MI->eraseFromParent();
711 }
712 
insertInlineAsm(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & ST,MachineIRBuilder MIRBuilder)713 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
714                             const SPIRVSubtarget &ST,
715                             MachineIRBuilder MIRBuilder) {
716   SmallVector<MachineInstr *> ToProcess;
717   for (MachineBasicBlock &MBB : MF) {
718     for (MachineInstr &MI : MBB) {
719       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
720           MI.getOpcode() == TargetOpcode::INLINEASM)
721         ToProcess.push_back(&MI);
722     }
723   }
724   if (ToProcess.size() == 0)
725     return;
726 
727   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
728     report_fatal_error("Inline assembly instructions require the "
729                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
730                        false);
731 
732   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
733 }
734 
insertSpirvDecorations(MachineFunction & MF,MachineIRBuilder MIB)735 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
736   SmallVector<MachineInstr *, 10> ToErase;
737   for (MachineBasicBlock &MBB : MF) {
738     for (MachineInstr &MI : MBB) {
739       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
740         continue;
741       MIB.setInsertPt(*MI.getParent(), MI);
742       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
743                               MI.getOperand(2).getMetadata());
744       ToErase.push_back(&MI);
745     }
746   }
747   for (MachineInstr *MI : ToErase)
748     MI->eraseFromParent();
749 }
750 
751 // Find basic blocks of the switch and replace registers in spv_switch() by its
752 // MBB equivalent.
processSwitches(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)753 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
754                             MachineIRBuilder MIB) {
755   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
756   SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
757       Switches;
758   for (MachineBasicBlock &MBB : MF) {
759     MachineRegisterInfo &MRI = MF.getRegInfo();
760     BB2MBB[MBB.getBasicBlock()] = &MBB;
761     for (MachineInstr &MI : MBB) {
762       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
763         continue;
764       // Calls to spv_switch intrinsics representing IR switches.
765       SmallVector<MachineInstr *, 8> NewOps;
766       for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
767         Register Reg = MI.getOperand(i).getReg();
768         if (i % 2 == 1) {
769           MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
770           NewOps.push_back(ConstInstr);
771         } else {
772           MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
773           assert(BuildMBB &&
774                  BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
775                  BuildMBB->getOperand(1).isBlockAddress() &&
776                  BuildMBB->getOperand(1).getBlockAddress());
777           NewOps.push_back(BuildMBB);
778         }
779       }
780       Switches.push_back(std::make_pair(&MI, NewOps));
781     }
782   }
783 
784   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
785   for (auto &SwIt : Switches) {
786     MachineInstr &MI = *SwIt.first;
787     SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
788     SmallVector<MachineOperand, 8> NewOps;
789     for (unsigned i = 0; i < Ins.size(); ++i) {
790       if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
791         BasicBlock *CaseBB =
792             Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
793         auto It = BB2MBB.find(CaseBB);
794         if (It == BB2MBB.end())
795           report_fatal_error("cannot find a machine basic block by a basic "
796                              "block in a switch statement");
797         NewOps.push_back(MachineOperand::CreateMBB(It->second));
798         MI.getParent()->addSuccessor(It->second);
799         ToEraseMI.insert(Ins[i]);
800       } else {
801         NewOps.push_back(
802             MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
803       }
804     }
805     for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
806       MI.removeOperand(i);
807     for (auto &MO : NewOps)
808       MI.addOperand(MO);
809     if (MachineInstr *Next = MI.getNextNode()) {
810       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
811         ToEraseMI.insert(Next);
812         Next = MI.getNextNode();
813       }
814       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
815         ToEraseMI.insert(Next);
816     }
817   }
818 
819   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
820   // this leaves their BasicBlock counterparts in a "address taken" status. This
821   // would make AsmPrinter to generate a series of unneeded labels of a "Address
822   // of block that was removed by CodeGen" kind. Let's first ensure that we
823   // don't have a dangling BlockAddress constants by zapping the BlockAddress
824   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
825   Constant *Replacement =
826       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
827   for (MachineInstr *BlockAddrI : ToEraseMI) {
828     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
829       BlockAddress *BA = const_cast<BlockAddress *>(
830           BlockAddrI->getOperand(1).getBlockAddress());
831       BA->replaceAllUsesWith(
832           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
833       BA->destroyConstant();
834     }
835     BlockAddrI->eraseFromParent();
836   }
837 }
838 
isImplicitFallthrough(MachineBasicBlock & MBB)839 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
840   if (MBB.empty())
841     return true;
842 
843   // Branching SPIR-V intrinsics are not detected by this generic method.
844   // Thus, we can only trust negative result.
845   if (!MBB.canFallThrough())
846     return false;
847 
848   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
849   // prevent an implicit fallthrough.
850   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
851        It != E; ++It) {
852     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
853       return false;
854   }
855   return true;
856 }
857 
removeImplicitFallthroughs(MachineFunction & MF,MachineIRBuilder MIB)858 static void removeImplicitFallthroughs(MachineFunction &MF,
859                                        MachineIRBuilder MIB) {
860   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
861   // In such cases, they will simply fallthrough their immediate successor.
862   for (MachineBasicBlock &MBB : MF) {
863     if (!isImplicitFallthrough(MBB))
864       continue;
865 
866     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
867            1);
868     MIB.setInsertPt(MBB, MBB.end());
869     MIB.buildBr(**MBB.successors().begin());
870   }
871 }
872 
runOnMachineFunction(MachineFunction & MF)873 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
874   // Initialize the type registry.
875   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
876   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
877   GR->setCurrentFunc(MF);
878   MachineIRBuilder MIB(MF);
879   // a registry of target extension constants
880   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
881   // to keep record of tracked constants
882   SmallSet<Register, 4> TrackedConstRegs;
883   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
884   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
885   insertBitcasts(MF, GR, MIB);
886   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
887   processSwitches(MF, GR, MIB);
888   processInstrsWithTypeFolding(MF, GR, MIB);
889   removeImplicitFallthroughs(MF, MIB);
890   insertSpirvDecorations(MF, MIB);
891   insertInlineAsm(MF, GR, ST, MIB);
892 
893   return true;
894 }
895 
896 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
897                 false)
898 
899 char SPIRVPreLegalizer::ID = 0;
900 
createSPIRVPreLegalizerPass()901 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
902   return new SPIRVPreLegalizer();
903 }
904