xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86FastPreTileConfig.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- X86FastPreTileConfig.cpp - Fast Tile Register Configure------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to preconfig the shape of physical tile registers
10 /// It inserts ldtilecfg ahead of each group of tile registers. The algorithm
11 /// walk each instruction of basic block in reverse order. All the tile
12 /// registers that live out the basic block would be spilled and reloaded
13 /// before its user. It also check the depenedency of the shape to ensure
14 /// the shape is defined before ldtilecfg.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "X86.h"
19 #include "X86InstrBuilder.h"
20 #include "X86MachineFunctionInfo.h"
21 #include "X86RegisterInfo.h"
22 #include "X86Subtarget.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/Debug.h"
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "fastpretileconfig"
38 
39 STATISTIC(NumStores, "Number of stores added");
40 STATISTIC(NumLoads, "Number of loads added");
41 
42 namespace {
43 
44 class X86FastPreTileConfig : public MachineFunctionPass {
45   MachineFunction *MF = nullptr;
46   const X86Subtarget *ST = nullptr;
47   const TargetInstrInfo *TII = nullptr;
48   MachineRegisterInfo *MRI = nullptr;
49   X86MachineFunctionInfo *X86FI = nullptr;
50   MachineFrameInfo *MFI = nullptr;
51   const TargetRegisterInfo *TRI = nullptr;
52   MachineBasicBlock *MBB = nullptr;
53   int CfgSS = -1;
54   struct PHIInfo {
55     Register Row;
56     Register Col;
57     Register StackAddr;
58   };
59   DenseMap<MachineInstr *, struct PHIInfo> VisitedPHIs;
60 
61   /// Maps virtual regs to the frame index where these values are spilled.
62   IndexedMap<int, VirtReg2IndexFunctor> StackSlotForVirtReg;
63 
64   /// Has a bit set for tile virtual register for which it was determined
65   /// that it is alive across blocks.
66   BitVector MayLiveAcrossBlocks;
67 
68   int getStackSpaceFor(Register VirtReg);
69   void InitializeTileConfigStackSpace();
70   bool mayLiveOut(Register VirtReg, MachineInstr *CfgMI);
71   void spill(MachineBasicBlock::iterator Before, Register VirtReg, bool Kill);
72   void reload(MachineBasicBlock::iterator UseMI, Register VirtReg,
73               MachineOperand *RowMO, MachineOperand *ColMO);
74   void canonicalizePHIs(MachineBasicBlock &MBB);
75   void convertPHI(MachineBasicBlock *MBB, MachineInstr &PHI);
76   void convertPHIs(MachineBasicBlock &MBB);
77   bool configBasicBlock(MachineBasicBlock &MBB);
78 
79 public:
80   X86FastPreTileConfig() : MachineFunctionPass(ID), StackSlotForVirtReg(-1) {}
81 
82   /// Return the pass name.
83   StringRef getPassName() const override {
84     return "Fast Tile Register Preconfigure";
85   }
86 
87   /// Perform tile register configure.
88   bool runOnMachineFunction(MachineFunction &MFunc) override;
89 
90   static char ID;
91 };
92 
93 } // end anonymous namespace
94 
95 char X86FastPreTileConfig::ID = 0;
96 
97 INITIALIZE_PASS_BEGIN(X86FastPreTileConfig, DEBUG_TYPE,
98                       "Fast Tile Register Preconfigure", false, false)
99 INITIALIZE_PASS_END(X86FastPreTileConfig, DEBUG_TYPE,
100                     "Fast Tile Register Preconfigure", false, false)
101 
102 static bool dominates(MachineBasicBlock &MBB,
103                       MachineBasicBlock::const_iterator A,
104                       MachineBasicBlock::const_iterator B) {
105   auto MBBEnd = MBB.end();
106   if (B == MBBEnd)
107     return true;
108 
109   MachineBasicBlock::const_iterator I = MBB.begin();
110   for (; &*I != A && &*I != B; ++I)
111     ;
112 
113   return &*I == A;
114 }
115 
116 /// This allocates space for the specified virtual register to be held on the
117 /// stack.
118 int X86FastPreTileConfig::getStackSpaceFor(Register VirtReg) {
119   // Find the location Reg would belong...
120   int SS = StackSlotForVirtReg[VirtReg];
121   // Already has space allocated?
122   if (SS != -1)
123     return SS;
124 
125   // Allocate a new stack object for this spill location...
126   const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
127   unsigned Size = TRI->getSpillSize(RC);
128   Align Alignment = TRI->getSpillAlign(RC);
129   int FrameIdx = MFI->CreateSpillStackObject(Size, Alignment);
130 
131   // Assign the slot.
132   StackSlotForVirtReg[VirtReg] = FrameIdx;
133   return FrameIdx;
134 }
135 
136 /// Returns false if \p VirtReg is known to not live out of the current config.
137 /// If \p VirtReg live out of the current MBB, it must live out of the current
138 /// config
139 bool X86FastPreTileConfig::mayLiveOut(Register VirtReg, MachineInstr *CfgMI) {
140   if (MayLiveAcrossBlocks.test(Register::virtReg2Index(VirtReg)))
141     return true;
142 
143   for (const MachineInstr &UseInst : MRI->use_nodbg_instructions(VirtReg)) {
144     if (UseInst.getParent() != MBB) {
145       MayLiveAcrossBlocks.set(Register::virtReg2Index(VirtReg));
146       return true;
147     }
148 
149     // The use and def are in the same MBB. If the tile register is
150     // reconfigured, it is crobbered and we need to spill and reload
151     // tile register.
152     if (CfgMI) {
153       if (dominates(*MBB, *CfgMI, UseInst)) {
154         MayLiveAcrossBlocks.set(Register::virtReg2Index(VirtReg));
155         return true;
156       }
157     }
158   }
159 
160   return false;
161 }
162 
163 void X86FastPreTileConfig::InitializeTileConfigStackSpace() {
164   MachineBasicBlock &MBB = MF->front();
165   MachineInstr *MI = &*MBB.getFirstNonPHI();
166   DebugLoc DL;
167   if (ST->hasAVX512()) {
168     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
169     BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
170     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), CfgSS)
171         .addReg(Zmm);
172   } else if (ST->hasAVX2()) {
173     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
174     BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
175     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS)
176         .addReg(Ymm);
177     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS,
178                       32)
179         .addReg(Ymm);
180   } else {
181     assert(ST->hasSSE2() && "AMX should assume SSE2 enabled");
182     unsigned StoreOpc = ST->hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
183     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
184     BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
185     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS)
186         .addReg(Xmm);
187     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 16)
188         .addReg(Xmm);
189     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 32)
190         .addReg(Xmm);
191     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 48)
192         .addReg(Xmm);
193   }
194   // Fill in the palette first.
195   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), CfgSS)
196       .addImm(1);
197 }
198 
199 /// Insert spill instruction for \p AssignedReg before \p Before.
200 /// TODO: Update DBG_VALUEs with \p VirtReg operands with the stack slot.
201 void X86FastPreTileConfig::spill(MachineBasicBlock::iterator Before,
202                                  Register VirtReg, bool Kill) {
203   LLVM_DEBUG(dbgs() << "Spilling " << printReg(VirtReg, TRI) << " \n");
204   int FI = getStackSpaceFor(VirtReg);
205   LLVM_DEBUG(dbgs() << " to stack slot #" << FI << '\n');
206 
207   const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
208   // Don't need shape information for tile store, becasue it is adjacent to
209   // the tile def instruction.
210   TII->storeRegToStackSlot(*MBB, Before, VirtReg, Kill, FI, &RC, TRI,
211                            Register());
212   ++NumStores;
213 
214   // TODO: update DBG_VALUEs
215 }
216 
217 /// Insert reload instruction for \p PhysReg before \p Before.
218 void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,
219                                   Register OrigReg, MachineOperand *RowMO,
220                                   MachineOperand *ColMO) {
221   int FI = getStackSpaceFor(OrigReg);
222   const TargetRegisterClass &RC = *MRI->getRegClass(OrigReg);
223   Register TileReg;
224   // Fold copy to tileload
225   // BB1:
226   // spill src to s
227   //
228   // BB2:
229   // t = copy src
230   // -->
231   // t = tileload (s)
232   if (UseMI->isCopy())
233     TileReg = UseMI->getOperand(0).getReg();
234   else
235     TileReg = MRI->createVirtualRegister(&RC);
236   // Can't use TII->loadRegFromStackSlot(), because we need the shape
237   // information for reload.
238   // tileloadd (%sp, %idx), %tmm
239   unsigned Opc = X86::PTILELOADDV;
240   Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
241   // FIXME: MBB is not the parent of UseMI.
242   MachineInstr *NewMI = BuildMI(*UseMI->getParent(), UseMI, DebugLoc(),
243                                 TII->get(X86::MOV64ri), StrideReg)
244                             .addImm(64);
245   NewMI = addFrameReference(
246       BuildMI(*UseMI->getParent(), UseMI, DebugLoc(), TII->get(Opc), TileReg)
247           .addReg(RowMO->getReg())
248           .addReg(ColMO->getReg()),
249       FI);
250   MachineOperand &MO = NewMI->getOperand(5);
251   MO.setReg(StrideReg);
252   MO.setIsKill(true);
253   RowMO->setIsKill(false);
254   ColMO->setIsKill(false);
255   // Erase copy instruction after it is folded.
256   if (UseMI->isCopy()) {
257     UseMI->eraseFromParent();
258   } else {
259     // Replace the register in the user MI.
260     for (auto &MO : UseMI->operands()) {
261       if (MO.isReg() && MO.getReg() == OrigReg)
262         MO.setReg(TileReg);
263     }
264   }
265 
266   ++NumLoads;
267   LLVM_DEBUG(dbgs() << "Reloading " << printReg(OrigReg, TRI) << " into "
268                     << printReg(TileReg, TRI) << '\n');
269 }
270 
271 static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
272   // The instruction must have 3 operands: tile def, row, col.
273   if (MI.isDebugInstr() || MI.getNumOperands() < 3 || !MI.isPseudo())
274     return false;
275   MachineOperand &MO = MI.getOperand(0);
276 
277   if (MO.isReg()) {
278     Register Reg = MO.getReg();
279     // FIXME it may be used after Greedy RA and the physical
280     // register is not rewritten yet.
281     if (Reg.isVirtual() &&
282         MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
283       return true;
284     if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
285       return true;
286   }
287 
288   return false;
289 }
290 
291 static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) {
292   MachineInstr *MI = MRI->getVRegDef(TileReg);
293   if (isTileDef(MRI, *MI)) {
294     MachineOperand *RowMO = &MI->getOperand(1);
295     MachineOperand *ColMO = &MI->getOperand(2);
296     return ShapeT(RowMO, ColMO, MRI);
297   } else if (MI->isCopy()) {
298     TileReg = MI->getOperand(1).getReg();
299     return getShape(MRI, TileReg);
300   }
301 
302   // The def should not be PHI node, because we walk the MBB in reverse post
303   // order.
304   assert(MI->isPHI() && "Unexpected PHI when get shape.");
305   llvm_unreachable("Unexpected MI when get shape.");
306 }
307 
308 // BB0:
309 // spill t0 to s0
310 // BB1:
311 // spill t1 to s1
312 //
313 // BB2:
314 // t = phi [t0, bb0] [t1, bb1]
315 // -->
316 // row = phi [r0, bb0] [r1, bb1]
317 // col = phi [c0, bb0] [c1, bb1]
318 //   s = phi [s0, bb0] [s1, bb1]
319 //   t = tileload row, col, s
320 // The new instruction is inserted at the end of the phi node. The order
321 // of the original phi node is not ensured.
322 void X86FastPreTileConfig::convertPHI(MachineBasicBlock *MBB,
323                                       MachineInstr &PHI) {
324   // 1. Create instruction to get stack slot address of each incoming block.
325   // 2. Create PHI node for the stack address.
326   // 3. Create PHI node for shape. If one of the incoming shape is immediate
327   //    use the immediate and delete the PHI node.
328   // 4. Create tileload instruction from the stack address.
329   Register StackAddrReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
330   MachineInstrBuilder AddrPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
331                                         TII->get(X86::PHI), StackAddrReg);
332   Register RowReg = MRI->createVirtualRegister(&X86::GR16RegClass);
333   MachineInstrBuilder RowPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
334                                        TII->get(X86::PHI), RowReg);
335   Register ColReg = MRI->createVirtualRegister(&X86::GR16RegClass);
336   MachineInstrBuilder ColPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
337                                        TII->get(X86::PHI), ColReg);
338   // Record the mapping of phi node and its row/column information.
339   VisitedPHIs[&PHI] = {RowReg, ColReg, StackAddrReg};
340 
341   for (unsigned I = 1, E = PHI.getNumOperands(); I != E; I += 2) {
342     // Get the 2 incoming value of tile register and MBB.
343     Register InTileReg = PHI.getOperand(I).getReg();
344     // Mark it as liveout, so that it will be spilled when visit
345     // the incoming MBB. Otherwise since phi will be deleted, it
346     // would miss spill when visit incoming MBB.
347     MayLiveAcrossBlocks.set(Register::virtReg2Index(InTileReg));
348     MachineBasicBlock *InMBB = PHI.getOperand(I + 1).getMBB();
349 
350     MachineInstr *TileDefMI = MRI->getVRegDef(InTileReg);
351     MachineBasicBlock::iterator InsertPos;
352     if (TileDefMI->isPHI()) {
353       InsertPos = TileDefMI->getParent()->getFirstNonPHI();
354       if (VisitedPHIs.count(TileDefMI)) { // circular phi reference
355         //        def t1
356         //       /       \
357         //  def t2       t3 = phi(t1, t4) <--
358         //       \       /                  |
359         //      t4 = phi(t2, t3)-------------
360         //
361         // For each (row, column and stack address) append phi incoming value.
362         // Create r3 = phi(r1, r4)
363         // Create r4 = phi(r2, r3)
364         Register InRowReg = VisitedPHIs[TileDefMI].Row;
365         Register InColReg = VisitedPHIs[TileDefMI].Col;
366         Register InStackAddrReg = VisitedPHIs[TileDefMI].StackAddr;
367         RowPHI.addReg(InRowReg).addMBB(InMBB);
368         ColPHI.addReg(InColReg).addMBB(InMBB);
369         AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
370         continue;
371       } else {
372         // Recursively convert PHI to tileload
373         convertPHI(TileDefMI->getParent(), *TileDefMI);
374         // The PHI node is coverted to tileload instruction. Get the stack
375         // address from tileload operands.
376         MachineInstr *TileLoad = MRI->getVRegDef(InTileReg);
377         assert(TileLoad && TileLoad->getOpcode() == X86::PTILELOADDV);
378         Register InRowReg = TileLoad->getOperand(1).getReg();
379         Register InColReg = TileLoad->getOperand(2).getReg();
380         Register InStackAddrReg = TileLoad->getOperand(3).getReg();
381         RowPHI.addReg(InRowReg).addMBB(InMBB);
382         ColPHI.addReg(InColReg).addMBB(InMBB);
383         AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
384       }
385     } else {
386       InsertPos = TileDefMI->getIterator();
387 
388       // Fill the incoming operand of row/column phi instruction.
389       ShapeT Shape = getShape(MRI, InTileReg);
390       Shape.getRow()->setIsKill(false);
391       Shape.getCol()->setIsKill(false);
392       RowPHI.addReg(Shape.getRow()->getReg()).addMBB(InMBB);
393       ColPHI.addReg(Shape.getCol()->getReg()).addMBB(InMBB);
394 
395       // The incoming tile register live out of its def BB, it would be spilled.
396       // Create MI to get the spill stack slot address for the tile register
397       int FI = getStackSpaceFor(InTileReg);
398       Register InStackAddrReg =
399           MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
400       addOffset(BuildMI(*TileDefMI->getParent(), InsertPos, DebugLoc(),
401                         TII->get(X86::LEA64r), InStackAddrReg)
402                     .addFrameIndex(FI),
403                 0);
404       AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
405     }
406   }
407 
408   MachineBasicBlock::iterator InsertPos = MBB->getFirstNonPHI();
409   Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
410   BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::MOV64ri), StrideReg)
411       .addImm(64);
412   Register TileReg = PHI.getOperand(0).getReg();
413   MachineInstr *NewMI = addDirectMem(
414       BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::PTILELOADDV), TileReg)
415           .addReg(RowReg)
416           .addReg(ColReg),
417       StackAddrReg);
418   MachineOperand &MO = NewMI->getOperand(5);
419   MO.setReg(StrideReg);
420   MO.setIsKill(true);
421   PHI.eraseFromParent();
422   VisitedPHIs.erase(&PHI);
423 }
424 
425 static bool isTileRegDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
426   MachineOperand &MO = MI.getOperand(0);
427   if (MO.isReg() && MO.getReg().isVirtual() &&
428       MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID)
429     return true;
430   return false;
431 }
432 
433 void X86FastPreTileConfig::canonicalizePHIs(MachineBasicBlock &MBB) {
434   SmallVector<MachineInstr *, 8> PHIs;
435 
436   for (MachineInstr &MI : MBB) {
437     if (!MI.isPHI())
438       break;
439     if (!isTileRegDef(MRI, MI))
440       continue;
441     PHIs.push_back(&MI);
442   }
443   // Canonicalize the phi node first. One tile phi may depeneds previous
444   // phi node. For below case, we need convert %t4.
445   //
446   // BB0:
447   // %t3 = phi (t1 BB1, t2 BB0)
448   // %t4 = phi (t5 BB1, t3 BB0)
449   // -->
450   // %t3 = phi (t1 BB1, t2 BB0)
451   // %t4 = phi (t5 BB1, t2 BB0)
452   //
453   while (!PHIs.empty()) {
454     MachineInstr *PHI = PHIs.pop_back_val();
455 
456     // Find the operand that is incoming from the same MBB and the def
457     // is also phi node.
458     MachineOperand *InMO = nullptr;
459     MachineInstr *DefMI = nullptr;
460     for (unsigned I = 1, E = PHI->getNumOperands(); I != E; I += 2) {
461       Register InTileReg = PHI->getOperand(I).getReg();
462       MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
463       DefMI = MRI->getVRegDef(InTileReg);
464       if (InMBB != &MBB || !DefMI->isPHI())
465         continue;
466 
467       InMO = &PHI->getOperand(I);
468       break;
469     }
470     // If can't find such operand, do nothing.
471     if (!InMO)
472       continue;
473 
474     // Current phi node depends on previous phi node. Break the
475     // dependency.
476     Register DefTileReg;
477     for (unsigned I = 1, E = DefMI->getNumOperands(); I != E; I += 2) {
478       MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
479       if (InMBB != &MBB)
480         continue;
481       DefTileReg = DefMI->getOperand(I).getReg();
482       InMO->setReg(DefTileReg);
483       break;
484     }
485   }
486 }
487 
488 void X86FastPreTileConfig::convertPHIs(MachineBasicBlock &MBB) {
489   SmallVector<MachineInstr *, 8> PHIs;
490   for (MachineInstr &MI : MBB) {
491     if (!MI.isPHI())
492       break;
493     if (!isTileRegDef(MRI, MI))
494       continue;
495     PHIs.push_back(&MI);
496   }
497   while (!PHIs.empty()) {
498     MachineInstr *MI = PHIs.pop_back_val();
499     VisitedPHIs.clear();
500     convertPHI(&MBB, *MI);
501   }
502 }
503 
504 // PreTileConfig should configure the tile registers based on basic
505 // block.
506 bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
507   this->MBB = &MBB;
508   bool Change = false;
509   MachineInstr *LastShapeMI = nullptr;
510   MachineInstr *LastTileCfg = nullptr;
511   bool HasUnconfigTile = false;
512 
513   auto Config = [&](MachineInstr &Before) {
514     if (CfgSS == -1)
515       CfgSS = MFI->CreateStackObject(ST->getTileConfigSize(),
516                                      ST->getTileConfigAlignment(), false);
517     LastTileCfg = addFrameReference(
518         BuildMI(MBB, Before, DebugLoc(), TII->get(X86::PLDTILECFGV)), CfgSS);
519     LastShapeMI = nullptr;
520     Change = true;
521   };
522   auto HasTileOperand = [](MachineRegisterInfo *MRI, MachineInstr &MI) {
523     for (const MachineOperand &MO : MI.operands()) {
524       if (!MO.isReg())
525         continue;
526       Register Reg = MO.getReg();
527       if (Reg.isVirtual() &&
528           MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
529         return true;
530     }
531     return false;
532   };
533   for (MachineInstr &MI : reverse(MBB)) {
534     // We have transformed phi node before configuring BB.
535     if (MI.isPHI())
536       break;
537     // Don't collect the shape of used tile, the tile should be defined
538     // before the tile use. Spill and reload would happen if there is only
539     // tile use after ldtilecfg, so the shape can be collected from reload.
540     // Take below code for example. %t would be reloaded before tilestore
541     // call
542     // ....
543     // tilestore %r, %c, %t
544     // -->
545     // call
546     // ldtilecfg
547     // %t = tileload %r, %c
548     // tilestore %r, %c, %t
549     if (HasTileOperand(MRI, MI))
550       HasUnconfigTile = true;
551     // According to AMX ABI, all the tile registers including config register
552     // are volatile. Caller need to save/restore config register.
553     if (MI.isCall() && HasUnconfigTile) {
554       MachineBasicBlock::iterator I;
555       if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
556         I = ++LastShapeMI->getIterator();
557       else
558         I = ++MI.getIterator();
559       Config(*I);
560       HasUnconfigTile = false;
561       continue;
562     }
563     if (!isTileDef(MRI, MI))
564       continue;
565     //
566     //---------------------------------------------------------------------
567     // Don't handle COPY instruction. If the src and dst of the COPY can be
568     // in the same config in below case, we just check the shape of t0.
569     // def row0
570     // def col0
571     // ldtilecfg
572     // t0 = tielzero(row0, col0)
573     // t1 = copy t0
574     // ...
575     // If the src and dst of the COPY can NOT be in the same config in below
576     // case. Reload would be generated befor the copy instruction.
577     // def row0
578     // def col0
579     // t0 = tielzero(row0, col0)
580     // spill t0
581     // ...
582     // def row1
583     // def col1
584     // ldtilecfg
585     // t1 = tilezero(row1, col1)
586     // reload t0
587     // t1 = copy t0
588     //---------------------------------------------------------------------
589     //
590     // If MI dominate the last shape def instruction, we need insert
591     // ldtilecfg after LastShapeMI now. The config doesn't include
592     // current MI.
593     //   def row0
594     //   def col0
595     //   tilezero(row0, col0)  <- MI
596     //   def row1
597     //   def col1
598     //   ldtilecfg             <- insert
599     //   tilezero(row1, col1)
600     if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
601       Config(*(++LastShapeMI->getIterator()));
602     MachineOperand *RowMO = &MI.getOperand(1);
603     MachineOperand *ColMO = &MI.getOperand(2);
604     MachineInstr *RowMI = MRI->getVRegDef(RowMO->getReg());
605     MachineInstr *ColMI = MRI->getVRegDef(ColMO->getReg());
606     // If the shape is defined in current MBB, check the domination.
607     // FIXME how about loop?
608     if (RowMI->getParent() == &MBB) {
609       if (!LastShapeMI)
610         LastShapeMI = RowMI;
611       else if (dominates(MBB, LastShapeMI, RowMI))
612         LastShapeMI = RowMI;
613     }
614     if (ColMI->getParent() == &MBB) {
615       if (!LastShapeMI)
616         LastShapeMI = ColMI;
617       else if (dominates(MBB, LastShapeMI, ColMI))
618         LastShapeMI = ColMI;
619     }
620     // If there is user live out of the tilecfg, spill it and reload in
621     // before the user.
622     Register TileReg = MI.getOperand(0).getReg();
623     if (mayLiveOut(TileReg, LastTileCfg))
624       spill(++MI.getIterator(), TileReg, false);
625     for (MachineInstr &UseMI : MRI->use_instructions(TileReg)) {
626       if (UseMI.getParent() == &MBB) {
627         // check user should not across ldtilecfg
628         if (!LastTileCfg || !dominates(MBB, LastTileCfg, UseMI))
629           continue;
630         // reload befor UseMI
631         reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
632       } else {
633         // Don't reload for phi instruction, we handle phi reload separately.
634         // TODO: merge the reload for the same user MBB.
635         if (!UseMI.isPHI())
636           reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
637       }
638     }
639   }
640 
641   // Configure tile registers at the head of the MBB
642   if (HasUnconfigTile) {
643     MachineInstr *Before;
644     if (LastShapeMI == nullptr || LastShapeMI->isPHI())
645       Before = &*MBB.getFirstNonPHI();
646     else
647       Before = &*(++LastShapeMI->getIterator());
648 
649     Config(*Before);
650   }
651 
652   return Change;
653 }
654 
655 bool X86FastPreTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
656   X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
657   // Early exit in the common case of non-AMX code.
658   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
659     return false;
660 
661   MF = &MFunc;
662   MRI = &MFunc.getRegInfo();
663   ST = &MFunc.getSubtarget<X86Subtarget>();
664   TII = ST->getInstrInfo();
665   MFI = &MFunc.getFrameInfo();
666   TRI = ST->getRegisterInfo();
667   CfgSS = -1;
668 
669   unsigned NumVirtRegs = MRI->getNumVirtRegs();
670 
671   StackSlotForVirtReg.resize(NumVirtRegs);
672   MayLiveAcrossBlocks.clear();
673   // We will create register during config. *3 is to make sure
674   // the virtual register number doesn't exceed the size of
675   // the bit vector.
676   MayLiveAcrossBlocks.resize(NumVirtRegs * 3);
677   bool Change = false;
678   assert(MRI->isSSA());
679 
680   // Canonicalize the phi node first.
681   for (MachineBasicBlock &MBB : MFunc)
682     canonicalizePHIs(MBB);
683 
684   // Loop over all of the basic blocks in reverse post order and insert
685   // ldtilecfg for tile registers. The reserse post order is to facilitate
686   // PHI node convert.
687   ReversePostOrderTraversal<MachineFunction *> RPOT(MF);
688   for (MachineBasicBlock *MBB : RPOT) {
689     convertPHIs(*MBB);
690     Change |= configBasicBlock(*MBB);
691   }
692 
693   if (Change)
694     InitializeTileConfigStackSpace();
695 
696   StackSlotForVirtReg.clear();
697   return Change;
698 }
699 
700 FunctionPass *llvm::createX86FastPreTileConfigPass() {
701   return new X86FastPreTileConfig();
702 }
703