xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision e9b261f297cac146f0c9f895c16debe1c4cf8978)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42   MachineRegisterInfo &MRI = MF.getRegInfo();
43   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45   for (MachineBasicBlock &MBB : MF) {
46     for (MachineInstr &MI : MBB) {
47       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48         continue;
49       ToErase.push_back(&MI);
50       auto *Const =
51           cast<Constant>(cast<ConstantAsMetadata>(
52                              MI.getOperand(3).getMetadata()->getOperand(0))
53                              ->getValue());
54       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
55         Register Reg = GR->find(GV, &MF);
56         if (!Reg.isValid())
57           GR->add(GV, &MF, MI.getOperand(2).getReg());
58         else
59           RegsAlreadyAddedToDT[&MI] = Reg;
60       } else {
61         Register Reg = GR->find(Const, &MF);
62         if (!Reg.isValid()) {
63           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
64             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
65             assert(BuildVec &&
66                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
68               GR->add(ConstVec->getElementAsConstant(i), &MF,
69                       BuildVec->getOperand(1 + i).getReg());
70           }
71           GR->add(Const, &MF, MI.getOperand(2).getReg());
72         } else {
73           RegsAlreadyAddedToDT[&MI] = Reg;
74           // This MI is unused and will be removed. If the MI uses
75           // const_composite, it will be unused and should be removed too.
76           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
77           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
78           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
79             ToEraseComposites.push_back(SrcMI);
80         }
81       }
82     }
83   }
84   for (MachineInstr *MI : ToErase) {
85     Register Reg = MI->getOperand(2).getReg();
86     if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
87       Reg = RegsAlreadyAddedToDT[MI];
88     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
89     if (!MRI.getRegClassOrNull(Reg) && RC)
90       MRI.setRegClass(Reg, RC);
91     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
92     MI->eraseFromParent();
93   }
94   for (MachineInstr *MI : ToEraseComposites)
95     MI->eraseFromParent();
96 }
97 
98 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
99   SmallVector<MachineInstr *, 10> ToErase;
100   MachineRegisterInfo &MRI = MF.getRegInfo();
101   const unsigned AssignNameOperandShift = 2;
102   for (MachineBasicBlock &MBB : MF) {
103     for (MachineInstr &MI : MBB) {
104       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
105         continue;
106       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
107       while (MI.getOperand(NumOp).isReg()) {
108         MachineOperand &MOp = MI.getOperand(NumOp);
109         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
110         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
111         MI.removeOperand(NumOp);
112         MI.addOperand(MachineOperand::CreateImm(
113             ConstMI->getOperand(1).getCImm()->getZExtValue()));
114         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
115           ToErase.push_back(ConstMI);
116       }
117     }
118   }
119   for (MachineInstr *MI : ToErase)
120     MI->eraseFromParent();
121 }
122 
123 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
124                            MachineIRBuilder MIB) {
125   SmallVector<MachineInstr *, 10> ToErase;
126   for (MachineBasicBlock &MBB : MF) {
127     for (MachineInstr &MI : MBB) {
128       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
129         continue;
130       assert(MI.getOperand(2).isReg());
131       MIB.setInsertPt(*MI.getParent(), MI);
132       MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
133       ToErase.push_back(&MI);
134     }
135   }
136   for (MachineInstr *MI : ToErase)
137     MI->eraseFromParent();
138 }
139 
140 // Translating GV, IRTranslator sometimes generates following IR:
141 //   %1 = G_GLOBAL_VALUE
142 //   %2 = COPY %1
143 //   %3 = G_ADDRSPACE_CAST %2
144 // New registers have no SPIRVType and no register class info.
145 //
146 // Set SPIRVType for GV, propagate it from GV to other instructions,
147 // also set register classes.
148 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
149                                      MachineRegisterInfo &MRI,
150                                      MachineIRBuilder &MIB) {
151   SPIRVType *SpirvTy = nullptr;
152   assert(MI && "Machine instr is expected");
153   if (MI->getOperand(0).isReg()) {
154     Register Reg = MI->getOperand(0).getReg();
155     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
156     if (!SpirvTy) {
157       switch (MI->getOpcode()) {
158       case TargetOpcode::G_CONSTANT: {
159         MIB.setInsertPt(*MI->getParent(), MI);
160         Type *Ty = MI->getOperand(1).getCImm()->getType();
161         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
162         break;
163       }
164       case TargetOpcode::G_GLOBAL_VALUE: {
165         MIB.setInsertPt(*MI->getParent(), MI);
166         Type *Ty = MI->getOperand(1).getGlobal()->getType();
167         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
168         break;
169       }
170       case TargetOpcode::G_TRUNC:
171       case TargetOpcode::G_ADDRSPACE_CAST:
172       case TargetOpcode::G_PTR_ADD:
173       case TargetOpcode::COPY: {
174         MachineOperand &Op = MI->getOperand(1);
175         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
176         if (Def)
177           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
178         break;
179       }
180       default:
181         break;
182       }
183       if (SpirvTy)
184         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
185       if (!MRI.getRegClassOrNull(Reg))
186         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
187     }
188   }
189   return SpirvTy;
190 }
191 
192 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
193 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
194 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
195 // It's used also in SPIRVBuiltins.cpp.
196 // TODO: maybe move to SPIRVUtils.
197 namespace llvm {
198 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
199                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
200                            MachineRegisterInfo &MRI) {
201   MachineInstr *Def = MRI.getVRegDef(Reg);
202   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
203   MIB.setInsertPt(*Def->getParent(),
204                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
205                                       : Def->getParent()->end()));
206   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
207   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
208     MRI.setRegClass(NewReg, RC);
209   } else {
210     MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
211     MRI.setRegClass(Reg, &SPIRV::IDRegClass);
212   }
213   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
214   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
215   // This is to make it convenient for Legalizer to get the SPIRVType
216   // when processing the actual MI (i.e. not pseudo one).
217   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
218   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
219   // the flags after instruction selection.
220   const uint32_t Flags = Def->getFlags();
221   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
222       .addDef(Reg)
223       .addUse(NewReg)
224       .addUse(GR->getSPIRVTypeID(SpirvTy))
225       .setMIFlags(Flags);
226   Def->getOperand(0).setReg(NewReg);
227   return NewReg;
228 }
229 } // namespace llvm
230 
231 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
232                                  MachineIRBuilder MIB) {
233   MachineRegisterInfo &MRI = MF.getRegInfo();
234   SmallVector<MachineInstr *, 10> ToErase;
235 
236   for (MachineBasicBlock *MBB : post_order(&MF)) {
237     if (MBB->empty())
238       continue;
239 
240     bool ReachedBegin = false;
241     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
242          !ReachedBegin;) {
243       MachineInstr &MI = *MII;
244 
245       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
246         Register Reg = MI.getOperand(1).getReg();
247         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
248         MachineInstr *Def = MRI.getVRegDef(Reg);
249         assert(Def && "Expecting an instruction that defines the register");
250         // G_GLOBAL_VALUE already has type info.
251         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
252           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
253         ToErase.push_back(&MI);
254       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
255                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
256                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
257         // %rc = G_CONSTANT ty Val
258         // ===>
259         // %cty = OpType* ty
260         // %rctmp = G_CONSTANT ty Val
261         // %rc = ASSIGN_TYPE %rctmp, %cty
262         Register Reg = MI.getOperand(0).getReg();
263         if (MRI.hasOneUse(Reg)) {
264           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
265           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
266               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
267             continue;
268         }
269         Type *Ty = nullptr;
270         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
271           Ty = MI.getOperand(1).getCImm()->getType();
272         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
273           Ty = MI.getOperand(1).getFPImm()->getType();
274         else {
275           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
276           Type *ElemTy = nullptr;
277           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
278           assert(ElemMI);
279 
280           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
281             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
282           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
283             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
284           else
285             llvm_unreachable("Unexpected opcode");
286           unsigned NumElts =
287               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
288           Ty = VectorType::get(ElemTy, NumElts, false);
289         }
290         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
291       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
292                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
293                  MI.getOpcode() == TargetOpcode::COPY ||
294                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
295         propagateSPIRVType(&MI, GR, MRI, MIB);
296       }
297 
298       if (MII == Begin)
299         ReachedBegin = true;
300       else
301         --MII;
302     }
303   }
304   for (MachineInstr *MI : ToErase)
305     MI->eraseFromParent();
306 }
307 
308 static std::pair<Register, unsigned>
309 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
310                const SPIRVGlobalRegistry &GR) {
311   LLT NewT = LLT::scalar(32);
312   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
313   assert(SpvType && "VReg is expected to have SPIRV type");
314   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
315   bool IsVectorFloat =
316       SpvType->getOpcode() == SPIRV::OpTypeVector &&
317       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
318           SPIRV::OpTypeFloat;
319   IsFloat |= IsVectorFloat;
320   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
321   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
322   if (MRI.getType(ValReg).isPointer()) {
323     NewT = LLT::pointer(0, 32);
324     GetIdOp = SPIRV::GET_pID;
325     DstClass = &SPIRV::pIDRegClass;
326   } else if (MRI.getType(ValReg).isVector()) {
327     NewT = LLT::fixed_vector(2, NewT);
328     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
329     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
330   }
331   Register IdReg = MRI.createGenericVirtualRegister(NewT);
332   MRI.setRegClass(IdReg, DstClass);
333   return {IdReg, GetIdOp};
334 }
335 
336 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
337                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
338   unsigned Opc = MI.getOpcode();
339   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
340   MachineInstr &AssignTypeInst =
341       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
342   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
343   AssignTypeInst.getOperand(1).setReg(NewReg);
344   MI.getOperand(0).setReg(NewReg);
345   MIB.setInsertPt(*MI.getParent(),
346                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
347                                     : MI.getParent()->end()));
348   for (auto &Op : MI.operands()) {
349     if (!Op.isReg() || Op.isDef())
350       continue;
351     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
352     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
353     Op.setReg(IdOpInfo.first);
354   }
355 }
356 
357 // Defined in SPIRVLegalizerInfo.cpp.
358 extern bool isTypeFoldingSupported(unsigned Opcode);
359 
360 static void processInstrsWithTypeFolding(MachineFunction &MF,
361                                          SPIRVGlobalRegistry *GR,
362                                          MachineIRBuilder MIB) {
363   MachineRegisterInfo &MRI = MF.getRegInfo();
364   for (MachineBasicBlock &MBB : MF) {
365     for (MachineInstr &MI : MBB) {
366       if (isTypeFoldingSupported(MI.getOpcode()))
367         processInstr(MI, MIB, MRI, GR);
368     }
369   }
370   for (MachineBasicBlock &MBB : MF) {
371     for (MachineInstr &MI : MBB) {
372       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
373       // to perform tblgen'erated selection and we can't do that on Legalizer
374       // as it operates on gMIR only.
375       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
376         continue;
377       Register SrcReg = MI.getOperand(1).getReg();
378       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
379       if (!isTypeFoldingSupported(Opcode))
380         continue;
381       Register DstReg = MI.getOperand(0).getReg();
382       if (MRI.getType(DstReg).isVector())
383         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
384       // Don't need to reset type of register holding constant and used in
385       // G_ADDRSPACE_CAST, since it braaks legalizer.
386       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
387         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
388         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
389           continue;
390       }
391       MRI.setType(DstReg, LLT::scalar(32));
392     }
393   }
394 }
395 
396 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
397                             MachineIRBuilder MIB) {
398   // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
399   // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
400   // + G_BR triples. A switch with two cases may be transformed to this MIR
401   // sequence:
402   //
403   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
404   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
405   //   G_BRCOND %Dst0, %bb.2
406   //   G_BR %bb.5
407   // bb.5.entry:
408   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
409   //   G_BRCOND %Dst1, %bb.3
410   //   G_BR %bb.4
411   // bb.2.sw.bb:
412   //   ...
413   // bb.3.sw.bb1:
414   //   ...
415   // bb.4.sw.epilog:
416   //   ...
417   //
418   // Sometimes (in case of range-compare switches), additional G_SUBs
419   // instructions are inserted before G_ICMPs. Those need to be additionally
420   // processed.
421   //
422   // This function modifies spv_switch call's operands to include destination
423   // MBBs (default and for each constant value).
424   //
425   // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
426   // G_BR sequences.
427 
428   MachineRegisterInfo &MRI = MF.getRegInfo();
429 
430   // Collect spv_switches and G_ICMPs across all MBBs in MF.
431   std::vector<MachineInstr *> RelevantInsts;
432 
433   // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
434   // After updating spv_switches, the instructions can be removed.
435   std::vector<MachineInstr *> PostUpdateArtifacts;
436 
437   // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
438   // spv_switch use these registers.
439   DenseSet<Register> CompareRegs;
440   for (MachineBasicBlock &MBB : MF) {
441     for (MachineInstr &MI : MBB) {
442       // Calls to spv_switch intrinsics representing IR switches.
443       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
444         assert(MI.getOperand(1).isReg());
445         CompareRegs.insert(MI.getOperand(1).getReg());
446         RelevantInsts.push_back(&MI);
447       }
448 
449       // G_SUBs coming from range-compare switch lowering. G_SUBs are found
450       // after spv_switch but before G_ICMP.
451       if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
452           CompareRegs.contains(MI.getOperand(1).getReg())) {
453         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
454         Register Dst = MI.getOperand(0).getReg();
455         CompareRegs.insert(Dst);
456         PostUpdateArtifacts.push_back(&MI);
457       }
458 
459       // G_ICMPs relating to switches.
460       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
461           CompareRegs.contains(MI.getOperand(2).getReg())) {
462         Register Dst = MI.getOperand(0).getReg();
463         RelevantInsts.push_back(&MI);
464         PostUpdateArtifacts.push_back(&MI);
465         MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
466         assert(CBr->getOpcode() == SPIRV::G_BRCOND);
467         PostUpdateArtifacts.push_back(CBr);
468         MachineInstr *Br = CBr->getNextNode();
469         assert(Br->getOpcode() == SPIRV::G_BR);
470         PostUpdateArtifacts.push_back(Br);
471       }
472     }
473   }
474 
475   // Update each spv_switch with destination MBBs.
476   for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
477     if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
478       continue;
479 
480     // Currently considered spv_switch.
481     MachineInstr *Switch = *i;
482     // Set the first successor as default MBB to support empty switches.
483     MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
484     // Container for mapping values to MMBs.
485     SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
486 
487     // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
488     // spv_switch (i) and break at any spv_switch with the same compare
489     // register (indicating we are back at the same scope).
490     Register CompareReg = Switch->getOperand(1).getReg();
491     for (auto j = i + 1; j != RelevantInsts.end(); j++) {
492       if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
493           (*j)->getOperand(1).getReg() == CompareReg)
494         break;
495 
496       if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
497             (*j)->getOperand(2).getReg() == CompareReg))
498         continue;
499 
500       MachineInstr *ICMP = *j;
501       Register Dst = ICMP->getOperand(0).getReg();
502       MachineOperand &PredOp = ICMP->getOperand(1);
503       const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
504       assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
505              MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
506       uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
507       MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
508       assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
509       MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
510 
511       // Map switch case Value to target MBB.
512       ValuesToMBBs[Value] = MBB;
513 
514       // Add target MBB as successor to the switch's MBB.
515       Switch->getParent()->addSuccessor(MBB);
516 
517       // The next MI is always G_BR to either the next case or the default.
518       MachineInstr *NextMI = CBr->getNextNode();
519       assert(NextMI->getOpcode() == SPIRV::G_BR &&
520              NextMI->getOperand(0).isMBB());
521       MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
522       // Default MBB does not begin with G_ICMP using spv_switch compare
523       // register.
524       if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
525           (NextMBB->front().getOperand(2).isReg() &&
526            NextMBB->front().getOperand(2).getReg() != CompareReg)) {
527         // Set default MBB and add it as successor to the switch's MBB.
528         DefaultMBB = NextMBB;
529         Switch->getParent()->addSuccessor(DefaultMBB);
530       }
531     }
532 
533     // Modify considered spv_switch operands using collected Values and
534     // MBBs.
535     SmallVector<const ConstantInt *, 3> Values;
536     SmallVector<MachineBasicBlock *, 3> MBBs;
537     for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
538       Register CReg = Switch->getOperand(k).getReg();
539       uint64_t Val = getIConstVal(CReg, &MRI);
540       MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
541       if (!ValuesToMBBs[Val])
542         continue;
543 
544       Values.push_back(ConstInstr->getOperand(1).getCImm());
545       MBBs.push_back(ValuesToMBBs[Val]);
546     }
547 
548     for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
549       Switch->removeOperand(k);
550 
551     Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
552     for (unsigned k = 0; k < Values.size(); k++) {
553       Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
554       Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
555     }
556   }
557 
558   for (MachineInstr *MI : PostUpdateArtifacts) {
559     MachineBasicBlock *ParentMBB = MI->getParent();
560     MI->eraseFromParent();
561     // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
562     // can be safely assumed, there are no breaks or phis directing into this
563     // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
564     // erased top-down.
565     if (ParentMBB->empty()) {
566       while (!ParentMBB->pred_empty())
567         (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);
568 
569       while (!ParentMBB->succ_empty())
570         ParentMBB->removeSuccessor(ParentMBB->succ_begin());
571 
572       ParentMBB->eraseFromParent();
573     }
574   }
575 }
576 
577 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
578   // Initialize the type registry.
579   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
580   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
581   GR->setCurrentFunc(MF);
582   MachineIRBuilder MIB(MF);
583   addConstantsToTrack(MF, GR);
584   foldConstantsIntoIntrinsics(MF);
585   insertBitcasts(MF, GR, MIB);
586   generateAssignInstrs(MF, GR, MIB);
587   processSwitches(MF, GR, MIB);
588   processInstrsWithTypeFolding(MF, GR, MIB);
589 
590   return true;
591 }
592 
593 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
594                 false)
595 
596 char SPIRVPreLegalizer::ID = 0;
597 
598 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
599   return new SPIRVPreLegalizer();
600 }
601