xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86FastPreTileConfig.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- X86FastPreTileConfig.cpp - Fast Tile Register Configure------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to preconfig the shape of physical tile registers
10 /// It inserts ldtilecfg ahead of each group of tile registers. The algorithm
11 /// walk each instruction of basic block in reverse order. All the tile
12 /// registers that live out the basic block would be spilled and reloaded
13 /// before its user. It also check the depenedency of the shape to ensure
14 /// the shape is defined before ldtilecfg.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "X86.h"
19 #include "X86InstrBuilder.h"
20 #include "X86MachineFunctionInfo.h"
21 #include "X86RegisterInfo.h"
22 #include "X86Subtarget.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/Support/Debug.h"
33 
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "fastpretileconfig"
37 
38 STATISTIC(NumStores, "Number of stores added");
39 STATISTIC(NumLoads, "Number of loads added");
40 
41 namespace {
42 
43 class X86FastPreTileConfig : public MachineFunctionPass {
44   MachineFunction *MF = nullptr;
45   const X86Subtarget *ST = nullptr;
46   const TargetInstrInfo *TII = nullptr;
47   MachineRegisterInfo *MRI = nullptr;
48   X86MachineFunctionInfo *X86FI = nullptr;
49   MachineFrameInfo *MFI = nullptr;
50   const TargetRegisterInfo *TRI = nullptr;
51   MachineBasicBlock *MBB = nullptr;
52   int CfgSS = -1;
53   struct PHIInfo {
54     Register Row;
55     Register Col;
56     Register StackAddr;
57   };
58   DenseMap<MachineInstr *, struct PHIInfo> VisitedPHIs;
59 
60   /// Maps virtual regs to the frame index where these values are spilled.
61   IndexedMap<int, VirtReg2IndexFunctor> StackSlotForVirtReg;
62 
63   /// Has a bit set for tile virtual register for which it was determined
64   /// that it is alive across blocks.
65   BitVector MayLiveAcrossBlocks;
66 
67   int getStackSpaceFor(Register VirtReg);
68   void InitializeTileConfigStackSpace();
69   bool mayLiveOut(Register VirtReg, MachineInstr *CfgMI);
70   void spill(MachineBasicBlock::iterator Before, Register VirtReg, bool Kill);
71   void reload(MachineBasicBlock::iterator UseMI, Register VirtReg,
72               MachineOperand *RowMO, MachineOperand *ColMO);
73   void canonicalizePHIs(MachineBasicBlock &MBB);
74   void convertPHI(MachineBasicBlock *MBB, MachineInstr &PHI);
75   void convertPHIs(MachineBasicBlock &MBB);
76   bool configBasicBlock(MachineBasicBlock &MBB);
77 
78 public:
X86FastPreTileConfig()79   X86FastPreTileConfig() : MachineFunctionPass(ID), StackSlotForVirtReg(-1) {}
80 
81   /// Return the pass name.
getPassName() const82   StringRef getPassName() const override {
83     return "Fast Tile Register Preconfigure";
84   }
85 
86   /// Perform tile register configure.
87   bool runOnMachineFunction(MachineFunction &MFunc) override;
88 
89   static char ID;
90 };
91 
92 } // end anonymous namespace
93 
94 char X86FastPreTileConfig::ID = 0;
95 
96 INITIALIZE_PASS_BEGIN(X86FastPreTileConfig, DEBUG_TYPE,
97                       "Fast Tile Register Preconfigure", false, false)
98 INITIALIZE_PASS_END(X86FastPreTileConfig, DEBUG_TYPE,
99                     "Fast Tile Register Preconfigure", false, false)
100 
dominates(MachineBasicBlock & MBB,MachineBasicBlock::const_iterator A,MachineBasicBlock::const_iterator B)101 static bool dominates(MachineBasicBlock &MBB,
102                       MachineBasicBlock::const_iterator A,
103                       MachineBasicBlock::const_iterator B) {
104   auto MBBEnd = MBB.end();
105   if (B == MBBEnd)
106     return true;
107 
108   MachineBasicBlock::const_iterator I = MBB.begin();
109   for (; &*I != A && &*I != B; ++I)
110     ;
111 
112   return &*I == A;
113 }
114 
115 /// This allocates space for the specified virtual register to be held on the
116 /// stack.
getStackSpaceFor(Register VirtReg)117 int X86FastPreTileConfig::getStackSpaceFor(Register VirtReg) {
118   // Find the location Reg would belong...
119   int SS = StackSlotForVirtReg[VirtReg];
120   // Already has space allocated?
121   if (SS != -1)
122     return SS;
123 
124   // Allocate a new stack object for this spill location...
125   const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
126   unsigned Size = TRI->getSpillSize(RC);
127   Align Alignment = TRI->getSpillAlign(RC);
128   int FrameIdx = MFI->CreateSpillStackObject(Size, Alignment);
129 
130   // Assign the slot.
131   StackSlotForVirtReg[VirtReg] = FrameIdx;
132   return FrameIdx;
133 }
134 
135 /// Returns false if \p VirtReg is known to not live out of the current config.
136 /// If \p VirtReg live out of the current MBB, it must live out of the current
137 /// config
mayLiveOut(Register VirtReg,MachineInstr * CfgMI)138 bool X86FastPreTileConfig::mayLiveOut(Register VirtReg, MachineInstr *CfgMI) {
139   if (MayLiveAcrossBlocks.test(VirtReg.virtRegIndex()))
140     return true;
141 
142   for (const MachineInstr &UseInst : MRI->use_nodbg_instructions(VirtReg)) {
143     if (UseInst.getParent() != MBB) {
144       MayLiveAcrossBlocks.set(VirtReg.virtRegIndex());
145       return true;
146     }
147 
148     // The use and def are in the same MBB. If the tile register is
149     // reconfigured, it is crobbered and we need to spill and reload
150     // tile register.
151     if (CfgMI) {
152       if (dominates(*MBB, *CfgMI, UseInst)) {
153         MayLiveAcrossBlocks.set(VirtReg.virtRegIndex());
154         return true;
155       }
156     }
157   }
158 
159   return false;
160 }
161 
InitializeTileConfigStackSpace()162 void X86FastPreTileConfig::InitializeTileConfigStackSpace() {
163   MachineBasicBlock &MBB = MF->front();
164   MachineInstr *MI = &*MBB.getFirstNonPHI();
165   DebugLoc DL;
166   if (ST->hasAVX512()) {
167     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
168     BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
169     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), CfgSS)
170         .addReg(Zmm);
171   } else if (ST->hasAVX2()) {
172     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
173     BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
174     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS)
175         .addReg(Ymm);
176     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS,
177                       32)
178         .addReg(Ymm);
179   } else {
180     assert(ST->hasSSE2() && "AMX should assume SSE2 enabled");
181     unsigned StoreOpc = ST->hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
182     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
183     BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
184     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS)
185         .addReg(Xmm);
186     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 16)
187         .addReg(Xmm);
188     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 32)
189         .addReg(Xmm);
190     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 48)
191         .addReg(Xmm);
192   }
193   // Fill in the palette first.
194   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), CfgSS)
195       .addImm(1);
196 }
197 
198 /// Insert spill instruction for \p AssignedReg before \p Before.
199 /// TODO: Update DBG_VALUEs with \p VirtReg operands with the stack slot.
spill(MachineBasicBlock::iterator Before,Register VirtReg,bool Kill)200 void X86FastPreTileConfig::spill(MachineBasicBlock::iterator Before,
201                                  Register VirtReg, bool Kill) {
202   LLVM_DEBUG(dbgs() << "Spilling " << printReg(VirtReg, TRI) << " \n");
203   int FI = getStackSpaceFor(VirtReg);
204   LLVM_DEBUG(dbgs() << " to stack slot #" << FI << '\n');
205 
206   const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
207   // Don't need shape information for tile store, becasue it is adjacent to
208   // the tile def instruction.
209   TII->storeRegToStackSlot(*MBB, Before, VirtReg, Kill, FI, &RC, TRI,
210                            Register());
211   ++NumStores;
212 
213   // TODO: update DBG_VALUEs
214 }
215 
216 /// Insert reload instruction for \p PhysReg before \p Before.
reload(MachineBasicBlock::iterator UseMI,Register OrigReg,MachineOperand * RowMO,MachineOperand * ColMO)217 void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,
218                                   Register OrigReg, MachineOperand *RowMO,
219                                   MachineOperand *ColMO) {
220   int FI = getStackSpaceFor(OrigReg);
221   const TargetRegisterClass &RC = *MRI->getRegClass(OrigReg);
222   Register TileReg;
223   // Fold copy to tileload
224   // BB1:
225   // spill src to s
226   //
227   // BB2:
228   // t = copy src
229   // -->
230   // t = tileload (s)
231   if (UseMI->isCopy())
232     TileReg = UseMI->getOperand(0).getReg();
233   else
234     TileReg = MRI->createVirtualRegister(&RC);
235   // Can't use TII->loadRegFromStackSlot(), because we need the shape
236   // information for reload.
237   // tileloadd (%sp, %idx), %tmm
238   unsigned Opc = X86::PTILELOADDV;
239   Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
240   // FIXME: MBB is not the parent of UseMI.
241   MachineInstr *NewMI = BuildMI(*UseMI->getParent(), UseMI, DebugLoc(),
242                                 TII->get(X86::MOV64ri), StrideReg)
243                             .addImm(64);
244   NewMI = addFrameReference(
245       BuildMI(*UseMI->getParent(), UseMI, DebugLoc(), TII->get(Opc), TileReg)
246           .addReg(RowMO->getReg())
247           .addReg(ColMO->getReg()),
248       FI);
249   MachineOperand &MO = NewMI->getOperand(5);
250   MO.setReg(StrideReg);
251   MO.setIsKill(true);
252   RowMO->setIsKill(false);
253   ColMO->setIsKill(false);
254   // Erase copy instruction after it is folded.
255   if (UseMI->isCopy()) {
256     UseMI->eraseFromParent();
257   } else {
258     // Replace the register in the user MI.
259     for (auto &MO : UseMI->operands()) {
260       if (MO.isReg() && MO.getReg() == OrigReg)
261         MO.setReg(TileReg);
262     }
263   }
264 
265   ++NumLoads;
266   LLVM_DEBUG(dbgs() << "Reloading " << printReg(OrigReg, TRI) << " into "
267                     << printReg(TileReg, TRI) << '\n');
268 }
269 
getTileDefNum(MachineRegisterInfo * MRI,Register Reg)270 static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) {
271   if (Reg.isVirtual()) {
272     unsigned RegClassID = MRI->getRegClass(Reg)->getID();
273     if (RegClassID == X86::TILERegClassID)
274       return 1;
275     if (RegClassID == X86::TILEPAIRRegClassID)
276       return 2;
277   } else {
278     if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
279       return 1;
280     if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
281       return 2;
282   }
283   return 0;
284 }
285 
isTileRegister(MachineRegisterInfo * MRI,Register VirtReg)286 static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) {
287   return getTileDefNum(MRI, VirtReg) > 0;
288 }
289 
isTileDef(MachineRegisterInfo * MRI,MachineInstr & MI)290 static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
291   // The instruction must have 3 operands: tile def, row, col.
292   if (MI.isDebugInstr() || MI.getNumOperands() < 3 || !MI.isPseudo())
293     return false;
294   MachineOperand &MO = MI.getOperand(0);
295 
296   if (!MO.isReg())
297     return false;
298 
299   return getTileDefNum(MRI, MO.getReg()) > 0;
300 }
301 
getShape(MachineRegisterInfo * MRI,Register TileReg)302 static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) {
303   MachineInstr *MI = MRI->getVRegDef(TileReg);
304   if (isTileDef(MRI, *MI)) {
305     MachineOperand *RowMO = &MI->getOperand(1);
306     MachineOperand *ColMO = &MI->getOperand(2);
307     return ShapeT(RowMO, ColMO, MRI);
308   } else if (MI->isCopy()) {
309     TileReg = MI->getOperand(1).getReg();
310     return getShape(MRI, TileReg);
311   }
312 
313   // The def should not be PHI node, because we walk the MBB in reverse post
314   // order.
315   assert(MI->isPHI() && "Unexpected PHI when get shape.");
316   llvm_unreachable("Unexpected MI when get shape.");
317 }
318 
319 // BB0:
320 // spill t0 to s0
321 // BB1:
322 // spill t1 to s1
323 //
324 // BB2:
325 // t = phi [t0, bb0] [t1, bb1]
326 // -->
327 // row = phi [r0, bb0] [r1, bb1]
328 // col = phi [c0, bb0] [c1, bb1]
329 //   s = phi [s0, bb0] [s1, bb1]
330 //   t = tileload row, col, s
331 // The new instruction is inserted at the end of the phi node. The order
332 // of the original phi node is not ensured.
convertPHI(MachineBasicBlock * MBB,MachineInstr & PHI)333 void X86FastPreTileConfig::convertPHI(MachineBasicBlock *MBB,
334                                       MachineInstr &PHI) {
335   // 1. Create instruction to get stack slot address of each incoming block.
336   // 2. Create PHI node for the stack address.
337   // 3. Create PHI node for shape. If one of the incoming shape is immediate
338   //    use the immediate and delete the PHI node.
339   // 4. Create tileload instruction from the stack address.
340   Register StackAddrReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
341   MachineInstrBuilder AddrPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
342                                         TII->get(X86::PHI), StackAddrReg);
343   Register RowReg = MRI->createVirtualRegister(&X86::GR16RegClass);
344   MachineInstrBuilder RowPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
345                                        TII->get(X86::PHI), RowReg);
346   Register ColReg = MRI->createVirtualRegister(&X86::GR16RegClass);
347   MachineInstrBuilder ColPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
348                                        TII->get(X86::PHI), ColReg);
349   // Record the mapping of phi node and its row/column information.
350   VisitedPHIs[&PHI] = {RowReg, ColReg, StackAddrReg};
351 
352   for (unsigned I = 1, E = PHI.getNumOperands(); I != E; I += 2) {
353     // Get the 2 incoming value of tile register and MBB.
354     Register InTileReg = PHI.getOperand(I).getReg();
355     // Mark it as liveout, so that it will be spilled when visit
356     // the incoming MBB. Otherwise since phi will be deleted, it
357     // would miss spill when visit incoming MBB.
358     MayLiveAcrossBlocks.set(InTileReg.virtRegIndex());
359     MachineBasicBlock *InMBB = PHI.getOperand(I + 1).getMBB();
360 
361     MachineInstr *TileDefMI = MRI->getVRegDef(InTileReg);
362     MachineBasicBlock::iterator InsertPos;
363     if (TileDefMI->isPHI()) {
364       InsertPos = TileDefMI->getParent()->getFirstNonPHI();
365       if (auto It = VisitedPHIs.find(TileDefMI);
366           It != VisitedPHIs.end()) { // circular phi reference
367         //        def t1
368         //       /       \
369         //  def t2       t3 = phi(t1, t4) <--
370         //       \       /                  |
371         //      t4 = phi(t2, t3)-------------
372         //
373         // For each (row, column and stack address) append phi incoming value.
374         // Create r3 = phi(r1, r4)
375         // Create r4 = phi(r2, r3)
376         Register InRowReg = It->second.Row;
377         Register InColReg = It->second.Col;
378         Register InStackAddrReg = It->second.StackAddr;
379         RowPHI.addReg(InRowReg).addMBB(InMBB);
380         ColPHI.addReg(InColReg).addMBB(InMBB);
381         AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
382         continue;
383       } else {
384         // Recursively convert PHI to tileload
385         convertPHI(TileDefMI->getParent(), *TileDefMI);
386         // The PHI node is coverted to tileload instruction. Get the stack
387         // address from tileload operands.
388         MachineInstr *TileLoad = MRI->getVRegDef(InTileReg);
389         assert(TileLoad && TileLoad->getOpcode() == X86::PTILELOADDV);
390         Register InRowReg = TileLoad->getOperand(1).getReg();
391         Register InColReg = TileLoad->getOperand(2).getReg();
392         Register InStackAddrReg = TileLoad->getOperand(3).getReg();
393         RowPHI.addReg(InRowReg).addMBB(InMBB);
394         ColPHI.addReg(InColReg).addMBB(InMBB);
395         AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
396       }
397     } else {
398       InsertPos = TileDefMI->getIterator();
399 
400       // Fill the incoming operand of row/column phi instruction.
401       ShapeT Shape = getShape(MRI, InTileReg);
402       Shape.getRow()->setIsKill(false);
403       Shape.getCol()->setIsKill(false);
404       RowPHI.addReg(Shape.getRow()->getReg()).addMBB(InMBB);
405       ColPHI.addReg(Shape.getCol()->getReg()).addMBB(InMBB);
406 
407       // The incoming tile register live out of its def BB, it would be spilled.
408       // Create MI to get the spill stack slot address for the tile register
409       int FI = getStackSpaceFor(InTileReg);
410       Register InStackAddrReg =
411           MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
412       addOffset(BuildMI(*TileDefMI->getParent(), InsertPos, DebugLoc(),
413                         TII->get(X86::LEA64r), InStackAddrReg)
414                     .addFrameIndex(FI),
415                 0);
416       AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
417     }
418   }
419 
420   MachineBasicBlock::iterator InsertPos = MBB->getFirstNonPHI();
421   Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
422   BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::MOV64ri), StrideReg)
423       .addImm(64);
424   Register TileReg = PHI.getOperand(0).getReg();
425   MachineInstr *NewMI = addDirectMem(
426       BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::PTILELOADDV), TileReg)
427           .addReg(RowReg)
428           .addReg(ColReg),
429       StackAddrReg);
430   MachineOperand &MO = NewMI->getOperand(5);
431   MO.setReg(StrideReg);
432   MO.setIsKill(true);
433   PHI.eraseFromParent();
434   VisitedPHIs.erase(&PHI);
435 }
436 
isTileRegDef(MachineRegisterInfo * MRI,MachineInstr & MI)437 static bool isTileRegDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
438   MachineOperand &MO = MI.getOperand(0);
439   if (MO.isReg() && MO.getReg().isVirtual() && isTileRegister(MRI, MO.getReg()))
440     return true;
441   return false;
442 }
443 
canonicalizePHIs(MachineBasicBlock & MBB)444 void X86FastPreTileConfig::canonicalizePHIs(MachineBasicBlock &MBB) {
445   SmallVector<MachineInstr *, 8> PHIs;
446 
447   for (MachineInstr &MI : MBB) {
448     if (!MI.isPHI())
449       break;
450     if (!isTileRegDef(MRI, MI))
451       continue;
452     PHIs.push_back(&MI);
453   }
454   // Canonicalize the phi node first. One tile phi may depeneds previous
455   // phi node. For below case, we need convert %t4.
456   //
457   // BB0:
458   // %t3 = phi (t1 BB1, t2 BB0)
459   // %t4 = phi (t5 BB1, t3 BB0)
460   // -->
461   // %t3 = phi (t1 BB1, t2 BB0)
462   // %t4 = phi (t5 BB1, t2 BB0)
463   //
464   while (!PHIs.empty()) {
465     MachineInstr *PHI = PHIs.pop_back_val();
466 
467     // Find the operand that is incoming from the same MBB and the def
468     // is also phi node.
469     MachineOperand *InMO = nullptr;
470     MachineInstr *DefMI = nullptr;
471     for (unsigned I = 1, E = PHI->getNumOperands(); I != E; I += 2) {
472       Register InTileReg = PHI->getOperand(I).getReg();
473       MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
474       DefMI = MRI->getVRegDef(InTileReg);
475       if (InMBB != &MBB || !DefMI->isPHI())
476         continue;
477 
478       InMO = &PHI->getOperand(I);
479       break;
480     }
481     // If can't find such operand, do nothing.
482     if (!InMO)
483       continue;
484 
485     // Current phi node depends on previous phi node. Break the
486     // dependency.
487     Register DefTileReg;
488     for (unsigned I = 1, E = DefMI->getNumOperands(); I != E; I += 2) {
489       MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
490       if (InMBB != &MBB)
491         continue;
492       DefTileReg = DefMI->getOperand(I).getReg();
493       InMO->setReg(DefTileReg);
494       break;
495     }
496   }
497 }
498 
convertPHIs(MachineBasicBlock & MBB)499 void X86FastPreTileConfig::convertPHIs(MachineBasicBlock &MBB) {
500   SmallVector<MachineInstr *, 8> PHIs;
501   for (MachineInstr &MI : MBB) {
502     if (!MI.isPHI())
503       break;
504     if (!isTileRegDef(MRI, MI))
505       continue;
506     PHIs.push_back(&MI);
507   }
508   while (!PHIs.empty()) {
509     MachineInstr *MI = PHIs.pop_back_val();
510     VisitedPHIs.clear();
511     convertPHI(&MBB, *MI);
512   }
513 }
514 
515 // PreTileConfig should configure the tile registers based on basic
516 // block.
configBasicBlock(MachineBasicBlock & MBB)517 bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
518   this->MBB = &MBB;
519   bool Change = false;
520   MachineInstr *LastShapeMI = nullptr;
521   MachineInstr *LastTileCfg = nullptr;
522   bool HasUnconfigTile = false;
523 
524   auto Config = [&](MachineInstr &Before) {
525     if (CfgSS == -1)
526       CfgSS = MFI->CreateStackObject(ST->getTileConfigSize(),
527                                      ST->getTileConfigAlignment(), false);
528     LastTileCfg = addFrameReference(
529         BuildMI(MBB, Before, DebugLoc(), TII->get(X86::PLDTILECFGV)), CfgSS);
530     LastShapeMI = nullptr;
531     Change = true;
532   };
533   auto HasTileOperand = [](MachineRegisterInfo *MRI, MachineInstr &MI) {
534     for (const MachineOperand &MO : MI.operands()) {
535       if (!MO.isReg())
536         continue;
537       Register Reg = MO.getReg();
538       if (Reg.isVirtual() && isTileRegister(MRI, Reg))
539         return true;
540     }
541     return false;
542   };
543   for (MachineInstr &MI : reverse(MBB)) {
544     // We have transformed phi node before configuring BB.
545     if (MI.isPHI())
546       break;
547     // Don't collect the shape of used tile, the tile should be defined
548     // before the tile use. Spill and reload would happen if there is only
549     // tile use after ldtilecfg, so the shape can be collected from reload.
550     // Take below code for example. %t would be reloaded before tilestore
551     // call
552     // ....
553     // tilestore %r, %c, %t
554     // -->
555     // call
556     // ldtilecfg
557     // %t = tileload %r, %c
558     // tilestore %r, %c, %t
559     if (HasTileOperand(MRI, MI))
560       HasUnconfigTile = true;
561     // According to AMX ABI, all the tile registers including config register
562     // are volatile. Caller need to save/restore config register.
563     if (MI.isCall() && HasUnconfigTile) {
564       MachineBasicBlock::iterator I;
565       if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
566         I = ++LastShapeMI->getIterator();
567       else
568         I = ++MI.getIterator();
569       Config(*I);
570       HasUnconfigTile = false;
571       continue;
572     }
573     if (!isTileDef(MRI, MI))
574       continue;
575     //
576     //---------------------------------------------------------------------
577     // Don't handle COPY instruction. If the src and dst of the COPY can be
578     // in the same config in below case, we just check the shape of t0.
579     // def row0
580     // def col0
581     // ldtilecfg
582     // t0 = tielzero(row0, col0)
583     // t1 = copy t0
584     // ...
585     // If the src and dst of the COPY can NOT be in the same config in below
586     // case. Reload would be generated befor the copy instruction.
587     // def row0
588     // def col0
589     // t0 = tielzero(row0, col0)
590     // spill t0
591     // ...
592     // def row1
593     // def col1
594     // ldtilecfg
595     // t1 = tilezero(row1, col1)
596     // reload t0
597     // t1 = copy t0
598     //---------------------------------------------------------------------
599     //
600     // If MI dominate the last shape def instruction, we need insert
601     // ldtilecfg after LastShapeMI now. The config doesn't include
602     // current MI.
603     //   def row0
604     //   def col0
605     //   tilezero(row0, col0)  <- MI
606     //   def row1
607     //   def col1
608     //   ldtilecfg             <- insert
609     //   tilezero(row1, col1)
610     if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
611       Config(*(++LastShapeMI->getIterator()));
612     MachineOperand *RowMO = &MI.getOperand(1);
613     MachineOperand *ColMO = &MI.getOperand(2);
614     MachineInstr *RowMI = MRI->getVRegDef(RowMO->getReg());
615     MachineInstr *ColMI = MRI->getVRegDef(ColMO->getReg());
616     // If the shape is defined in current MBB, check the domination.
617     // FIXME how about loop?
618     if (RowMI->getParent() == &MBB) {
619       if (!LastShapeMI)
620         LastShapeMI = RowMI;
621       else if (dominates(MBB, LastShapeMI, RowMI))
622         LastShapeMI = RowMI;
623     }
624     if (ColMI->getParent() == &MBB) {
625       if (!LastShapeMI)
626         LastShapeMI = ColMI;
627       else if (dominates(MBB, LastShapeMI, ColMI))
628         LastShapeMI = ColMI;
629     }
630     unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg());
631     if (TileDefNum > 1) {
632       for (unsigned I = 1; I < TileDefNum; I++) {
633         MachineOperand *ColxMO = &MI.getOperand(2 + I);
634         MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg());
635         if (ColxMI->getParent() == &MBB) {
636           if (!LastShapeMI)
637             LastShapeMI = ColxMI;
638           else if (dominates(MBB, LastShapeMI, ColxMI))
639             LastShapeMI = ColxMI;
640         }
641       }
642     }
643     // If there is user live out of the tilecfg, spill it and reload in
644     // before the user.
645     Register TileReg = MI.getOperand(0).getReg();
646     if (mayLiveOut(TileReg, LastTileCfg))
647       spill(++MI.getIterator(), TileReg, false);
648     for (MachineInstr &UseMI : MRI->use_instructions(TileReg)) {
649       if (UseMI.getParent() == &MBB) {
650         // check user should not across ldtilecfg
651         if (!LastTileCfg || !dominates(MBB, LastTileCfg, UseMI))
652           continue;
653         // reload befor UseMI
654         reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
655       } else {
656         // Don't reload for phi instruction, we handle phi reload separately.
657         // TODO: merge the reload for the same user MBB.
658         if (!UseMI.isPHI())
659           reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
660       }
661     }
662   }
663 
664   // Configure tile registers at the head of the MBB
665   if (HasUnconfigTile) {
666     MachineInstr *Before;
667     if (LastShapeMI == nullptr || LastShapeMI->isPHI())
668       Before = &*MBB.getFirstNonPHI();
669     else
670       Before = &*(++LastShapeMI->getIterator());
671 
672     Config(*Before);
673   }
674 
675   return Change;
676 }
677 
runOnMachineFunction(MachineFunction & MFunc)678 bool X86FastPreTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
679   X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
680   // Early exit in the common case of non-AMX code.
681   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
682     return false;
683 
684   MF = &MFunc;
685   MRI = &MFunc.getRegInfo();
686   ST = &MFunc.getSubtarget<X86Subtarget>();
687   TII = ST->getInstrInfo();
688   MFI = &MFunc.getFrameInfo();
689   TRI = ST->getRegisterInfo();
690   CfgSS = -1;
691 
692   unsigned NumVirtRegs = MRI->getNumVirtRegs();
693 
694   StackSlotForVirtReg.resize(NumVirtRegs);
695   MayLiveAcrossBlocks.clear();
696   // We will create register during config. *3 is to make sure
697   // the virtual register number doesn't exceed the size of
698   // the bit vector.
699   MayLiveAcrossBlocks.resize(NumVirtRegs * 3);
700   bool Change = false;
701   assert(MRI->isSSA());
702 
703   // Canonicalize the phi node first.
704   for (MachineBasicBlock &MBB : MFunc)
705     canonicalizePHIs(MBB);
706 
707   // Loop over all of the basic blocks in reverse post order and insert
708   // ldtilecfg for tile registers. The reserse post order is to facilitate
709   // PHI node convert.
710   ReversePostOrderTraversal<MachineFunction *> RPOT(MF);
711   for (MachineBasicBlock *MBB : RPOT) {
712     convertPHIs(*MBB);
713     Change |= configBasicBlock(*MBB);
714   }
715 
716   if (Change)
717     InitializeTileConfigStackSpace();
718 
719   StackSlotForVirtReg.clear();
720   return Change;
721 }
722 
createX86FastPreTileConfigPass()723 FunctionPass *llvm::createX86FastPreTileConfigPass() {
724   return new X86FastPreTileConfig();
725 }
726