xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86FastTileConfig.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. Before FastRegAllocation pass
11 /// the ldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after register allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86Subtarget.h"
24 #include "llvm/CodeGen/MachineFrameInfo.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineInstr.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/Passes.h"
29 #include "llvm/CodeGen/TargetInstrInfo.h"
30 #include "llvm/CodeGen/TargetRegisterInfo.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "fasttileconfig"
35 
36 namespace {
37 
38 class X86FastTileConfig : public MachineFunctionPass {
39   // context
40   MachineFunction *MF = nullptr;
41   const TargetInstrInfo *TII = nullptr;
42   MachineRegisterInfo *MRI = nullptr;
43   const TargetRegisterInfo *TRI = nullptr;
44   X86MachineFunctionInfo *X86FI = nullptr;
45 
46   bool configBasicBlock(MachineBasicBlock &MBB);
47 
48 public:
49   X86FastTileConfig() : MachineFunctionPass(ID) {}
50 
51   /// Return the pass name.
52   StringRef getPassName() const override {
53     return "Fast Tile Register Configure";
54   }
55 
56   void getAnalysisUsage(AnalysisUsage &AU) const override {
57     AU.setPreservesAll();
58     MachineFunctionPass::getAnalysisUsage(AU);
59   }
60 
61   /// Perform register allocation.
62   bool runOnMachineFunction(MachineFunction &MFunc) override;
63 
64   MachineFunctionProperties getRequiredProperties() const override {
65     return MachineFunctionProperties().setNoPHIs();
66   }
67 
68   static char ID;
69 };
70 
71 } // end anonymous namespace
72 
73 char X86FastTileConfig::ID = 0;
74 
75 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
76                       "Fast Tile Register Configure", false, false)
77 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
78                     "Fast Tile Register Configure", false, false)
79 
80 static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {
81   // There is no phi instruction after register allocation.
82   assert(MI.isPHI() == false);
83   // The instruction must have 3 operands: tile def, row, col.
84   // It should be AMX pseudo instruction that have shape operand.
85   if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||
86       !MI.isPseudo())
87     return 0;
88   MachineOperand &MO = MI.getOperand(0);
89 
90   if (MO.isReg()) {
91     Register Reg = MO.getReg();
92     // FIXME: It may be used after Greedy RA and the physical
93     // register is not rewritten yet.
94     if (Reg.isVirtual()) {
95       if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
96         return 1;
97       if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID)
98         return 2;
99     }
100     if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
101       return 1;
102     if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
103       return 2;
104   }
105 
106   return 0;
107 }
108 
109 static unsigned getTMMIndex(Register Reg) {
110   if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
111     return Reg - X86::TMM0;
112   if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
113     return (Reg - X86::TMM0_TMM1) * 2;
114   llvm_unreachable("Invalid Tmm Reg!");
115 }
116 
117 // PreTileConfig should configure the tile registers based on basic
118 // block.
119 bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
120   bool Change = false;
121   SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;
122   for (MachineInstr &MI : reverse(MBB)) {
123     unsigned DefNum = getNumDefTiles(MRI, MI);
124     if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV)
125       continue;
126     // AMX instructions that define tile register.
127     if (MI.getOpcode() != X86::PLDTILECFGV) {
128       MachineOperand &Row = MI.getOperand(1);
129       unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg());
130       for (unsigned I = 0; I < DefNum; I++) {
131         MachineOperand &Col = MI.getOperand(2 + I);
132         ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)});
133       }
134     } else { // PLDTILECFGV
135       // Rewrite the shape information to memory. Stack slot should have
136       // been initialized to zero in pre config.
137       int SS = MI.getOperand(0).getIndex(); // tile config stack slot.
138       for (auto &ShapeInfo : ShapeInfos) {
139         DebugLoc DL;
140         unsigned TMMIdx = ShapeInfo.first;
141         Register RowReg = ShapeInfo.second.getRow()->getReg();
142         Register ColReg = ShapeInfo.second.getCol()->getReg();
143         // Here is the data format for the tile config.
144         // 0      palette
145         // 1      start_row
146         // 2-15   reserved, must be zero
147         // 16-17  tile0.colsb Tile 0 bytes per row.
148         // 18-19  tile1.colsb Tile 1 bytes per row.
149         // 20-21  tile2.colsb Tile 2 bytes per row.
150         // ... (sequence continues)
151         // 30-31  tile7.colsb Tile 7 bytes per row.
152         // 32-47  reserved, must be zero
153         // 48     tile0.rows Tile 0 rows.
154         // 49     tile1.rows Tile 1 rows.
155         // 50     tile2.rows Tile 2 rows.
156         // ... (sequence continues)
157         // 55     tile7.rows Tile 7 rows.
158         // 56-63  reserved, must be zero
159         int RowOffset = 48 + TMMIdx;
160         int ColOffset = 16 + TMMIdx * 2;
161 
162         Register SubRowReg = TRI->getSubReg(RowReg, X86::sub_8bit);
163         BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), SubRowReg);
164         MachineInstrBuilder StoreRow =
165             BuildMI(MBB, MI, DL, TII->get(X86::MOV8mr));
166         addFrameReference(StoreRow, SS, RowOffset).addReg(SubRowReg);
167 
168         MachineInstrBuilder StoreCol =
169             BuildMI(MBB, MI, DL, TII->get(X86::MOV16mr));
170         addFrameReference(StoreCol, SS, ColOffset).addReg(ColReg);
171       }
172       ShapeInfos.clear();
173       Change = true;
174     }
175   }
176 
177   return Change;
178 }
179 
180 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
181   X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
182   // Early exit in the common case of non-AMX code.
183   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
184     return false;
185 
186   MF = &MFunc;
187   MRI = &MFunc.getRegInfo();
188   const TargetSubtargetInfo *ST = &MFunc.getSubtarget<X86Subtarget>();
189   TRI = ST->getRegisterInfo();
190   TII = MFunc.getSubtarget().getInstrInfo();
191   bool Change = false;
192 
193   // Loop over all of the basic blocks, eliminating virtual register references
194   for (MachineBasicBlock &MBB : MFunc)
195     Change |= configBasicBlock(MBB);
196 
197   return Change;
198 }
199 
200 FunctionPass *llvm::createX86FastTileConfigPass() {
201   return new X86FastTileConfig();
202 }
203