xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42   MachineRegisterInfo &MRI = MF.getRegInfo();
43   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45   for (MachineBasicBlock &MBB : MF) {
46     for (MachineInstr &MI : MBB) {
47       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48         continue;
49       ToErase.push_back(&MI);
50       auto *Const =
51           cast<Constant>(cast<ConstantAsMetadata>(
52                              MI.getOperand(3).getMetadata()->getOperand(0))
53                              ->getValue());
54       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
55         Register Reg = GR->find(GV, &MF);
56         if (!Reg.isValid())
57           GR->add(GV, &MF, MI.getOperand(2).getReg());
58         else
59           RegsAlreadyAddedToDT[&MI] = Reg;
60       } else {
61         Register Reg = GR->find(Const, &MF);
62         if (!Reg.isValid()) {
63           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
64             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
65             assert(BuildVec &&
66                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
68               GR->add(ConstVec->getElementAsConstant(i), &MF,
69                       BuildVec->getOperand(1 + i).getReg());
70           }
71           GR->add(Const, &MF, MI.getOperand(2).getReg());
72         } else {
73           RegsAlreadyAddedToDT[&MI] = Reg;
74           // This MI is unused and will be removed. If the MI uses
75           // const_composite, it will be unused and should be removed too.
76           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
77           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
78           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
79             ToEraseComposites.push_back(SrcMI);
80         }
81       }
82     }
83   }
84   for (MachineInstr *MI : ToErase) {
85     Register Reg = MI->getOperand(2).getReg();
86     if (RegsAlreadyAddedToDT.contains(MI))
87       Reg = RegsAlreadyAddedToDT[MI];
88     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
89     if (!MRI.getRegClassOrNull(Reg) && RC)
90       MRI.setRegClass(Reg, RC);
91     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
92     MI->eraseFromParent();
93   }
94   for (MachineInstr *MI : ToEraseComposites)
95     MI->eraseFromParent();
96 }
97 
98 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
99   SmallVector<MachineInstr *, 10> ToErase;
100   MachineRegisterInfo &MRI = MF.getRegInfo();
101   const unsigned AssignNameOperandShift = 2;
102   for (MachineBasicBlock &MBB : MF) {
103     for (MachineInstr &MI : MBB) {
104       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
105         continue;
106       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
107       while (MI.getOperand(NumOp).isReg()) {
108         MachineOperand &MOp = MI.getOperand(NumOp);
109         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
110         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
111         MI.removeOperand(NumOp);
112         MI.addOperand(MachineOperand::CreateImm(
113             ConstMI->getOperand(1).getCImm()->getZExtValue()));
114         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
115           ToErase.push_back(ConstMI);
116       }
117     }
118   }
119   for (MachineInstr *MI : ToErase)
120     MI->eraseFromParent();
121 }
122 
123 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
124                            MachineIRBuilder MIB) {
125   SmallVector<MachineInstr *, 10> ToErase;
126   for (MachineBasicBlock &MBB : MF) {
127     for (MachineInstr &MI : MBB) {
128       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
129           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
130         continue;
131       assert(MI.getOperand(2).isReg());
132       MIB.setInsertPt(*MI.getParent(), MI);
133       ToErase.push_back(&MI);
134       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
135         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
136         continue;
137       }
138       Register Def = MI.getOperand(0).getReg();
139       Register Source = MI.getOperand(2).getReg();
140       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
141           getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
142       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
143           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
144           addressSpaceToStorageClass(MI.getOperand(4).getImm()));
145 
146       // If the bitcast would be redundant, replace all uses with the source
147       // register.
148       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
149         MIB.getMRI()->replaceRegWith(Def, Source);
150       } else {
151         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
152         MIB.buildBitcast(Def, Source);
153       }
154     }
155   }
156   for (MachineInstr *MI : ToErase)
157     MI->eraseFromParent();
158 }
159 
160 // Translating GV, IRTranslator sometimes generates following IR:
161 //   %1 = G_GLOBAL_VALUE
162 //   %2 = COPY %1
163 //   %3 = G_ADDRSPACE_CAST %2
164 // New registers have no SPIRVType and no register class info.
165 //
166 // Set SPIRVType for GV, propagate it from GV to other instructions,
167 // also set register classes.
168 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
169                                      MachineRegisterInfo &MRI,
170                                      MachineIRBuilder &MIB) {
171   SPIRVType *SpirvTy = nullptr;
172   assert(MI && "Machine instr is expected");
173   if (MI->getOperand(0).isReg()) {
174     Register Reg = MI->getOperand(0).getReg();
175     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
176     if (!SpirvTy) {
177       switch (MI->getOpcode()) {
178       case TargetOpcode::G_CONSTANT: {
179         MIB.setInsertPt(*MI->getParent(), MI);
180         Type *Ty = MI->getOperand(1).getCImm()->getType();
181         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
182         break;
183       }
184       case TargetOpcode::G_GLOBAL_VALUE: {
185         MIB.setInsertPt(*MI->getParent(), MI);
186         Type *Ty = MI->getOperand(1).getGlobal()->getType();
187         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
188         break;
189       }
190       case TargetOpcode::G_TRUNC:
191       case TargetOpcode::G_ADDRSPACE_CAST:
192       case TargetOpcode::G_PTR_ADD:
193       case TargetOpcode::COPY: {
194         MachineOperand &Op = MI->getOperand(1);
195         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
196         if (Def)
197           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
198         break;
199       }
200       default:
201         break;
202       }
203       if (SpirvTy)
204         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
205       if (!MRI.getRegClassOrNull(Reg))
206         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
207     }
208   }
209   return SpirvTy;
210 }
211 
212 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
213 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
214 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
215 // It's used also in SPIRVBuiltins.cpp.
216 // TODO: maybe move to SPIRVUtils.
217 namespace llvm {
218 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
219                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
220                            MachineRegisterInfo &MRI) {
221   MachineInstr *Def = MRI.getVRegDef(Reg);
222   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
223   MIB.setInsertPt(*Def->getParent(),
224                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
225                                       : Def->getParent()->end()));
226   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
227   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
228     MRI.setRegClass(NewReg, RC);
229   } else {
230     MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
231     MRI.setRegClass(Reg, &SPIRV::IDRegClass);
232   }
233   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
234   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
235   // This is to make it convenient for Legalizer to get the SPIRVType
236   // when processing the actual MI (i.e. not pseudo one).
237   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
238   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
239   // the flags after instruction selection.
240   const uint32_t Flags = Def->getFlags();
241   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
242       .addDef(Reg)
243       .addUse(NewReg)
244       .addUse(GR->getSPIRVTypeID(SpirvTy))
245       .setMIFlags(Flags);
246   Def->getOperand(0).setReg(NewReg);
247   return NewReg;
248 }
249 } // namespace llvm
250 
251 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
252                                  MachineIRBuilder MIB) {
253   MachineRegisterInfo &MRI = MF.getRegInfo();
254   SmallVector<MachineInstr *, 10> ToErase;
255 
256   for (MachineBasicBlock *MBB : post_order(&MF)) {
257     if (MBB->empty())
258       continue;
259 
260     bool ReachedBegin = false;
261     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
262          !ReachedBegin;) {
263       MachineInstr &MI = *MII;
264 
265       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
266         Register Reg = MI.getOperand(1).getReg();
267         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
268         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
269             getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
270         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
271             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
272             addressSpaceToStorageClass(MI.getOperand(3).getImm()));
273         MachineInstr *Def = MRI.getVRegDef(Reg);
274         assert(Def && "Expecting an instruction that defines the register");
275         insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
276                           MF.getRegInfo());
277         ToErase.push_back(&MI);
278       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
279         Register Reg = MI.getOperand(1).getReg();
280         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
281         MachineInstr *Def = MRI.getVRegDef(Reg);
282         assert(Def && "Expecting an instruction that defines the register");
283         // G_GLOBAL_VALUE already has type info.
284         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
285           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
286         ToErase.push_back(&MI);
287       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
288                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
289                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
290         // %rc = G_CONSTANT ty Val
291         // ===>
292         // %cty = OpType* ty
293         // %rctmp = G_CONSTANT ty Val
294         // %rc = ASSIGN_TYPE %rctmp, %cty
295         Register Reg = MI.getOperand(0).getReg();
296         if (MRI.hasOneUse(Reg)) {
297           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
298           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
299               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
300             continue;
301         }
302         Type *Ty = nullptr;
303         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
304           Ty = MI.getOperand(1).getCImm()->getType();
305         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
306           Ty = MI.getOperand(1).getFPImm()->getType();
307         else {
308           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
309           Type *ElemTy = nullptr;
310           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
311           assert(ElemMI);
312 
313           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
314             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
315           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
316             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
317           else
318             llvm_unreachable("Unexpected opcode");
319           unsigned NumElts =
320               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
321           Ty = VectorType::get(ElemTy, NumElts, false);
322         }
323         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
324       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
325                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
326                  MI.getOpcode() == TargetOpcode::COPY ||
327                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
328         propagateSPIRVType(&MI, GR, MRI, MIB);
329       }
330 
331       if (MII == Begin)
332         ReachedBegin = true;
333       else
334         --MII;
335     }
336   }
337   for (MachineInstr *MI : ToErase)
338     MI->eraseFromParent();
339 }
340 
341 static std::pair<Register, unsigned>
342 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
343                const SPIRVGlobalRegistry &GR) {
344   LLT NewT = LLT::scalar(32);
345   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
346   assert(SpvType && "VReg is expected to have SPIRV type");
347   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
348   bool IsVectorFloat =
349       SpvType->getOpcode() == SPIRV::OpTypeVector &&
350       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
351           SPIRV::OpTypeFloat;
352   IsFloat |= IsVectorFloat;
353   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
354   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
355   if (MRI.getType(ValReg).isPointer()) {
356     NewT = LLT::pointer(0, 32);
357     GetIdOp = SPIRV::GET_pID;
358     DstClass = &SPIRV::pIDRegClass;
359   } else if (MRI.getType(ValReg).isVector()) {
360     NewT = LLT::fixed_vector(2, NewT);
361     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
362     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
363   }
364   Register IdReg = MRI.createGenericVirtualRegister(NewT);
365   MRI.setRegClass(IdReg, DstClass);
366   return {IdReg, GetIdOp};
367 }
368 
369 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
370                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
371   unsigned Opc = MI.getOpcode();
372   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
373   MachineInstr &AssignTypeInst =
374       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
375   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
376   AssignTypeInst.getOperand(1).setReg(NewReg);
377   MI.getOperand(0).setReg(NewReg);
378   MIB.setInsertPt(*MI.getParent(),
379                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
380                                     : MI.getParent()->end()));
381   for (auto &Op : MI.operands()) {
382     if (!Op.isReg() || Op.isDef())
383       continue;
384     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
385     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
386     Op.setReg(IdOpInfo.first);
387   }
388 }
389 
390 // Defined in SPIRVLegalizerInfo.cpp.
391 extern bool isTypeFoldingSupported(unsigned Opcode);
392 
393 static void processInstrsWithTypeFolding(MachineFunction &MF,
394                                          SPIRVGlobalRegistry *GR,
395                                          MachineIRBuilder MIB) {
396   MachineRegisterInfo &MRI = MF.getRegInfo();
397   for (MachineBasicBlock &MBB : MF) {
398     for (MachineInstr &MI : MBB) {
399       if (isTypeFoldingSupported(MI.getOpcode()))
400         processInstr(MI, MIB, MRI, GR);
401     }
402   }
403   for (MachineBasicBlock &MBB : MF) {
404     for (MachineInstr &MI : MBB) {
405       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
406       // to perform tblgen'erated selection and we can't do that on Legalizer
407       // as it operates on gMIR only.
408       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
409         continue;
410       Register SrcReg = MI.getOperand(1).getReg();
411       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
412       if (!isTypeFoldingSupported(Opcode))
413         continue;
414       Register DstReg = MI.getOperand(0).getReg();
415       if (MRI.getType(DstReg).isVector())
416         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
417       // Don't need to reset type of register holding constant and used in
418       // G_ADDRSPACE_CAST, since it braaks legalizer.
419       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
420         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
421         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
422           continue;
423       }
424       MRI.setType(DstReg, LLT::scalar(32));
425     }
426   }
427 }
428 
429 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
430                             MachineIRBuilder MIB) {
431   // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
432   // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
433   // + G_BR triples. A switch with two cases may be transformed to this MIR
434   // sequence:
435   //
436   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
437   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
438   //   G_BRCOND %Dst0, %bb.2
439   //   G_BR %bb.5
440   // bb.5.entry:
441   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
442   //   G_BRCOND %Dst1, %bb.3
443   //   G_BR %bb.4
444   // bb.2.sw.bb:
445   //   ...
446   // bb.3.sw.bb1:
447   //   ...
448   // bb.4.sw.epilog:
449   //   ...
450   //
451   // Sometimes (in case of range-compare switches), additional G_SUBs
452   // instructions are inserted before G_ICMPs. Those need to be additionally
453   // processed.
454   //
455   // This function modifies spv_switch call's operands to include destination
456   // MBBs (default and for each constant value).
457   //
458   // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
459   // G_BR sequences.
460 
461   MachineRegisterInfo &MRI = MF.getRegInfo();
462 
463   // Collect spv_switches and G_ICMPs across all MBBs in MF.
464   std::vector<MachineInstr *> RelevantInsts;
465 
466   // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
467   // After updating spv_switches, the instructions can be removed.
468   std::vector<MachineInstr *> PostUpdateArtifacts;
469 
470   // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
471   // spv_switch use these registers.
472   DenseSet<Register> CompareRegs;
473   for (MachineBasicBlock &MBB : MF) {
474     for (MachineInstr &MI : MBB) {
475       // Calls to spv_switch intrinsics representing IR switches.
476       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
477         assert(MI.getOperand(1).isReg());
478         CompareRegs.insert(MI.getOperand(1).getReg());
479         RelevantInsts.push_back(&MI);
480       }
481 
482       // G_SUBs coming from range-compare switch lowering. G_SUBs are found
483       // after spv_switch but before G_ICMP.
484       if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
485           CompareRegs.contains(MI.getOperand(1).getReg())) {
486         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
487         Register Dst = MI.getOperand(0).getReg();
488         CompareRegs.insert(Dst);
489         PostUpdateArtifacts.push_back(&MI);
490       }
491 
492       // G_ICMPs relating to switches.
493       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
494           CompareRegs.contains(MI.getOperand(2).getReg())) {
495         Register Dst = MI.getOperand(0).getReg();
496         RelevantInsts.push_back(&MI);
497         PostUpdateArtifacts.push_back(&MI);
498         MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
499         assert(CBr->getOpcode() == SPIRV::G_BRCOND);
500         PostUpdateArtifacts.push_back(CBr);
501         MachineInstr *Br = CBr->getNextNode();
502         assert(Br->getOpcode() == SPIRV::G_BR);
503         PostUpdateArtifacts.push_back(Br);
504       }
505     }
506   }
507 
508   // Update each spv_switch with destination MBBs.
509   for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
510     if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
511       continue;
512 
513     // Currently considered spv_switch.
514     MachineInstr *Switch = *i;
515     // Set the first successor as default MBB to support empty switches.
516     MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
517     // Container for mapping values to MMBs.
518     SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
519 
520     // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
521     // spv_switch (i) and break at any spv_switch with the same compare
522     // register (indicating we are back at the same scope).
523     Register CompareReg = Switch->getOperand(1).getReg();
524     for (auto j = i + 1; j != RelevantInsts.end(); j++) {
525       if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
526           (*j)->getOperand(1).getReg() == CompareReg)
527         break;
528 
529       if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
530             (*j)->getOperand(2).getReg() == CompareReg))
531         continue;
532 
533       MachineInstr *ICMP = *j;
534       Register Dst = ICMP->getOperand(0).getReg();
535       MachineOperand &PredOp = ICMP->getOperand(1);
536       const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
537       assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
538              MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
539       uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
540       MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
541       assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
542       MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
543 
544       // Map switch case Value to target MBB.
545       ValuesToMBBs[Value] = MBB;
546 
547       // Add target MBB as successor to the switch's MBB.
548       Switch->getParent()->addSuccessor(MBB);
549 
550       // The next MI is always G_BR to either the next case or the default.
551       MachineInstr *NextMI = CBr->getNextNode();
552       assert(NextMI->getOpcode() == SPIRV::G_BR &&
553              NextMI->getOperand(0).isMBB());
554       MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
555       // Default MBB does not begin with G_ICMP using spv_switch compare
556       // register.
557       if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
558           (NextMBB->front().getOperand(2).isReg() &&
559            NextMBB->front().getOperand(2).getReg() != CompareReg)) {
560         // Set default MBB and add it as successor to the switch's MBB.
561         DefaultMBB = NextMBB;
562         Switch->getParent()->addSuccessor(DefaultMBB);
563       }
564     }
565 
566     // Modify considered spv_switch operands using collected Values and
567     // MBBs.
568     SmallVector<const ConstantInt *, 3> Values;
569     SmallVector<MachineBasicBlock *, 3> MBBs;
570     for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
571       Register CReg = Switch->getOperand(k).getReg();
572       uint64_t Val = getIConstVal(CReg, &MRI);
573       MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
574       if (!ValuesToMBBs[Val])
575         continue;
576 
577       Values.push_back(ConstInstr->getOperand(1).getCImm());
578       MBBs.push_back(ValuesToMBBs[Val]);
579     }
580 
581     for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
582       Switch->removeOperand(k);
583 
584     Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
585     for (unsigned k = 0; k < Values.size(); k++) {
586       Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
587       Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
588     }
589   }
590 
591   for (MachineInstr *MI : PostUpdateArtifacts) {
592     MachineBasicBlock *ParentMBB = MI->getParent();
593     MI->eraseFromParent();
594     // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
595     // can be safely assumed, there are no breaks or phis directing into this
596     // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
597     // erased top-down.
598     if (ParentMBB->empty()) {
599       while (!ParentMBB->pred_empty())
600         (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);
601 
602       while (!ParentMBB->succ_empty())
603         ParentMBB->removeSuccessor(ParentMBB->succ_begin());
604 
605       ParentMBB->eraseFromParent();
606     }
607   }
608 }
609 
610 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
611   if (MBB.empty())
612     return true;
613 
614   // Branching SPIR-V intrinsics are not detected by this generic method.
615   // Thus, we can only trust negative result.
616   if (!MBB.canFallThrough())
617     return false;
618 
619   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
620   // prevent an implicit fallthrough.
621   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
622        It != E; ++It) {
623     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
624       return false;
625   }
626   return true;
627 }
628 
629 static void removeImplicitFallthroughs(MachineFunction &MF,
630                                        MachineIRBuilder MIB) {
631   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
632   // In such cases, they will simply fallthrough their immediate successor.
633   for (MachineBasicBlock &MBB : MF) {
634     if (!isImplicitFallthrough(MBB))
635       continue;
636 
637     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
638            1);
639     MIB.setInsertPt(MBB, MBB.end());
640     MIB.buildBr(**MBB.successors().begin());
641   }
642 }
643 
644 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
645   // Initialize the type registry.
646   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
647   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
648   GR->setCurrentFunc(MF);
649   MachineIRBuilder MIB(MF);
650   addConstantsToTrack(MF, GR);
651   foldConstantsIntoIntrinsics(MF);
652   insertBitcasts(MF, GR, MIB);
653   generateAssignInstrs(MF, GR, MIB);
654   processSwitches(MF, GR, MIB);
655   processInstrsWithTypeFolding(MF, GR, MIB);
656   removeImplicitFallthroughs(MF, MIB);
657 
658   return true;
659 }
660 
661 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
662                 false)
663 
664 char SPIRVPreLegalizer::ID = 0;
665 
666 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
667   return new SPIRVPreLegalizer();
668 }
669