xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86TileConfig.cpp (revision 2e3507c25e42292b45a5482e116d278f5515d04d)
1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. In X86PreTileConfig pass
11 /// the pldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after egister allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/LiveIntervals.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineInstr.h"
29 #include "llvm/CodeGen/MachineRegisterInfo.h"
30 #include "llvm/CodeGen/Passes.h"
31 #include "llvm/CodeGen/TargetInstrInfo.h"
32 #include "llvm/CodeGen/TargetRegisterInfo.h"
33 #include "llvm/CodeGen/TileShapeInfo.h"
34 #include "llvm/CodeGen/VirtRegMap.h"
35 #include "llvm/InitializePasses.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "tileconfig"
40 
41 namespace {
42 
43 struct X86TileConfig : public MachineFunctionPass {
44 
45   X86TileConfig() : MachineFunctionPass(ID) {}
46 
47   /// Return the pass name.
48   StringRef getPassName() const override { return "Tile Register Configure"; }
49 
50   /// X86TileConfig analysis usage.
51   void getAnalysisUsage(AnalysisUsage &AU) const override {
52     AU.setPreservesAll();
53     AU.addRequired<VirtRegMap>();
54     AU.addRequired<LiveIntervals>();
55     MachineFunctionPass::getAnalysisUsage(AU);
56   }
57 
58   /// Perform register allocation.
59   bool runOnMachineFunction(MachineFunction &mf) override;
60 
61   MachineFunctionProperties getRequiredProperties() const override {
62     return MachineFunctionProperties().set(
63         MachineFunctionProperties::Property::NoPHIs);
64   }
65 
66   static char ID;
67 };
68 
69 } // end anonymous namespace
70 
71 char X86TileConfig::ID = 0;
72 
73 INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
74                       false, false)
75 INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
76 INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
77                     false)
78 
79 bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
80   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
81   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
82   const TargetInstrInfo *TII = ST.getInstrInfo();
83   MachineRegisterInfo &MRI = MF.getRegInfo();
84   LiveIntervals &LIS = getAnalysis<LiveIntervals>();
85   VirtRegMap &VRM = getAnalysis<VirtRegMap>();
86 
87   if (VRM.isShapeMapEmpty())
88     return false;
89 
90   int SS = INT_MAX;
91   for (MachineBasicBlock &MBB : MF) {
92     for (MachineInstr &MI : MBB) {
93       if (MI.getOpcode() == X86::PLDTILECFGV) {
94         SS = MI.getOperand(0).getIndex();
95         break;
96       }
97     }
98     if (SS != INT_MAX)
99       break;
100   }
101   // Didn't find PLDTILECFGV, just return false;
102   if (SS == INT_MAX)
103     return false;
104 
105   // Try to find a point to insert MIs for constant shapes.
106   // Here we are leveraging the palette id inserted in PreRA pass.
107   unsigned ConstPos = 0;
108   MachineInstr *ConstMI = nullptr;
109   for (MachineInstr &MI : MF.front()) {
110     if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) {
111       ConstMI = &MI;
112       break;
113     }
114     ++ConstPos;
115   }
116   assert(ConstMI && "Cannot find an insertion point");
117 
118   unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
119   SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
120   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
121     Register VirtReg = Register::index2VirtReg(I);
122     if (MRI.reg_nodbg_empty(VirtReg))
123       continue;
124     if (MRI.getRegClass(VirtReg)->getID() != X86::TILERegClassID)
125       continue;
126     if (VRM.getPhys(VirtReg) == VirtRegMap::NO_PHYS_REG)
127       continue;
128     unsigned Index = VRM.getPhys(VirtReg) - X86::TMM0;
129     if (!Phys2Virt[Index])
130       Phys2Virt[Index] = VirtReg;
131   }
132 
133   // Fill in the shape of each tile physical register.
134   for (unsigned I = 0; I < AMXRegNum; ++I) {
135     if (!Phys2Virt[I])
136       continue;
137     DebugLoc DL;
138     bool IsRow = true;
139     MachineInstr *NewMI = nullptr;
140     ShapeT Shape = VRM.getShape(Phys2Virt[I]);
141     for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
142       // Here is the data format for the tile config.
143       // 0      palette
144       // 1      start_row
145       // 2-15   reserved, must be zero
146       // 16-17  tile0.colsb Tile 0 bytes per row.
147       // 18-19  tile1.colsb Tile 1 bytes per row.
148       // 20-21  tile2.colsb Tile 2 bytes per row.
149       // ... (sequence continues)
150       // 30-31  tile7.colsb Tile 7 bytes per row.
151       // 32-47  reserved, must be zero
152       // 48     tile0.rows Tile 0 rows.
153       // 49     tile1.rows Tile 1 rows.
154       // 50     tile2.rows Tile 2 rows.
155       // ... (sequence continues)
156       // 55     tile7.rows Tile 7 rows.
157       // 56-63  reserved, must be zero
158       int64_t Imm = INT64_MAX;
159       int Offset = IsRow ? 48 + I : 16 + I * 2;
160       for (auto &DefMI : MRI.def_instructions(R)) {
161         MachineBasicBlock &MBB = *DefMI.getParent();
162         if (DefMI.isMoveImmediate()) {
163           if (Imm != INT64_MAX) {
164             // FIXME: We should handle this case in future.
165             assert(Imm == DefMI.getOperand(1).getImm() &&
166                    "Cannot initialize with different shapes");
167             continue;
168           }
169           Imm = DefMI.getOperand(1).getImm();
170           NewMI = addFrameReference(
171                       BuildMI(MF.front(), ++ConstMI->getIterator(), DL,
172                               TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)),
173                       SS, Offset)
174                       .addImm(Imm);
175           ConstMI = NewMI;
176           LIS.InsertMachineInstrInMaps(*NewMI);
177         } else {
178           unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
179           unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R));
180           if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
181             SubIdx = 0;
182           auto Iter = DefMI.getIterator();
183           if (&MBB == &MF.front() &&
184               (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos)
185             Iter = ConstMI->getIterator();
186           NewMI = addFrameReference(
187                       BuildMI(MBB, ++Iter, DL,
188                               TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)),
189                       SS, Offset)
190                       .addReg(R, 0, SubIdx);
191           SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI);
192           LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()});
193         }
194       }
195       IsRow = false;
196     }
197   }
198   return true;
199 }
200 
201 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
202