xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86FastTileConfig.cpp (revision 5e801ac66d24704442eba426ed13c3effb8a34e7)
1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. Before FastRegAllocation pass
11 /// the ldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after register allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/InitializePasses.h"
33 
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "fasttileconfig"
37 
38 namespace {
39 
40 class X86FastTileConfig : public MachineFunctionPass {
41   // context
42   MachineFunction *MF = nullptr;
43   const X86Subtarget *ST = nullptr;
44   const TargetRegisterInfo *TRI = nullptr;
45   const TargetInstrInfo *TII = nullptr;
46   MachineRegisterInfo *MRI = nullptr;
47   X86MachineFunctionInfo *X86FI = nullptr;
48 
49   MachineInstr *getTileConfigPoint();
50   void tileConfig();
51 
52 public:
53   X86FastTileConfig() : MachineFunctionPass(ID) {}
54 
55   bool fastTileConfig();
56   bool isTileLoad(MachineInstr &MI);
57   bool isTileStore(MachineInstr &MI);
58   bool isAMXInstr(MachineInstr &MI);
59   void getTileStoreShape(MachineInstr &MI,
60                          SmallVector<MachineOperand *> &ShapedTiles);
61 
62   MachineInstr *getKeyAMXInstr(MachineInstr *MI);
63   void getTileShapesCfg(MachineInstr *MI,
64                         SmallVector<MachineOperand *> &ShapedTiles);
65   void getShapeCfgInstrs(MachineInstr *MI,
66                          std::map<unsigned, MachineInstr *> &RowCfgs,
67                          std::map<unsigned, MachineInstr *> &ColCfgs);
68 
69   /// Return the pass name.
70   StringRef getPassName() const override {
71     return "Fast Tile Register Configure";
72   }
73 
74   void materializeTileCfg(MachineInstr *MI);
75 
76   void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles,
77                       std::map<unsigned, MachineInstr *> &RowCfgs,
78                       std::map<unsigned, MachineInstr *> &ColCfgs);
79 
80   /// Perform register allocation.
81   bool runOnMachineFunction(MachineFunction &MFunc) override;
82 
83   MachineFunctionProperties getRequiredProperties() const override {
84     return MachineFunctionProperties().set(
85         MachineFunctionProperties::Property::NoPHIs);
86   }
87 
88   static char ID;
89 };
90 
91 } // end anonymous namespace
92 
93 char X86FastTileConfig::ID = 0;
94 
95 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
96                       "Fast Tile Register Configure", false, false)
97 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
98                     "Fast Tile Register Configure", false, false)
99 
100 static bool isTilePhysReg(MachineOperand &Op) {
101   if (!Op.isReg())
102     return false;
103 
104   Register Reg = Op.getReg();
105   if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
106     return true;
107   return false;
108 }
109 
110 static unsigned getTilePhysRegIdx(MachineOperand *Op) {
111   assert(isTilePhysReg(*Op) && "Tile Operand is invalid");
112   return Op->getReg() - X86::TMM0;
113 }
114 
115 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) {
116   unsigned Offset = 48 + TIdx;
117   MI->getOperand(3).ChangeToImmediate(Offset);
118 }
119 
120 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) {
121   unsigned Offset = 16 + TIdx * 2;
122   MI->getOperand(3).ChangeToImmediate(Offset);
123 }
124 
125 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) {
126   return MI.getOpcode() == X86::PTILELOADDV ||
127          MI.getOpcode() == X86::PTILELOADDT1V;
128 }
129 bool X86FastTileConfig::isTileStore(MachineInstr &MI) {
130   return MI.getOpcode() == X86::PTILESTOREDV;
131 }
132 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) {
133   // TODO: May need to handle some special nontile amx instrucion.
134   if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr())
135     return false;
136 
137   for (MachineOperand &MO : MI.operands())
138     if (isTilePhysReg(MO))
139       return true;
140 
141   return false;
142 }
143 
144 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) {
145   auto Cfg = MachineBasicBlock::iterator(MI);
146   MachineBasicBlock *MBB = MI->getParent();
147   MachineInstr *KeyMI = nullptr;
148   int KeyAMXNum = 0;
149 
150   for (auto II = Cfg; II != MBB->end(); II++) {
151     if (isTileLoad(*II)) {
152       KeyMI = &*II;
153       continue;
154     }
155 
156     if (isTileStore(*II)) {
157       assert(KeyMI && "Key AMX Should be found before!");
158       break;
159     }
160 
161     if (isAMXInstr(*II)) {
162       assert((KeyAMXNum == 0) && "Too many Key AMX instruction!");
163       KeyAMXNum++;
164       KeyMI = &*II;
165     }
166   }
167   assert(KeyMI && "There must be an AMX instruction.");
168   return KeyMI;
169 }
170 
171 // Orderly get the tiles in key amx instruction, uses before defs.
172 void X86FastTileConfig::getTileShapesCfg(
173     MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) {
174   MachineInstr *KeyMI = getKeyAMXInstr(CfgMI);
175 
176   SmallVector<MachineOperand *> DefTiles;
177   for (MachineOperand &MO : KeyMI->operands()) {
178     if (!isTilePhysReg(MO))
179       continue;
180     if (MO.isDef())
181       DefTiles.push_back(&MO);
182     else
183       ShapedTiles.push_back(&MO);
184   }
185   ShapedTiles.append(DefTiles);
186 }
187 
188 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and
189 // amx.shape.N.col*" at pass "Pre AMX Tile Config".
190 // The 'N' implies the order of tiles in key amx intrinsic.
191 void X86FastTileConfig::getShapeCfgInstrs(
192     MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs,
193     std::map<unsigned, MachineInstr *> &ColCfgs) {
194   auto Cfg = MachineBasicBlock::iterator(MI);
195   MachineBasicBlock *MBB = MI->getParent();
196 
197   for (auto II = Cfg; II != MBB->begin(); II--) {
198     if (isAMXInstr(*II) || II->isTerminator() || II->isCall())
199       break;
200     if (!II->mayStore() || !II->hasOneMemOperand())
201       continue;
202     const Value *MemPtr = II->memoperands()[0]->getValue();
203     if (!MemPtr)
204       continue;
205 
206     StringRef Name = MemPtr->getName();
207     if (!Name.startswith("amx.tmm."))
208       continue;
209 
210     // Get the 'N'th tile shape config in key amx instruction.
211     auto N = Name.find(".shape");
212     StringRef STileIdx = Name.slice(8, N);
213     unsigned Idx;
214     STileIdx.getAsInteger(10, Idx);
215 
216     // And related them with their store instructions.
217     if (Name.contains("row"))
218       RowCfgs[Idx] = &*II;
219     else if (Name.contains("col"))
220       ColCfgs[Idx] = &*II;
221     else
222       llvm_unreachable("Invalid tile shape info!");
223   }
224   assert((RowCfgs.size() == ColCfgs.size()) &&
225          "The number of tile row and col must be equal!");
226 }
227 
228 // Here is the data format for the tile config.
229 // 0      palette   = 1 now.
230 // 1      start_row = 0 now.
231 // 2-15   reserved, must be zero
232 // 16-17  tile0.colsb Tile 0 bytes per row.
233 // 18-19  tile1.colsb Tile 1 bytes per row.
234 // 20-21  tile2.colsb Tile 2 bytes per row.
235 // ... (sequence continues)
236 // 30-31  tile7.colsb Tile 7 bytes per row.
237 // 32-47  reserved, must be zero
238 // 48     tile0.rows Tile 0 rows.
239 // 49     tile1.rows Tile 1 rows.
240 // 50     tile2.rows Tile 2 rows.
241 // ... (sequence continues)
242 // 55     tile7.rows Tile 7 rows.
243 // 56-63  reserved, must be zero
244 void X86FastTileConfig::rewriteTileCfg(
245     SmallVector<MachineOperand *> &ShapedTiles,
246     std::map<unsigned, MachineInstr *> &RowCfgs,
247     std::map<unsigned, MachineInstr *> &ColCfgs) {
248   assert((RowCfgs.size() == ShapedTiles.size()) &&
249          "The number of tile shapes not equal with the number of tiles!");
250 
251   // Orderly get the tiles and adjust the shape config.
252   for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) {
253     MachineOperand *MO = ShapedTiles[I];
254     unsigned TmmIdx = getTilePhysRegIdx(MO);
255     if (I == TmmIdx)
256       continue;
257     adjustRowCfg(TmmIdx, RowCfgs[I]);
258     adjustColCfg(TmmIdx, ColCfgs[I]);
259   }
260 }
261 
262 // We have already preconfig the shapes before fast register allocation at
263 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register
264 // allocation, the shapes pre-written before may not rightly corresponding
265 // to the correct tmm registers, so we need adjust them.
266 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) {
267   SmallVector<MachineOperand *> ShapedTiles;
268   std::map<unsigned, MachineInstr *> RowCfgs;
269   std::map<unsigned, MachineInstr *> ColCfgs;
270 
271   // Orderly keep the tile uses and def in ShapedTiles;
272   getTileShapesCfg(CfgMI, ShapedTiles);
273   assert(ShapedTiles.size() && "Not find shapes config!");
274 
275   getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs);
276 
277   rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs);
278 }
279 
280 bool X86FastTileConfig::fastTileConfig() {
281   bool Changed = false;
282 
283   for (MachineBasicBlock &MBB : *MF) {
284     SmallVector<MachineInstr *, 2> CFGs;
285     for (MachineInstr &MI : MBB)
286       if (MI.getOpcode() == X86::PLDTILECFGV)
287         CFGs.push_back(&MI);
288     for (auto *MI : CFGs)
289       materializeTileCfg(MI);
290     if (!CFGs.empty())
291       Changed = true;
292   }
293   if (Changed)
294     X86FI->setHasVirtualTileReg(true);
295   return Changed;
296 }
297 
298 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
299   MF = &MFunc;
300   MRI = &MFunc.getRegInfo();
301   ST = &MFunc.getSubtarget<X86Subtarget>();
302   TRI = ST->getRegisterInfo();
303   TII = MFunc.getSubtarget().getInstrInfo();
304   X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
305 
306   return fastTileConfig();
307 }
308 
309 FunctionPass *llvm::createX86FastTileConfigPass() {
310   return new X86FastTileConfig();
311 }
312