xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision c66a499e037efd268a744e487e7d0c45a4944a9b)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42   MachineRegisterInfo &MRI = MF.getRegInfo();
43   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45   for (MachineBasicBlock &MBB : MF) {
46     for (MachineInstr &MI : MBB) {
47       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48         continue;
49       ToErase.push_back(&MI);
50       auto *Const =
51           cast<Constant>(cast<ConstantAsMetadata>(
52                              MI.getOperand(3).getMetadata()->getOperand(0))
53                              ->getValue());
54       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
55         Register Reg = GR->find(GV, &MF);
56         if (!Reg.isValid())
57           GR->add(GV, &MF, MI.getOperand(2).getReg());
58         else
59           RegsAlreadyAddedToDT[&MI] = Reg;
60       } else {
61         Register Reg = GR->find(Const, &MF);
62         if (!Reg.isValid()) {
63           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
64             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
65             assert(BuildVec &&
66                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
68               GR->add(ConstVec->getElementAsConstant(i), &MF,
69                       BuildVec->getOperand(1 + i).getReg());
70           }
71           GR->add(Const, &MF, MI.getOperand(2).getReg());
72         } else {
73           RegsAlreadyAddedToDT[&MI] = Reg;
74           // This MI is unused and will be removed. If the MI uses
75           // const_composite, it will be unused and should be removed too.
76           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
77           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
78           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
79             ToEraseComposites.push_back(SrcMI);
80         }
81       }
82     }
83   }
84   for (MachineInstr *MI : ToErase) {
85     Register Reg = MI->getOperand(2).getReg();
86     if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
87       Reg = RegsAlreadyAddedToDT[MI];
88     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
89     MI->eraseFromParent();
90   }
91   for (MachineInstr *MI : ToEraseComposites)
92     MI->eraseFromParent();
93 }
94 
95 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
96   SmallVector<MachineInstr *, 10> ToErase;
97   MachineRegisterInfo &MRI = MF.getRegInfo();
98   const unsigned AssignNameOperandShift = 2;
99   for (MachineBasicBlock &MBB : MF) {
100     for (MachineInstr &MI : MBB) {
101       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
102         continue;
103       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
104       while (MI.getOperand(NumOp).isReg()) {
105         MachineOperand &MOp = MI.getOperand(NumOp);
106         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
107         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
108         MI.removeOperand(NumOp);
109         MI.addOperand(MachineOperand::CreateImm(
110             ConstMI->getOperand(1).getCImm()->getZExtValue()));
111         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
112           ToErase.push_back(ConstMI);
113       }
114     }
115   }
116   for (MachineInstr *MI : ToErase)
117     MI->eraseFromParent();
118 }
119 
120 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
121                            MachineIRBuilder MIB) {
122   SmallVector<MachineInstr *, 10> ToErase;
123   for (MachineBasicBlock &MBB : MF) {
124     for (MachineInstr &MI : MBB) {
125       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
126         continue;
127       assert(MI.getOperand(2).isReg());
128       MIB.setInsertPt(*MI.getParent(), MI);
129       MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
130       ToErase.push_back(&MI);
131     }
132   }
133   for (MachineInstr *MI : ToErase)
134     MI->eraseFromParent();
135 }
136 
137 // Translating GV, IRTranslator sometimes generates following IR:
138 //   %1 = G_GLOBAL_VALUE
139 //   %2 = COPY %1
140 //   %3 = G_ADDRSPACE_CAST %2
141 // New registers have no SPIRVType and no register class info.
142 //
143 // Set SPIRVType for GV, propagate it from GV to other instructions,
144 // also set register classes.
145 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
146                                      MachineRegisterInfo &MRI,
147                                      MachineIRBuilder &MIB) {
148   SPIRVType *SpirvTy = nullptr;
149   assert(MI && "Machine instr is expected");
150   if (MI->getOperand(0).isReg()) {
151     Register Reg = MI->getOperand(0).getReg();
152     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
153     if (!SpirvTy) {
154       switch (MI->getOpcode()) {
155       case TargetOpcode::G_CONSTANT: {
156         MIB.setInsertPt(*MI->getParent(), MI);
157         Type *Ty = MI->getOperand(1).getCImm()->getType();
158         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
159         break;
160       }
161       case TargetOpcode::G_GLOBAL_VALUE: {
162         MIB.setInsertPt(*MI->getParent(), MI);
163         Type *Ty = MI->getOperand(1).getGlobal()->getType();
164         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
165         break;
166       }
167       case TargetOpcode::G_TRUNC:
168       case TargetOpcode::G_ADDRSPACE_CAST:
169       case TargetOpcode::G_PTR_ADD:
170       case TargetOpcode::COPY: {
171         MachineOperand &Op = MI->getOperand(1);
172         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
173         if (Def)
174           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
175         break;
176       }
177       default:
178         break;
179       }
180       if (SpirvTy)
181         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
182       if (!MRI.getRegClassOrNull(Reg))
183         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
184     }
185   }
186   return SpirvTy;
187 }
188 
189 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
190 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
191 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
192 // It's used also in SPIRVBuiltins.cpp.
193 // TODO: maybe move to SPIRVUtils.
194 namespace llvm {
195 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
196                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
197                            MachineRegisterInfo &MRI) {
198   MachineInstr *Def = MRI.getVRegDef(Reg);
199   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
200   MIB.setInsertPt(*Def->getParent(),
201                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
202                                       : Def->getParent()->end()));
203   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
204   if (auto *RC = MRI.getRegClassOrNull(Reg))
205     MRI.setRegClass(NewReg, RC);
206   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
207   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
208   // This is to make it convenient for Legalizer to get the SPIRVType
209   // when processing the actual MI (i.e. not pseudo one).
210   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
211   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
212   // the flags after instruction selection.
213   const uint16_t Flags = Def->getFlags();
214   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
215       .addDef(Reg)
216       .addUse(NewReg)
217       .addUse(GR->getSPIRVTypeID(SpirvTy))
218       .setMIFlags(Flags);
219   Def->getOperand(0).setReg(NewReg);
220   MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
221   return NewReg;
222 }
223 } // namespace llvm
224 
225 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
226                                  MachineIRBuilder MIB) {
227   MachineRegisterInfo &MRI = MF.getRegInfo();
228   SmallVector<MachineInstr *, 10> ToErase;
229 
230   for (MachineBasicBlock *MBB : post_order(&MF)) {
231     if (MBB->empty())
232       continue;
233 
234     bool ReachedBegin = false;
235     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
236          !ReachedBegin;) {
237       MachineInstr &MI = *MII;
238 
239       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
240         Register Reg = MI.getOperand(1).getReg();
241         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
242         MachineInstr *Def = MRI.getVRegDef(Reg);
243         assert(Def && "Expecting an instruction that defines the register");
244         // G_GLOBAL_VALUE already has type info.
245         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
246           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
247         ToErase.push_back(&MI);
248       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
249                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
250                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
251         // %rc = G_CONSTANT ty Val
252         // ===>
253         // %cty = OpType* ty
254         // %rctmp = G_CONSTANT ty Val
255         // %rc = ASSIGN_TYPE %rctmp, %cty
256         Register Reg = MI.getOperand(0).getReg();
257         if (MRI.hasOneUse(Reg)) {
258           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
259           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
260               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
261             continue;
262         }
263         Type *Ty = nullptr;
264         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
265           Ty = MI.getOperand(1).getCImm()->getType();
266         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
267           Ty = MI.getOperand(1).getFPImm()->getType();
268         else {
269           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
270           Type *ElemTy = nullptr;
271           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
272           assert(ElemMI);
273 
274           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
275             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
276           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
277             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
278           else
279             llvm_unreachable("Unexpected opcode");
280           unsigned NumElts =
281               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
282           Ty = VectorType::get(ElemTy, NumElts, false);
283         }
284         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
285       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
286                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
287                  MI.getOpcode() == TargetOpcode::COPY ||
288                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
289         propagateSPIRVType(&MI, GR, MRI, MIB);
290       }
291 
292       if (MII == Begin)
293         ReachedBegin = true;
294       else
295         --MII;
296     }
297   }
298   for (MachineInstr *MI : ToErase)
299     MI->eraseFromParent();
300 }
301 
302 static std::pair<Register, unsigned>
303 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
304                const SPIRVGlobalRegistry &GR) {
305   LLT NewT = LLT::scalar(32);
306   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
307   assert(SpvType && "VReg is expected to have SPIRV type");
308   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
309   bool IsVectorFloat =
310       SpvType->getOpcode() == SPIRV::OpTypeVector &&
311       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
312           SPIRV::OpTypeFloat;
313   IsFloat |= IsVectorFloat;
314   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
315   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
316   if (MRI.getType(ValReg).isPointer()) {
317     NewT = LLT::pointer(0, 32);
318     GetIdOp = SPIRV::GET_pID;
319     DstClass = &SPIRV::pIDRegClass;
320   } else if (MRI.getType(ValReg).isVector()) {
321     NewT = LLT::fixed_vector(2, NewT);
322     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
323     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
324   }
325   Register IdReg = MRI.createGenericVirtualRegister(NewT);
326   MRI.setRegClass(IdReg, DstClass);
327   return {IdReg, GetIdOp};
328 }
329 
330 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
331                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
332   unsigned Opc = MI.getOpcode();
333   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
334   MachineInstr &AssignTypeInst =
335       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
336   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
337   AssignTypeInst.getOperand(1).setReg(NewReg);
338   MI.getOperand(0).setReg(NewReg);
339   MIB.setInsertPt(*MI.getParent(),
340                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
341                                     : MI.getParent()->end()));
342   for (auto &Op : MI.operands()) {
343     if (!Op.isReg() || Op.isDef())
344       continue;
345     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
346     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
347     Op.setReg(IdOpInfo.first);
348   }
349 }
350 
351 // Defined in SPIRVLegalizerInfo.cpp.
352 extern bool isTypeFoldingSupported(unsigned Opcode);
353 
354 static void processInstrsWithTypeFolding(MachineFunction &MF,
355                                          SPIRVGlobalRegistry *GR,
356                                          MachineIRBuilder MIB) {
357   MachineRegisterInfo &MRI = MF.getRegInfo();
358   for (MachineBasicBlock &MBB : MF) {
359     for (MachineInstr &MI : MBB) {
360       if (isTypeFoldingSupported(MI.getOpcode()))
361         processInstr(MI, MIB, MRI, GR);
362     }
363   }
364   for (MachineBasicBlock &MBB : MF) {
365     for (MachineInstr &MI : MBB) {
366       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
367       // to perform tblgen'erated selection and we can't do that on Legalizer
368       // as it operates on gMIR only.
369       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
370         continue;
371       Register SrcReg = MI.getOperand(1).getReg();
372       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
373       if (!isTypeFoldingSupported(Opcode))
374         continue;
375       Register DstReg = MI.getOperand(0).getReg();
376       if (MRI.getType(DstReg).isVector())
377         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
378       // Don't need to reset type of register holding constant and used in
379       // G_ADDRSPACE_CAST, since it braaks legalizer.
380       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
381         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
382         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
383           continue;
384       }
385       MRI.setType(DstReg, LLT::scalar(32));
386     }
387   }
388 }
389 
390 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
391                             MachineIRBuilder MIB) {
392   // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
393   // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
394   // + G_BR triples. A switch with two cases may be transformed to this MIR
395   // sequence:
396   //
397   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
398   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
399   //   G_BRCOND %Dst0, %bb.2
400   //   G_BR %bb.5
401   // bb.5.entry:
402   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
403   //   G_BRCOND %Dst1, %bb.3
404   //   G_BR %bb.4
405   // bb.2.sw.bb:
406   //   ...
407   // bb.3.sw.bb1:
408   //   ...
409   // bb.4.sw.epilog:
410   //   ...
411   //
412   // Sometimes (in case of range-compare switches), additional G_SUBs
413   // instructions are inserted before G_ICMPs. Those need to be additionally
414   // processed and require type assignment.
415   //
416   // This function modifies spv_switch call's operands to include destination
417   // MBBs (default and for each constant value).
418   // Note that this function does not remove G_ICMP + G_BRCOND + G_BR sequences,
419   // but they are marked by ModuleAnalysis as skipped and as a result AsmPrinter
420   // does not output them.
421 
422   MachineRegisterInfo &MRI = MF.getRegInfo();
423 
424   // Collect all MIs relevant to switches across all MBBs in MF.
425   std::vector<MachineInstr *> RelevantInsts;
426 
427   // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
428   // spv_switch use these registers.
429   DenseSet<Register> CompareRegs;
430   for (MachineBasicBlock &MBB : MF) {
431     for (MachineInstr &MI : MBB) {
432       // Calls to spv_switch intrinsics representing IR switches.
433       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
434         assert(MI.getOperand(1).isReg());
435         CompareRegs.insert(MI.getOperand(1).getReg());
436         RelevantInsts.push_back(&MI);
437       }
438 
439       // G_SUBs coming from range-compare switch lowering. G_SUBs are found
440       // after spv_switch but before G_ICMP.
441       if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
442           CompareRegs.contains(MI.getOperand(1).getReg())) {
443         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
444         Register Dst = MI.getOperand(0).getReg();
445         CompareRegs.insert(Dst);
446         SPIRVType *Ty = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
447         insertAssignInstr(Dst, nullptr, Ty, GR, MIB, MRI);
448       }
449 
450       // G_ICMPs relating to switches.
451       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
452           CompareRegs.contains(MI.getOperand(2).getReg())) {
453         Register Dst = MI.getOperand(0).getReg();
454         // Set type info for destination register of switch's ICMP instruction.
455         if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
456           MIB.setInsertPt(*MI.getParent(), MI);
457           Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
458           SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
459           MRI.setRegClass(Dst, &SPIRV::IDRegClass);
460           GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
461         }
462         RelevantInsts.push_back(&MI);
463       }
464     }
465   }
466 
467   // Update each spv_switch with destination MBBs.
468   for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
469     if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
470       continue;
471 
472     // Currently considered spv_switch.
473     MachineInstr *Switch = *i;
474     // Set the first successor as default MBB to support empty switches.
475     MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
476     // Container for mapping values to MMBs.
477     SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
478 
479     // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
480     // spv_switch (i) and break at any spv_switch with the same compare
481     // register (indicating we are back at the same scope).
482     Register CompareReg = Switch->getOperand(1).getReg();
483     for (auto j = i + 1; j != RelevantInsts.end(); j++) {
484       if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
485           (*j)->getOperand(1).getReg() == CompareReg)
486         break;
487 
488       if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
489             (*j)->getOperand(2).getReg() == CompareReg))
490         continue;
491 
492       MachineInstr *ICMP = *j;
493       Register Dst = ICMP->getOperand(0).getReg();
494       MachineOperand &PredOp = ICMP->getOperand(1);
495       const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
496       assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
497              MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
498       uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
499       MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
500       assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
501       MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
502 
503       // Map switch case Value to target MBB.
504       ValuesToMBBs[Value] = MBB;
505 
506       // The next MI is always G_BR to either the next case or the default.
507       MachineInstr *NextMI = CBr->getNextNode();
508       assert(NextMI->getOpcode() == SPIRV::G_BR &&
509              NextMI->getOperand(0).isMBB());
510       MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
511       // Default MBB does not begin with G_ICMP using spv_switch compare
512       // register.
513       if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
514           (NextMBB->front().getOperand(2).isReg() &&
515            NextMBB->front().getOperand(2).getReg() != CompareReg))
516         DefaultMBB = NextMBB;
517     }
518 
519     // Modify considered spv_switch operands using collected Values and
520     // MBBs.
521     SmallVector<const ConstantInt *, 3> Values;
522     SmallVector<MachineBasicBlock *, 3> MBBs;
523     for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
524       Register CReg = Switch->getOperand(k).getReg();
525       uint64_t Val = getIConstVal(CReg, &MRI);
526       MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
527       if (!ValuesToMBBs[Val])
528         continue;
529 
530       Values.push_back(ConstInstr->getOperand(1).getCImm());
531       MBBs.push_back(ValuesToMBBs[Val]);
532     }
533 
534     for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
535       Switch->removeOperand(k);
536 
537     Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
538     for (unsigned k = 0; k < Values.size(); k++) {
539       Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
540       Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
541     }
542   }
543 }
544 
545 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
546   // Initialize the type registry.
547   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
548   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
549   GR->setCurrentFunc(MF);
550   MachineIRBuilder MIB(MF);
551   addConstantsToTrack(MF, GR);
552   foldConstantsIntoIntrinsics(MF);
553   insertBitcasts(MF, GR, MIB);
554   generateAssignInstrs(MF, GR, MIB);
555   processSwitches(MF, GR, MIB);
556   processInstrsWithTypeFolding(MF, GR, MIB);
557 
558   return true;
559 }
560 
561 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
562                 false)
563 
564 char SPIRVPreLegalizer::ID = 0;
565 
566 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
567   return new SPIRVPreLegalizer();
568 }
569