xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/CodeGen/GlobalISel/CSEInfo.h"
20 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 
25 #define DEBUG_TYPE "spirv-prelegalizer"
26 
27 using namespace llvm;
28 
29 namespace {
30 class SPIRVPreLegalizer : public MachineFunctionPass {
31 public:
32   static char ID;
SPIRVPreLegalizer()33   SPIRVPreLegalizer() : MachineFunctionPass(ID) {}
34   bool runOnMachineFunction(MachineFunction &MF) override;
35   void getAnalysisUsage(AnalysisUsage &AU) const override;
36 };
37 } // namespace
38 
getAnalysisUsage(AnalysisUsage & AU) const39 void SPIRVPreLegalizer::getAnalysisUsage(AnalysisUsage &AU) const {
40   AU.addPreserved<GISelValueTrackingAnalysisLegacy>();
41   MachineFunctionPass::getAnalysisUsage(AU);
42 }
43 
44 static void
addConstantsToTrack(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & STI,DenseMap<MachineInstr *,Type * > & TargetExtConstTypes)45 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
46                     const SPIRVSubtarget &STI,
47                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
48   MachineRegisterInfo &MRI = MF.getRegInfo();
49   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
50   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
51   for (MachineBasicBlock &MBB : MF) {
52     for (MachineInstr &MI : MBB) {
53       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
54         continue;
55       ToErase.push_back(&MI);
56       Register SrcReg = MI.getOperand(2).getReg();
57       auto *Const =
58           cast<Constant>(cast<ConstantAsMetadata>(
59                              MI.getOperand(3).getMetadata()->getOperand(0))
60                              ->getValue());
61       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
62         Register Reg = GR->find(GV, &MF);
63         if (!Reg.isValid()) {
64           GR->add(GV, MRI.getVRegDef(SrcReg));
65           GR->addGlobalObject(GV, &MF, SrcReg);
66         } else
67           RegsAlreadyAddedToDT[&MI] = Reg;
68       } else {
69         Register Reg = GR->find(Const, &MF);
70         if (!Reg.isValid()) {
71           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
72             auto *BuildVec = MRI.getVRegDef(SrcReg);
73             assert(BuildVec &&
74                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
75             GR->add(Const, BuildVec);
76             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
77               // Ensure that OpConstantComposite reuses a constant when it's
78               // already created and available in the same machine function.
79               Constant *ElemConst = ConstVec->getElementAsConstant(i);
80               Register ElemReg = GR->find(ElemConst, &MF);
81               if (!ElemReg.isValid())
82                 GR->add(ElemConst,
83                         MRI.getVRegDef(BuildVec->getOperand(1 + i).getReg()));
84               else
85                 BuildVec->getOperand(1 + i).setReg(ElemReg);
86             }
87           }
88           if (Const->getType()->isTargetExtTy()) {
89             // remember association so that we can restore it when assign types
90             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
91             if (SrcMI)
92               GR->add(Const, SrcMI);
93             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
94                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
95               TargetExtConstTypes[SrcMI] = Const->getType();
96             if (Const->isNullValue()) {
97               MachineBasicBlock &DepMBB = MF.front();
98               MachineIRBuilder MIB(DepMBB, DepMBB.getFirstNonPHI());
99               SPIRVType *ExtType = GR->getOrCreateSPIRVType(
100                   Const->getType(), MIB, SPIRV::AccessQualifier::ReadWrite,
101                   true);
102               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
103               SrcMI->addOperand(MachineOperand::CreateReg(
104                   GR->getSPIRVTypeID(ExtType), false));
105             }
106           }
107         } else {
108           RegsAlreadyAddedToDT[&MI] = Reg;
109           // This MI is unused and will be removed. If the MI uses
110           // const_composite, it will be unused and should be removed too.
111           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
112           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
113           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
114             ToEraseComposites.push_back(SrcMI);
115         }
116       }
117     }
118   }
119   for (MachineInstr *MI : ToErase) {
120     Register Reg = MI->getOperand(2).getReg();
121     auto It = RegsAlreadyAddedToDT.find(MI);
122     if (It != RegsAlreadyAddedToDT.end())
123       Reg = It->second;
124     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
125     if (!MRI.getRegClassOrNull(Reg) && RC)
126       MRI.setRegClass(Reg, RC);
127     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
128     GR->invalidateMachineInstr(MI);
129     MI->eraseFromParent();
130   }
131   for (MachineInstr *MI : ToEraseComposites) {
132     GR->invalidateMachineInstr(MI);
133     MI->eraseFromParent();
134   }
135 }
136 
foldConstantsIntoIntrinsics(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)137 static void foldConstantsIntoIntrinsics(MachineFunction &MF,
138                                         SPIRVGlobalRegistry *GR,
139                                         MachineIRBuilder MIB) {
140   SmallVector<MachineInstr *, 64> ToErase;
141   for (MachineBasicBlock &MBB : MF) {
142     for (MachineInstr &MI : MBB) {
143       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
144         continue;
145       const MDNode *MD = MI.getOperand(2).getMetadata();
146       StringRef ValueName = cast<MDString>(MD->getOperand(0))->getString();
147       if (ValueName.size() > 0) {
148         MIB.setInsertPt(*MI.getParent(), MI);
149         buildOpName(MI.getOperand(1).getReg(), ValueName, MIB);
150       }
151       ToErase.push_back(&MI);
152     }
153     for (MachineInstr *MI : ToErase) {
154       GR->invalidateMachineInstr(MI);
155       MI->eraseFromParent();
156     }
157     ToErase.clear();
158   }
159 }
160 
findAssignTypeInstr(Register Reg,MachineRegisterInfo * MRI)161 static MachineInstr *findAssignTypeInstr(Register Reg,
162                                          MachineRegisterInfo *MRI) {
163   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
164                                                IE = MRI->use_instr_end();
165        I != IE; ++I) {
166     MachineInstr *UseMI = &*I;
167     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
168          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
169         UseMI->getOperand(1).getReg() == Reg)
170       return UseMI;
171   }
172   return nullptr;
173 }
174 
buildOpBitcast(SPIRVGlobalRegistry * GR,MachineIRBuilder & MIB,Register ResVReg,Register OpReg)175 static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
176                            Register ResVReg, Register OpReg) {
177   SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
178   SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg);
179   assert(ResType && OpType && "Operand types are expected");
180   if (!GR->isBitcastCompatible(ResType, OpType))
181     report_fatal_error("incompatible result and operand types in a bitcast");
182   MachineRegisterInfo *MRI = MIB.getMRI();
183   if (!MRI->getRegClassOrNull(ResVReg))
184     MRI->setRegClass(ResVReg, GR->getRegClass(ResType));
185   if (ResType == OpType)
186     MIB.buildInstr(TargetOpcode::COPY).addDef(ResVReg).addUse(OpReg);
187   else
188     MIB.buildInstr(SPIRV::OpBitcast)
189         .addDef(ResVReg)
190         .addUse(GR->getSPIRVTypeID(ResType))
191         .addUse(OpReg);
192 }
193 
194 // We do instruction selections early instead of calling MIB.buildBitcast()
195 // generating the general op code G_BITCAST. When MachineVerifier validates
196 // G_BITCAST we see a check of a kind: if Source Type is equal to Destination
197 // Type then report error "bitcast must change the type". This doesn't take into
198 // account the notion of a typed pointer that is important for SPIR-V where a
199 // user may and should use bitcast between pointers with different pointee types
200 // (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
201 // It's important for correct lowering in SPIR-V, because interpretation of the
202 // data type is not left to instructions that utilize the pointer, but encoded
203 // by the pointer declaration, and the SPIRV target can and must handle the
204 // declaration and use of pointers that specify the type of data they point to.
205 // It's not feasible to improve validation of G_BITCAST using just information
206 // provided by low level types of source and destination. Therefore we don't
207 // produce G_BITCAST as the general op code with semantics different from
208 // OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
209 // difference would be that CombinerHelper couldn't transform known patterns
210 // around G_BUILD_VECTOR. See discussion
211 // in https://github.com/llvm/llvm-project/pull/110270 for even more context.
selectOpBitcasts(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)212 static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
213                              MachineIRBuilder MIB) {
214   SmallVector<MachineInstr *, 16> ToErase;
215   for (MachineBasicBlock &MBB : MF) {
216     for (MachineInstr &MI : MBB) {
217       if (MI.getOpcode() != TargetOpcode::G_BITCAST)
218         continue;
219       MIB.setInsertPt(*MI.getParent(), MI);
220       buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
221                      MI.getOperand(1).getReg());
222       ToErase.push_back(&MI);
223     }
224   }
225   for (MachineInstr *MI : ToErase) {
226     GR->invalidateMachineInstr(MI);
227     MI->eraseFromParent();
228   }
229 }
230 
insertBitcasts(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)231 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
232                            MachineIRBuilder MIB) {
233   // Get access to information about available extensions
234   const SPIRVSubtarget *ST =
235       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
236   SmallVector<MachineInstr *, 10> ToErase;
237   for (MachineBasicBlock &MBB : MF) {
238     for (MachineInstr &MI : MBB) {
239       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
240           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
241         continue;
242       assert(MI.getOperand(2).isReg());
243       MIB.setInsertPt(*MI.getParent(), MI);
244       ToErase.push_back(&MI);
245       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
246         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
247         continue;
248       }
249       Register Def = MI.getOperand(0).getReg();
250       Register Source = MI.getOperand(2).getReg();
251       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
252       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
253           ElemTy, MI,
254           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
255 
256       // If the ptrcast would be redundant, replace all uses with the source
257       // register.
258       MachineRegisterInfo *MRI = MIB.getMRI();
259       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
260         // Erase Def's assign type instruction if we are going to replace Def.
261         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MRI))
262           ToErase.push_back(AssignMI);
263         MRI->replaceRegWith(Def, Source);
264       } else {
265         if (!GR->getSPIRVTypeForVReg(Def, &MF))
266           GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
267         MIB.buildBitcast(Def, Source);
268       }
269     }
270   }
271   for (MachineInstr *MI : ToErase) {
272     GR->invalidateMachineInstr(MI);
273     MI->eraseFromParent();
274   }
275 }
276 
277 // Translating GV, IRTranslator sometimes generates following IR:
278 //   %1 = G_GLOBAL_VALUE
279 //   %2 = COPY %1
280 //   %3 = G_ADDRSPACE_CAST %2
281 //
282 // or
283 //
284 //  %1 = G_ZEXT %2
285 //  G_MEMCPY ... %2 ...
286 //
287 // New registers have no SPIRVType and no register class info.
288 //
289 // Set SPIRVType for GV, propagate it from GV to other instructions,
290 // also set register classes.
propagateSPIRVType(MachineInstr * MI,SPIRVGlobalRegistry * GR,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)291 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
292                                      MachineRegisterInfo &MRI,
293                                      MachineIRBuilder &MIB) {
294   SPIRVType *SpvType = nullptr;
295   assert(MI && "Machine instr is expected");
296   if (MI->getOperand(0).isReg()) {
297     Register Reg = MI->getOperand(0).getReg();
298     SpvType = GR->getSPIRVTypeForVReg(Reg);
299     if (!SpvType) {
300       switch (MI->getOpcode()) {
301       case TargetOpcode::G_FCONSTANT:
302       case TargetOpcode::G_CONSTANT: {
303         MIB.setInsertPt(*MI->getParent(), MI);
304         Type *Ty = MI->getOperand(1).getCImm()->getType();
305         SpvType = GR->getOrCreateSPIRVType(
306             Ty, MIB, SPIRV::AccessQualifier::ReadWrite, true);
307         break;
308       }
309       case TargetOpcode::G_GLOBAL_VALUE: {
310         MIB.setInsertPt(*MI->getParent(), MI);
311         const GlobalValue *Global = MI->getOperand(1).getGlobal();
312         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
313         auto *Ty = TypedPointerType::get(ElementTy,
314                                          Global->getType()->getAddressSpace());
315         SpvType = GR->getOrCreateSPIRVType(
316             Ty, MIB, SPIRV::AccessQualifier::ReadWrite, true);
317         break;
318       }
319       case TargetOpcode::G_ANYEXT:
320       case TargetOpcode::G_SEXT:
321       case TargetOpcode::G_ZEXT: {
322         if (MI->getOperand(1).isReg()) {
323           if (MachineInstr *DefInstr =
324                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
325             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
326               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
327               unsigned ExpectedBW =
328                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
329               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
330               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
331               if (NumElements > 1)
332                 SpvType = GR->getOrCreateSPIRVVectorType(SpvType, NumElements,
333                                                          MIB, true);
334             }
335           }
336         }
337         break;
338       }
339       case TargetOpcode::G_PTRTOINT:
340         SpvType = GR->getOrCreateSPIRVIntegerType(
341             MRI.getType(Reg).getScalarSizeInBits(), MIB);
342         break;
343       case TargetOpcode::G_TRUNC:
344       case TargetOpcode::G_ADDRSPACE_CAST:
345       case TargetOpcode::G_PTR_ADD:
346       case TargetOpcode::COPY: {
347         MachineOperand &Op = MI->getOperand(1);
348         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
349         if (Def)
350           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
351         break;
352       }
353       default:
354         break;
355       }
356       if (SpvType) {
357         // check if the address space needs correction
358         LLT RegType = MRI.getType(Reg);
359         if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
360             RegType.isPointer() &&
361             storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
362                 RegType.getAddressSpace()) {
363           const SPIRVSubtarget &ST =
364               MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
365           auto TSC = addressSpaceToStorageClass(RegType.getAddressSpace(), ST);
366           SpvType = GR->changePointerStorageClass(SpvType, TSC, *MI);
367         }
368         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
369       }
370       if (!MRI.getRegClassOrNull(Reg))
371         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
372                                      : &SPIRV::iIDRegClass);
373     }
374   }
375   return SpvType;
376 }
377 
378 // To support current approach and limitations wrt. bit width here we widen a
379 // scalar register with a bit width greater than 1 to valid sizes and cap it to
380 // 64 width.
widenBitWidthToNextPow2(unsigned BitWidth)381 static unsigned widenBitWidthToNextPow2(unsigned BitWidth) {
382   if (BitWidth == 1)
383     return 1; // No need to widen 1-bit values
384   return std::min(std::max(1u << Log2_32_Ceil(BitWidth), 8u), 64u);
385 }
386 
widenScalarType(Register Reg,MachineRegisterInfo & MRI)387 static void widenScalarType(Register Reg, MachineRegisterInfo &MRI) {
388   LLT RegType = MRI.getType(Reg);
389   if (!RegType.isScalar())
390     return;
391   unsigned CurrentWidth = RegType.getScalarSizeInBits();
392   unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
393   if (NewWidth != CurrentWidth)
394     MRI.setType(Reg, LLT::scalar(NewWidth));
395 }
396 
widenCImmType(MachineOperand & MOP)397 static void widenCImmType(MachineOperand &MOP) {
398   const ConstantInt *CImmVal = MOP.getCImm();
399   unsigned CurrentWidth = CImmVal->getBitWidth();
400   unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
401   if (NewWidth != CurrentWidth) {
402     // Replace the immediate value with the widened version
403     MOP.setCImm(ConstantInt::get(CImmVal->getType()->getContext(),
404                                  CImmVal->getValue().zextOrTrunc(NewWidth)));
405   }
406 }
407 
setInsertPtAfterDef(MachineIRBuilder & MIB,MachineInstr * Def)408 static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
409   MachineBasicBlock &MBB = *Def->getParent();
410   MachineBasicBlock::iterator DefIt =
411       Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
412   // Skip all the PHI and debug instructions.
413   while (DefIt != MBB.end() &&
414          (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
415     DefIt = std::next(DefIt);
416   MIB.setInsertPt(MBB, DefIt);
417 }
418 
419 namespace llvm {
insertAssignInstr(Register Reg,Type * Ty,SPIRVType * SpvType,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIB,MachineRegisterInfo & MRI)420 void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
421                        SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
422                        MachineRegisterInfo &MRI) {
423   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
424   MachineInstr *Def = MRI.getVRegDef(Reg);
425   setInsertPtAfterDef(MIB, Def);
426   if (!SpvType)
427     SpvType = GR->getOrCreateSPIRVType(Ty, MIB,
428                                        SPIRV::AccessQualifier::ReadWrite, true);
429 
430   if (!isTypeFoldingSupported(Def->getOpcode())) {
431     // No need to generate SPIRV::ASSIGN_TYPE pseudo-instruction
432     if (!MRI.getRegClassOrNull(Reg))
433       MRI.setRegClass(Reg, GR->getRegClass(SpvType));
434     if (!MRI.getType(Reg).isValid())
435       MRI.setType(Reg, GR->getRegType(SpvType));
436     GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
437     return;
438   }
439 
440   // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
441   // present after each auto-folded instruction to take a type reference from.
442   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
443   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
444     MRI.setRegClass(NewReg, RC);
445   } else {
446     auto RegClass = GR->getRegClass(SpvType);
447     MRI.setRegClass(NewReg, RegClass);
448     MRI.setRegClass(Reg, RegClass);
449   }
450   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
451   // This is to make it convenient for Legalizer to get the SPIRVType
452   // when processing the actual MI (i.e. not pseudo one).
453   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
454   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
455   // the flags after instruction selection.
456   const uint32_t Flags = Def->getFlags();
457   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
458       .addDef(Reg)
459       .addUse(NewReg)
460       .addUse(GR->getSPIRVTypeID(SpvType))
461       .setMIFlags(Flags);
462   for (unsigned I = 0, E = Def->getNumDefs(); I != E; ++I) {
463     MachineOperand &MO = Def->getOperand(I);
464     if (MO.getReg() == Reg) {
465       MO.setReg(NewReg);
466       break;
467     }
468   }
469 }
470 
processInstr(MachineInstr & MI,MachineIRBuilder & MIB,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR,SPIRVType * KnownResType)471 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
472                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR,
473                   SPIRVType *KnownResType) {
474   MIB.setInsertPt(*MI.getParent(), MI.getIterator());
475   for (auto &Op : MI.operands()) {
476     if (!Op.isReg() || Op.isDef())
477       continue;
478     Register OpReg = Op.getReg();
479     SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OpReg);
480     if (!SpvType && KnownResType) {
481       SpvType = KnownResType;
482       GR->assignSPIRVTypeToVReg(KnownResType, OpReg, *MI.getMF());
483     }
484     assert(SpvType);
485     if (!MRI.getRegClassOrNull(OpReg))
486       MRI.setRegClass(OpReg, GR->getRegClass(SpvType));
487     if (!MRI.getType(OpReg).isValid())
488       MRI.setType(OpReg, GR->getRegType(SpvType));
489   }
490 }
491 } // namespace llvm
492 
493 static void
generateAssignInstrs(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB,DenseMap<MachineInstr *,Type * > & TargetExtConstTypes)494 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
495                      MachineIRBuilder MIB,
496                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
497   // Get access to information about available extensions
498   const SPIRVSubtarget *ST =
499       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
500 
501   MachineRegisterInfo &MRI = MF.getRegInfo();
502   SmallVector<MachineInstr *, 10> ToErase;
503   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
504 
505   bool IsExtendedInts =
506       ST->canUseExtension(
507           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
508       ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
509       ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
510 
511   for (MachineBasicBlock *MBB : post_order(&MF)) {
512     if (MBB->empty())
513       continue;
514 
515     bool ReachedBegin = false;
516     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
517          !ReachedBegin;) {
518       MachineInstr &MI = *MII;
519       unsigned MIOp = MI.getOpcode();
520 
521       if (!IsExtendedInts) {
522         // validate bit width of scalar registers and constant immediates
523         for (auto &MOP : MI.operands()) {
524           if (MOP.isReg())
525             widenScalarType(MOP.getReg(), MRI);
526           else if (MOP.isCImm())
527             widenCImmType(MOP);
528         }
529       }
530 
531       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
532         Register Reg = MI.getOperand(1).getReg();
533         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
534         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
535         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
536             ElementTy, MI,
537             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
538         MachineInstr *Def = MRI.getVRegDef(Reg);
539         assert(Def && "Expecting an instruction that defines the register");
540         // G_GLOBAL_VALUE already has type info.
541         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
542             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
543           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
544                             MF.getRegInfo());
545         ToErase.push_back(&MI);
546       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
547         Register Reg = MI.getOperand(1).getReg();
548         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
549         MachineInstr *Def = MRI.getVRegDef(Reg);
550         assert(Def && "Expecting an instruction that defines the register");
551         // G_GLOBAL_VALUE already has type info.
552         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
553             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
554           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
555         ToErase.push_back(&MI);
556       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
557         MachineInstr *MdMI = MI.getPrevNode();
558         if (MdMI && isSpvIntrinsic(*MdMI, Intrinsic::spv_value_md)) {
559           // It's an internal service info from before IRTranslator passes.
560           MachineInstr *Def = getVRegDef(MRI, MI.getOperand(0).getReg());
561           for (unsigned I = 1, E = MI.getNumOperands(); I != E && Def; ++I)
562             if (getVRegDef(MRI, MI.getOperand(I).getReg()) != Def)
563               Def = nullptr;
564           if (Def) {
565             const MDNode *MD = MdMI->getOperand(1).getMetadata();
566             StringRef ValueName =
567                 cast<MDString>(MD->getOperand(1))->getString();
568             const MDNode *TypeMD = cast<MDNode>(MD->getOperand(0));
569             Type *ValueTy = getMDOperandAsType(TypeMD, 0);
570             GR->addValueAttrs(Def, std::make_pair(ValueTy, ValueName.str()));
571           }
572           ToErase.push_back(MdMI);
573         }
574         ToErase.push_back(&MI);
575       } else if (MIOp == TargetOpcode::G_CONSTANT ||
576                  MIOp == TargetOpcode::G_FCONSTANT ||
577                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
578         // %rc = G_CONSTANT ty Val
579         // ===>
580         // %cty = OpType* ty
581         // %rctmp = G_CONSTANT ty Val
582         // %rc = ASSIGN_TYPE %rctmp, %cty
583         Register Reg = MI.getOperand(0).getReg();
584         bool NeedAssignType = true;
585         if (MRI.hasOneUse(Reg)) {
586           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
587           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
588               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
589             continue;
590           if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
591             NeedAssignType = false;
592         }
593         Type *Ty = nullptr;
594         if (MIOp == TargetOpcode::G_CONSTANT) {
595           auto TargetExtIt = TargetExtConstTypes.find(&MI);
596           Ty = TargetExtIt == TargetExtConstTypes.end()
597                    ? MI.getOperand(1).getCImm()->getType()
598                    : TargetExtIt->second;
599           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
600           // TODO: we may wish to analyze here if OpCI is zero and LLT RegType =
601           // MRI.getType(Reg); RegType.isPointer() is true, so that we observe
602           // at this point not i64/i32 constant but null pointer in the
603           // corresponding address space of RegType.getAddressSpace(). This may
604           // help to successfully validate the case when a OpConstantComposite's
605           // constituent has type that does not match Result Type of
606           // OpConstantComposite (see, for example,
607           // pointers/PtrCast-null-in-OpSpecConstantOp.ll).
608           Register PrimaryReg = GR->find(OpCI, &MF);
609           if (!PrimaryReg.isValid()) {
610             GR->add(OpCI, &MI);
611           } else if (PrimaryReg != Reg &&
612                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
613             auto *RCReg = MRI.getRegClassOrNull(Reg);
614             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
615             if (!RCReg || RCPrimary == RCReg) {
616               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
617               ToErase.push_back(&MI);
618               NeedAssignType = false;
619             }
620           }
621         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
622           Ty = MI.getOperand(1).getFPImm()->getType();
623         } else {
624           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
625           Type *ElemTy = nullptr;
626           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
627           assert(ElemMI);
628 
629           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
630             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
631           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
632             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
633           } else {
634             if (const SPIRVType *ElemSpvType =
635                     GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg(), &MF))
636               ElemTy = const_cast<Type *>(GR->getTypeForSPIRVType(ElemSpvType));
637             if (!ElemTy) {
638               // There may be a case when we already know Reg's type.
639               MachineInstr *NextMI = MI.getNextNode();
640               if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
641                   NextMI->getOperand(1).getReg() != Reg)
642                 llvm_unreachable("Unexpected opcode");
643             }
644           }
645           if (ElemTy)
646             Ty = VectorType::get(
647                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
648                 false);
649           else
650             NeedAssignType = false;
651         }
652         if (NeedAssignType)
653           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
654       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
655         propagateSPIRVType(&MI, GR, MRI, MIB);
656       }
657 
658       if (MII == Begin)
659         ReachedBegin = true;
660       else
661         --MII;
662     }
663   }
664   for (MachineInstr *MI : ToErase) {
665     auto It = RegsAlreadyAddedToDT.find(MI);
666     if (It != RegsAlreadyAddedToDT.end())
667       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
668     GR->invalidateMachineInstr(MI);
669     MI->eraseFromParent();
670   }
671 
672   // Address the case when IRTranslator introduces instructions with new
673   // registers without SPIRVType associated.
674   for (MachineBasicBlock &MBB : MF) {
675     for (MachineInstr &MI : MBB) {
676       switch (MI.getOpcode()) {
677       case TargetOpcode::G_TRUNC:
678       case TargetOpcode::G_ANYEXT:
679       case TargetOpcode::G_SEXT:
680       case TargetOpcode::G_ZEXT:
681       case TargetOpcode::G_PTRTOINT:
682       case TargetOpcode::COPY:
683       case TargetOpcode::G_ADDRSPACE_CAST:
684         propagateSPIRVType(&MI, GR, MRI, MIB);
685         break;
686       }
687     }
688   }
689 }
690 
processInstrsWithTypeFolding(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)691 static void processInstrsWithTypeFolding(MachineFunction &MF,
692                                          SPIRVGlobalRegistry *GR,
693                                          MachineIRBuilder MIB) {
694   MachineRegisterInfo &MRI = MF.getRegInfo();
695   for (MachineBasicBlock &MBB : MF)
696     for (MachineInstr &MI : MBB)
697       if (isTypeFoldingSupported(MI.getOpcode()))
698         processInstr(MI, MIB, MRI, GR, nullptr);
699 }
700 
701 static Register
collectInlineAsmInstrOperands(MachineInstr * MI,SmallVector<unsigned,4> * Ops=nullptr)702 collectInlineAsmInstrOperands(MachineInstr *MI,
703                               SmallVector<unsigned, 4> *Ops = nullptr) {
704   Register DefReg;
705   unsigned StartOp = InlineAsm::MIOp_FirstOperand,
706            AsmDescOp = InlineAsm::MIOp_FirstOperand;
707   for (unsigned Idx = StartOp, MISz = MI->getNumOperands(); Idx != MISz;
708        ++Idx) {
709     const MachineOperand &MO = MI->getOperand(Idx);
710     if (MO.isMetadata())
711       continue;
712     if (Idx == AsmDescOp && MO.isImm()) {
713       // compute the index of the next operand descriptor
714       const InlineAsm::Flag F(MO.getImm());
715       AsmDescOp += 1 + F.getNumOperandRegisters();
716       continue;
717     }
718     if (MO.isReg() && MO.isDef()) {
719       if (!Ops)
720         return MO.getReg();
721       else
722         DefReg = MO.getReg();
723     } else if (Ops) {
724       Ops->push_back(Idx);
725     }
726   }
727   return DefReg;
728 }
729 
730 static void
insertInlineAsmProcess(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & ST,MachineIRBuilder MIRBuilder,const SmallVector<MachineInstr * > & ToProcess)731 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
732                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
733                        const SmallVector<MachineInstr *> &ToProcess) {
734   MachineRegisterInfo &MRI = MF.getRegInfo();
735   Register AsmTargetReg;
736   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
737     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
738     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
739     MIRBuilder.setInsertPt(*I2->getParent(), *I2);
740 
741     if (!AsmTargetReg.isValid()) {
742       // define vendor specific assembly target or dialect
743       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
744       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
745       auto AsmTargetMIB =
746           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
747       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
748       GR->add(AsmTargetMIB.getInstr(), AsmTargetMIB);
749     }
750 
751     // create types
752     const MDNode *IAMD = I1->getOperand(1).getMetadata();
753     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
754     SmallVector<SPIRVType *, 4> ArgTypes;
755     for (const auto &ArgTy : FTy->params())
756       ArgTypes.push_back(GR->getOrCreateSPIRVType(
757           ArgTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true));
758     SPIRVType *RetType =
759         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder,
760                                  SPIRV::AccessQualifier::ReadWrite, true);
761     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
762         FTy, RetType, ArgTypes, MIRBuilder);
763 
764     // define vendor specific assembly instructions string
765     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
766     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
767     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
768                       .addDef(AsmReg)
769                       .addUse(GR->getSPIRVTypeID(RetType))
770                       .addUse(GR->getSPIRVTypeID(FuncType))
771                       .addUse(AsmTargetReg);
772     // inline asm string:
773     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
774                  AsmMIB);
775     // inline asm constraint string:
776     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
777                      ->getString(),
778                  AsmMIB);
779     GR->add(AsmMIB.getInstr(), AsmMIB);
780 
781     // calls the inline assembly instruction
782     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
783     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
784       MIRBuilder.buildInstr(SPIRV::OpDecorate)
785           .addUse(AsmReg)
786           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
787 
788     Register DefReg = collectInlineAsmInstrOperands(I2);
789     if (!DefReg.isValid()) {
790       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
791       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
792       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
793           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder,
794           SPIRV::AccessQualifier::ReadWrite, true);
795       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
796     }
797 
798     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
799                        .addDef(DefReg)
800                        .addUse(GR->getSPIRVTypeID(RetType))
801                        .addUse(AsmReg);
802     for (unsigned IntrIdx = 3; IntrIdx < I1->getNumOperands(); ++IntrIdx)
803       AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
804   }
805   for (MachineInstr *MI : ToProcess) {
806     GR->invalidateMachineInstr(MI);
807     MI->eraseFromParent();
808   }
809 }
810 
insertInlineAsm(MachineFunction & MF,SPIRVGlobalRegistry * GR,const SPIRVSubtarget & ST,MachineIRBuilder MIRBuilder)811 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
812                             const SPIRVSubtarget &ST,
813                             MachineIRBuilder MIRBuilder) {
814   SmallVector<MachineInstr *> ToProcess;
815   for (MachineBasicBlock &MBB : MF) {
816     for (MachineInstr &MI : MBB) {
817       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
818           MI.getOpcode() == TargetOpcode::INLINEASM)
819         ToProcess.push_back(&MI);
820     }
821   }
822   if (ToProcess.size() == 0)
823     return;
824 
825   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
826     report_fatal_error("Inline assembly instructions require the "
827                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
828                        false);
829 
830   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
831 }
832 
convertFloatToSPIRVWord(float F)833 static uint32_t convertFloatToSPIRVWord(float F) {
834   union {
835     float F;
836     uint32_t Spir;
837   } FPMaxError;
838   FPMaxError.F = F;
839   return FPMaxError.Spir;
840 }
841 
insertSpirvDecorations(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)842 static void insertSpirvDecorations(MachineFunction &MF, SPIRVGlobalRegistry *GR,
843                                    MachineIRBuilder MIB) {
844   SmallVector<MachineInstr *, 10> ToErase;
845   for (MachineBasicBlock &MBB : MF) {
846     for (MachineInstr &MI : MBB) {
847       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration) &&
848           !isSpvIntrinsic(MI, Intrinsic::spv_assign_aliasing_decoration) &&
849           !isSpvIntrinsic(MI, Intrinsic::spv_assign_fpmaxerror_decoration))
850         continue;
851       MIB.setInsertPt(*MI.getParent(), MI.getNextNode());
852       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration)) {
853         buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
854                                 MI.getOperand(2).getMetadata());
855       } else if (isSpvIntrinsic(MI,
856                                 Intrinsic::spv_assign_fpmaxerror_decoration)) {
857         ConstantFP *OpV = mdconst::dyn_extract<ConstantFP>(
858             MI.getOperand(2).getMetadata()->getOperand(0));
859         uint32_t OpValue =
860             convertFloatToSPIRVWord(OpV->getValueAPF().convertToFloat());
861 
862         buildOpDecorate(MI.getOperand(1).getReg(), MIB,
863                         SPIRV::Decoration::FPMaxErrorDecorationINTEL,
864                         {OpValue});
865       } else {
866         GR->buildMemAliasingOpDecorate(MI.getOperand(1).getReg(), MIB,
867                                        MI.getOperand(2).getImm(),
868                                        MI.getOperand(3).getMetadata());
869       }
870 
871       ToErase.push_back(&MI);
872     }
873   }
874   for (MachineInstr *MI : ToErase) {
875     GR->invalidateMachineInstr(MI);
876     MI->eraseFromParent();
877   }
878 }
879 
880 // LLVM allows the switches to use registers as cases, while SPIR-V required
881 // those to be immediate values. This function replaces such operands with the
882 // equivalent immediate constant.
processSwitchesConstants(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)883 static void processSwitchesConstants(MachineFunction &MF,
884                                      SPIRVGlobalRegistry *GR,
885                                      MachineIRBuilder MIB) {
886   MachineRegisterInfo &MRI = MF.getRegInfo();
887   for (MachineBasicBlock &MBB : MF) {
888     for (MachineInstr &MI : MBB) {
889       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
890         continue;
891 
892       SmallVector<MachineOperand, 8> NewOperands;
893       NewOperands.push_back(MI.getOperand(0)); // Opcode
894       NewOperands.push_back(MI.getOperand(1)); // Condition
895       NewOperands.push_back(MI.getOperand(2)); // Default
896       for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
897         Register Reg = MI.getOperand(i).getReg();
898         MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
899         NewOperands.push_back(
900             MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
901 
902         NewOperands.push_back(MI.getOperand(i + 1));
903       }
904 
905       assert(MI.getNumOperands() == NewOperands.size());
906       while (MI.getNumOperands() > 0)
907         MI.removeOperand(0);
908       for (auto &MO : NewOperands)
909         MI.addOperand(MO);
910     }
911   }
912 }
913 
914 // Some instructions are used during CodeGen but should never be emitted.
915 // Cleaning up those.
cleanupHelperInstructions(MachineFunction & MF,SPIRVGlobalRegistry * GR)916 static void cleanupHelperInstructions(MachineFunction &MF,
917                                       SPIRVGlobalRegistry *GR) {
918   SmallVector<MachineInstr *, 8> ToEraseMI;
919   for (MachineBasicBlock &MBB : MF) {
920     for (MachineInstr &MI : MBB) {
921       if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
922           MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
923         ToEraseMI.push_back(&MI);
924     }
925   }
926 
927   for (MachineInstr *MI : ToEraseMI) {
928     GR->invalidateMachineInstr(MI);
929     MI->eraseFromParent();
930   }
931 }
932 
933 // Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
934 // operands/registers by the actual MBB it references.
processBlockAddr(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)935 static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
936                              MachineIRBuilder MIB) {
937   // Gather the reverse-mapping BB -> MBB.
938   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
939   for (MachineBasicBlock &MBB : MF)
940     BB2MBB[MBB.getBasicBlock()] = &MBB;
941 
942   // Gather instructions requiring patching. For now, only those can use
943   // G_BLOCK_ADDR.
944   SmallVector<MachineInstr *, 8> InstructionsToPatch;
945   for (MachineBasicBlock &MBB : MF) {
946     for (MachineInstr &MI : MBB) {
947       if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
948           isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
949           isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
950         InstructionsToPatch.push_back(&MI);
951     }
952   }
953 
954   // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
955   // the actual MBB it references. Once those references have been updated, we
956   // can cleanup remaining G_BLOCK_ADDR references.
957   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
958   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
959   MachineRegisterInfo &MRI = MF.getRegInfo();
960   for (MachineInstr *MI : InstructionsToPatch) {
961     SmallVector<MachineOperand, 8> NewOps;
962     for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
963       // The operand is not a register, keep as-is.
964       if (!MI->getOperand(i).isReg()) {
965         NewOps.push_back(MI->getOperand(i));
966         continue;
967       }
968 
969       Register Reg = MI->getOperand(i).getReg();
970       MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
971       // The register is not the result of G_BLOCK_ADDR, keep as-is.
972       if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
973         NewOps.push_back(MI->getOperand(i));
974         continue;
975       }
976 
977       assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
978              BuildMBB->getOperand(1).isBlockAddress() &&
979              BuildMBB->getOperand(1).getBlockAddress());
980       BasicBlock *BB =
981           BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
982       auto It = BB2MBB.find(BB);
983       if (It == BB2MBB.end())
984         report_fatal_error("cannot find a machine basic block by a basic block "
985                            "in a switch statement");
986       MachineBasicBlock *ReferencedBlock = It->second;
987       NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
988 
989       ClearAddressTaken.insert(ReferencedBlock);
990       ToEraseMI.insert(BuildMBB);
991     }
992 
993     // Replace the operands.
994     assert(MI->getNumOperands() == NewOps.size());
995     while (MI->getNumOperands() > 0)
996       MI->removeOperand(0);
997     for (auto &MO : NewOps)
998       MI->addOperand(MO);
999 
1000     if (MachineInstr *Next = MI->getNextNode()) {
1001       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
1002         ToEraseMI.insert(Next);
1003         Next = MI->getNextNode();
1004       }
1005       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
1006         ToEraseMI.insert(Next);
1007     }
1008   }
1009 
1010   // BlockAddress operands were used to keep information between passes,
1011   // let's undo the "address taken" status to reflect that Succ doesn't
1012   // actually correspond to an IR-level basic block.
1013   for (MachineBasicBlock *Succ : ClearAddressTaken)
1014     Succ->setAddressTakenIRBlock(nullptr);
1015 
1016   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
1017   // this leaves their BasicBlock counterparts in a "address taken" status. This
1018   // would make AsmPrinter to generate a series of unneeded labels of a "Address
1019   // of block that was removed by CodeGen" kind. Let's first ensure that we
1020   // don't have a dangling BlockAddress constants by zapping the BlockAddress
1021   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
1022   Constant *Replacement =
1023       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
1024   for (MachineInstr *BlockAddrI : ToEraseMI) {
1025     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
1026       BlockAddress *BA = const_cast<BlockAddress *>(
1027           BlockAddrI->getOperand(1).getBlockAddress());
1028       BA->replaceAllUsesWith(
1029           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
1030       BA->destroyConstant();
1031     }
1032     GR->invalidateMachineInstr(BlockAddrI);
1033     BlockAddrI->eraseFromParent();
1034   }
1035 }
1036 
isImplicitFallthrough(MachineBasicBlock & MBB)1037 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
1038   if (MBB.empty())
1039     return true;
1040 
1041   // Branching SPIR-V intrinsics are not detected by this generic method.
1042   // Thus, we can only trust negative result.
1043   if (!MBB.canFallThrough())
1044     return false;
1045 
1046   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
1047   // prevent an implicit fallthrough.
1048   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
1049        It != E; ++It) {
1050     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
1051       return false;
1052   }
1053   return true;
1054 }
1055 
removeImplicitFallthroughs(MachineFunction & MF,MachineIRBuilder MIB)1056 static void removeImplicitFallthroughs(MachineFunction &MF,
1057                                        MachineIRBuilder MIB) {
1058   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
1059   // In such cases, they will simply fallthrough their immediate successor.
1060   for (MachineBasicBlock &MBB : MF) {
1061     if (!isImplicitFallthrough(MBB))
1062       continue;
1063 
1064     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
1065            1);
1066     MIB.setInsertPt(MBB, MBB.end());
1067     MIB.buildBr(**MBB.successors().begin());
1068   }
1069 }
1070 
runOnMachineFunction(MachineFunction & MF)1071 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
1072   // Initialize the type registry.
1073   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
1074   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1075   GR->setCurrentFunc(MF);
1076   MachineIRBuilder MIB(MF);
1077   // a registry of target extension constants
1078   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
1079   // to keep record of tracked constants
1080   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes);
1081   foldConstantsIntoIntrinsics(MF, GR, MIB);
1082   insertBitcasts(MF, GR, MIB);
1083   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
1084 
1085   processSwitchesConstants(MF, GR, MIB);
1086   processBlockAddr(MF, GR, MIB);
1087   cleanupHelperInstructions(MF, GR);
1088 
1089   processInstrsWithTypeFolding(MF, GR, MIB);
1090   removeImplicitFallthroughs(MF, MIB);
1091   insertSpirvDecorations(MF, GR, MIB);
1092   insertInlineAsm(MF, GR, ST, MIB);
1093   selectOpBitcasts(MF, GR, MIB);
1094 
1095   return true;
1096 }
1097 
1098 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
1099                 false)
1100 
1101 char SPIRVPreLegalizer::ID = 0;
1102 
createSPIRVPreLegalizerPass()1103 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
1104   return new SPIRVPreLegalizer();
1105 }
1106