xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86PreTileConfig.cpp (revision 7c20397b724a55001c2054fa133a768e9d06eb1c)
1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineInstr.h"
33 #include "llvm/CodeGen/MachineLoopInfo.h"
34 #include "llvm/CodeGen/MachineRegisterInfo.h"
35 #include "llvm/CodeGen/Passes.h"
36 #include "llvm/CodeGen/TargetInstrInfo.h"
37 #include "llvm/CodeGen/TargetRegisterInfo.h"
38 #include "llvm/InitializePasses.h"
39 
40 using namespace llvm;
41 
42 #define DEBUG_TYPE "tile-pre-config"
43 #define REPORT_CONFIG_FAIL                                                     \
44   report_fatal_error(                                                          \
45       MF.getName() +                                                           \
46       ": Failed to config tile register, please define the shape earlier");
47 
48 namespace {
49 
50 struct MIRef {
51   MachineInstr *MI = nullptr;
52   MachineBasicBlock *MBB = nullptr;
53   // A virtual position for instruction that will be inserted after MI.
54   size_t Pos = 0;
55   MIRef() = default;
56   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
57     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
58          ++I, ++Pos)
59       MI = &*I;
60   }
61   MIRef(MachineInstr *MI)
62       : MI(MI), MBB(MI->getParent()),
63         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
64   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
65       : MI(MI), MBB(MBB),
66         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
67   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
68       : MI(MI), MBB(MBB), Pos(Pos) {}
69   operator bool() const { return MBB != nullptr; }
70   bool operator==(const MIRef &RHS) const {
71     return MI == RHS.MI && MBB == RHS.MBB;
72   }
73   bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
74   bool operator<(const MIRef &RHS) const {
75     // Comparison between different BBs happens when inserting a MIRef into set.
76     // So we compare MBB first to make the insertion happy.
77     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
78   }
79   bool operator>(const MIRef &RHS) const {
80     // Comparison between different BBs happens when inserting a MIRef into set.
81     // So we compare MBB first to make the insertion happy.
82     return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
83   }
84 };
85 
86 struct BBInfo {
87   MIRef FirstAMX;
88   MIRef LastCall;
89   bool HasAMXRegLiveIn = false;
90   bool TileCfgForbidden = false;
91   bool NeedTileCfgLiveIn = false;
92 };
93 
94 class X86PreTileConfig : public MachineFunctionPass {
95   MachineRegisterInfo *MRI;
96   const MachineLoopInfo *MLI;
97   SmallSet<MachineInstr *, 8> DefVisited;
98   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
99   DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
100 
101   /// Check if the callee will clobber AMX registers.
102   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
103     auto Iter = llvm::find_if(
104         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
105     if (Iter == MI.operands_end())
106       return false;
107     UsableRegs.clearBitsInMask(Iter->getRegMask());
108     return !UsableRegs.none();
109   }
110 
111   /// Check if MI is AMX pseudo instruction.
112   bool isAMXInstruction(MachineInstr &MI) {
113     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
114       return false;
115     MachineOperand &MO = MI.getOperand(0);
116     // We can simply check if it is AMX instruction by its def.
117     // But we should exclude old API which uses physical registers.
118     if (MO.isReg() && MO.getReg().isVirtual() &&
119         MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
120       collectShapeInfo(MI);
121       return true;
122     }
123     // PTILESTOREDV is the only exception that doesn't def a AMX register.
124     return MI.getOpcode() == X86::PTILESTOREDV;
125   }
126 
127   /// Check if it is an edge from loop bottom to loop head.
128   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
129     if (!MLI->isLoopHeader(Header))
130       return false;
131     auto *ML = MLI->getLoopFor(Header);
132     if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
133       return true;
134 
135     return false;
136   }
137 
138   /// Collect the shape def information for later use.
139   void collectShapeInfo(MachineInstr &MI);
140 
141   /// Try to hoist shapes definded below AMX instructions.
142   bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
143     MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
144     auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
145     auto InsertPoint = FirstAMX.MI->getIterator();
146     for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
147       // Do not hoist instructions that access memory.
148       if (I->MI->mayLoadOrStore())
149         return false;
150       for (auto &MO : I->MI->operands()) {
151         if (MO.isDef())
152           continue;
153         // Do not hoist instructions if the sources' def under AMX instruction.
154         // TODO: We can handle isMoveImmediate MI here.
155         if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
156           return false;
157         // TODO: Maybe need more checks here.
158       }
159       MBB->insert(InsertPoint, I->MI->removeFromParent());
160     }
161     // We only need to mark the last shape in the BB now.
162     Shapes.clear();
163     Shapes.push_back(MIRef(&*--InsertPoint, MBB));
164     return true;
165   }
166 
167 public:
168   X86PreTileConfig() : MachineFunctionPass(ID) {}
169 
170   /// Return the pass name.
171   StringRef getPassName() const override {
172     return "Tile Register Pre-configure";
173   }
174 
175   /// X86PreTileConfig analysis usage.
176   void getAnalysisUsage(AnalysisUsage &AU) const override {
177     AU.setPreservesAll();
178     AU.addRequired<MachineLoopInfo>();
179     MachineFunctionPass::getAnalysisUsage(AU);
180   }
181 
182   /// Clear MF related structures.
183   void releaseMemory() override {
184     ShapeBBs.clear();
185     DefVisited.clear();
186     BBVisitedInfo.clear();
187   }
188 
189   /// Perform ldtilecfg instructions inserting.
190   bool runOnMachineFunction(MachineFunction &MF) override;
191 
192   static char ID;
193 };
194 
195 } // end anonymous namespace
196 
197 char X86PreTileConfig::ID = 0;
198 
199 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
200                       "Tile Register Pre-configure", false, false)
201 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
202 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
203                     "Tile Register Pre-configure", false, false)
204 
205 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
206   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
207     MIRef MIR(MI, MBB);
208     auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
209     if (I == ShapeBBs[MBB].end() || *I != MIR)
210       ShapeBBs[MBB].insert(I, MIR);
211   };
212 
213   SmallVector<Register, 8> WorkList(
214       {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
215   while (!WorkList.empty()) {
216     Register R = WorkList.pop_back_val();
217     MachineInstr *DefMI = MRI->getVRegDef(R);
218     assert(DefMI && "R must has one define instruction");
219     MachineBasicBlock *DefMBB = DefMI->getParent();
220     if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
221       continue;
222     if (DefMI->isPHI()) {
223       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
224         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
225           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
226         else
227           WorkList.push_back(DefMI->getOperand(I).getReg());
228     } else {
229       RecordShape(DefMI, DefMBB);
230     }
231   }
232 }
233 
234 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
235   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
236   const TargetInstrInfo *TII = ST.getInstrInfo();
237   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
238   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
239   X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
240 
241   BitVector AMXRegs(TRI->getNumRegs());
242   for (unsigned I = 0; I < RC->getNumRegs(); I++)
243     AMXRegs.set(X86::TMM0 + I);
244 
245   // Iterate MF to collect information.
246   MRI = &MF.getRegInfo();
247   MLI = &getAnalysis<MachineLoopInfo>();
248   SmallSet<MIRef, 8> CfgNeedInsert;
249   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
250   for (auto &MBB : MF) {
251     size_t Pos = 0;
252     for (auto &MI : MBB) {
253       ++Pos;
254       if (isAMXInstruction(MI)) {
255         // If there's call before the AMX, we need to reload tile config.
256         if (BBVisitedInfo[&MBB].LastCall)
257           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
258         else // Otherwise, we need tile config to live in this BB.
259           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
260         // Always record the first AMX in case there's shape def after it.
261         if (!BBVisitedInfo[&MBB].FirstAMX)
262           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
263       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
264         // Record the call only if the callee clobbers all AMX registers.
265         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
266       }
267     }
268     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
269       if (&MBB == &MF.front())
270         CfgNeedInsert.insert(MIRef(&MBB));
271       else
272         CfgLiveInBBs.push_back(&MBB);
273     }
274     if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
275       for (auto *Succ : MBB.successors())
276         if (!isLoopBackEdge(Succ, &MBB))
277           BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
278   }
279 
280   // Update NeedTileCfgLiveIn for predecessors.
281   while (!CfgLiveInBBs.empty()) {
282     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
283     for (auto *Pred : MBB->predecessors()) {
284       if (BBVisitedInfo[Pred].LastCall) {
285         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
286       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
287         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
288         if (Pred == &MF.front())
289           CfgNeedInsert.insert(MIRef(Pred));
290         else
291           CfgLiveInBBs.push_back(Pred);
292       }
293     }
294   }
295 
296   // There's no AMX instruction if we didn't find a tile config live in point.
297   if (CfgNeedInsert.empty())
298     return false;
299   X86FI->setHasVirtualTileReg(true);
300 
301   // Avoid to insert ldtilecfg before any shape defs.
302   SmallVector<MachineBasicBlock *, 8> WorkList;
303   for (auto &I : ShapeBBs) {
304     // TODO: We can hoist shapes across BBs here.
305     if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
306       REPORT_CONFIG_FAIL
307     if (BBVisitedInfo[I.first].FirstAMX &&
308         BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
309         !hoistShapesInBB(I.first, I.second))
310       REPORT_CONFIG_FAIL
311     WorkList.push_back(I.first);
312   }
313   while (!WorkList.empty()) {
314     MachineBasicBlock *MBB = WorkList.pop_back_val();
315     for (auto *Pred : MBB->predecessors()) {
316       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
317         BBVisitedInfo[Pred].TileCfgForbidden = true;
318         WorkList.push_back(Pred);
319       }
320     }
321   }
322 
323   DebugLoc DL;
324   SmallSet<MIRef, 8> VisitedOrInserted;
325   int SS = MF.getFrameInfo().CreateStackObject(
326       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
327 
328   // Try to insert for the tile config live in points.
329   for (const auto &I : CfgNeedInsert) {
330     SmallSet<MIRef, 8> InsertPoints;
331     SmallVector<MIRef, 8> WorkList({I});
332     while (!WorkList.empty()) {
333       MIRef I = WorkList.pop_back_val();
334       if (!VisitedOrInserted.count(I)) {
335         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
336           // If the BB is all shapes reachable, stop sink and try to insert.
337           InsertPoints.insert(I);
338         } else {
339           // Avoid the BB to be multi visited.
340           VisitedOrInserted.insert(I);
341           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
342           // true when MBB isn't all shapes reachable.
343           for (auto *Succ : I.MBB->successors())
344             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
345               WorkList.push_back(MIRef(Succ));
346         }
347       }
348     }
349 
350     // A given point might be forked due to shape conditions are not met.
351     for (MIRef I : InsertPoints) {
352       // Make sure we insert ldtilecfg after the last shape def in MBB.
353       if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
354         I = ShapeBBs[I.MBB].back();
355       // There're chances the MBB is sunk more than once. Record it to avoid
356       // multi insert.
357       if (VisitedOrInserted.insert(I).second) {
358         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
359         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
360                           SS);
361       }
362     }
363   }
364 
365   // Zero stack slot.
366   MachineBasicBlock &MBB = MF.front();
367   MachineInstr *MI = &*MBB.begin();
368   if (ST.hasAVX512()) {
369     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
370     BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
371         .addReg(Zmm, RegState::Undef)
372         .addReg(Zmm, RegState::Undef);
373     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
374         .addReg(Zmm);
375   } else if (ST.hasAVX2()) {
376     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
377     BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
378         .addReg(Ymm, RegState::Undef)
379         .addReg(Ymm, RegState::Undef);
380     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
381         .addReg(Ymm);
382     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
383         .addReg(Ymm);
384   } else {
385     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
386     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
387     BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
388         .addReg(Xmm, RegState::Undef)
389         .addReg(Xmm, RegState::Undef);
390     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
391         .addReg(Xmm);
392     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
393         .addReg(Xmm);
394     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
395         .addReg(Xmm);
396     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
397         .addReg(Xmm);
398   }
399   // Fill in the palette first.
400   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
401 
402   return true;
403 }
404 
405 FunctionPass *llvm::createX86PreTileConfigPass() {
406   return new X86PreTileConfig();
407 }
408