1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. In X86PreTileConfig pass 11 /// the pldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after egister allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/LiveIntervals.h" 26 #include "llvm/CodeGen/MachineDominators.h" 27 #include "llvm/CodeGen/MachineFrameInfo.h" 28 #include "llvm/CodeGen/MachineFunctionPass.h" 29 #include "llvm/CodeGen/MachineInstr.h" 30 #include "llvm/CodeGen/MachineRegisterInfo.h" 31 #include "llvm/CodeGen/Passes.h" 32 #include "llvm/CodeGen/TargetInstrInfo.h" 33 #include "llvm/CodeGen/TargetRegisterInfo.h" 34 #include "llvm/CodeGen/TileShapeInfo.h" 35 #include "llvm/CodeGen/VirtRegMap.h" 36 #include "llvm/InitializePasses.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "tile-config" 41 42 namespace { 43 44 class X86TileConfig : public MachineFunctionPass { 45 // context 46 MachineFunction *MF = nullptr; 47 const X86Subtarget *ST = nullptr; 48 const TargetRegisterInfo *TRI; 49 const TargetInstrInfo *TII; 50 MachineDominatorTree *DomTree = nullptr; 51 MachineRegisterInfo *MRI = nullptr; 52 VirtRegMap *VRM = nullptr; 53 LiveIntervals *LIS = nullptr; 54 55 MachineInstr *getTileConfigPoint(); 56 void tileConfig(); 57 58 public: 59 X86TileConfig() : MachineFunctionPass(ID) {} 60 61 /// Return the pass name. 62 StringRef getPassName() const override { return "Tile Register Configure"; } 63 64 /// X86TileConfig analysis usage. 65 void getAnalysisUsage(AnalysisUsage &AU) const override; 66 67 /// Perform register allocation. 68 bool runOnMachineFunction(MachineFunction &mf) override; 69 70 MachineFunctionProperties getRequiredProperties() const override { 71 return MachineFunctionProperties().set( 72 MachineFunctionProperties::Property::NoPHIs); 73 } 74 75 static char ID; 76 }; 77 78 } // end anonymous namespace 79 80 char X86TileConfig::ID = 0; 81 82 INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure", 83 false, false) 84 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 85 INITIALIZE_PASS_DEPENDENCY(VirtRegMap) 86 INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure", 87 false, false) 88 89 void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 90 AU.addRequired<MachineDominatorTree>(); 91 AU.addRequired<LiveIntervals>(); 92 AU.addPreserved<SlotIndexes>(); 93 AU.addRequired<VirtRegMap>(); 94 AU.setPreservesAll(); 95 MachineFunctionPass::getAnalysisUsage(AU); 96 } 97 98 static unsigned getTilePhysRegIndex(Register PhysReg) { 99 assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) && 100 "Tile register number is invalid"); 101 return (PhysReg - X86::TMM0); 102 } 103 104 static MachineInstr * 105 storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, 106 Register SrcReg, unsigned BitSize, int FrameIdx, int Offset, 107 const TargetInstrInfo *TII, const TargetRegisterClass *RC, 108 const TargetRegisterInfo *TRI) { 109 110 unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit; 111 unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr; 112 if (BitSize == TRI->getRegSizeInBits(*RC)) 113 SubIdx = 0; 114 MachineInstr *NewMI = 115 addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx, 116 Offset) 117 .addReg(SrcReg, 0, SubIdx); 118 return NewMI; 119 } 120 121 static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB, 122 MachineBasicBlock::iterator MI, 123 int64_t Imm, unsigned BitSize, 124 int FrameIdx, int Offset, 125 const TargetInstrInfo *TII) { 126 unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi; 127 return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), 128 FrameIdx, Offset) 129 .addImm(Imm); 130 } 131 132 MachineInstr *X86TileConfig::getTileConfigPoint() { 133 for (MachineBasicBlock &MBB : *MF) { 134 135 // Traverse the basic block. 136 for (MachineInstr &MI : MBB) 137 // Refer X86PreTileConfig.cpp. 138 // We only support one tile config for now. 139 if (MI.getOpcode() == X86::PLDTILECFG) 140 return &MI; 141 } 142 143 return nullptr; 144 } 145 146 void X86TileConfig::tileConfig() { 147 MachineInstr *MI = getTileConfigPoint(); 148 if (!MI) 149 return; 150 MachineBasicBlock *MBB = MI->getParent(); 151 int SS = MI->getOperand(1).getIndex(); 152 BitVector PhysRegs(TRI->getNumRegs()); 153 154 // Fill in the palette first. 155 auto *NewMI = storeImmToStackSlot(*MBB, *MI, 1, 8, SS, 0, TII); 156 LIS->InsertMachineInstrInMaps(*NewMI); 157 // Fill in the shape of each tile physical register. 158 for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 159 Register VirtReg = Register::index2VirtReg(i); 160 if (MRI->reg_nodbg_empty(VirtReg)) 161 continue; 162 const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 163 if (RC.getID() != X86::TILERegClassID) 164 continue; 165 Register PhysReg = VRM->getPhys(VirtReg); 166 if (PhysRegs.test(PhysReg)) 167 continue; 168 PhysRegs.set(PhysReg); 169 ShapeT Shape = VRM->getShape(VirtReg); 170 Register RowReg = Shape.getRow()->getReg(); 171 Register ColReg = Shape.getCol()->getReg(); 172 173 // Here is the data format for the tile config. 174 // 0 palette 175 // 1 start_row 176 // 2-15 reserved, must be zero 177 // 16-17 tile0.colsb Tile 0 bytes per row. 178 // 18-19 tile1.colsb Tile 1 bytes per row. 179 // 20-21 tile2.colsb Tile 2 bytes per row. 180 // ... (sequence continues) 181 // 30-31 tile7.colsb Tile 7 bytes per row. 182 // 32-47 reserved, must be zero 183 // 48 tile0.rows Tile 0 rows. 184 // 49 tile1.rows Tile 1 rows. 185 // 50 tile2.rows Tile 2 rows. 186 // ... (sequence continues) 187 // 55 tile7.rows Tile 7 rows. 188 // 56-63 reserved, must be zero 189 unsigned Index = getTilePhysRegIndex(PhysReg); 190 int RowOffset = 48 + Index; 191 int ColOffset = 16 + Index * 2; 192 193 unsigned BitSize = 8; 194 for (const auto &Pair : {std::make_pair(RowReg, RowOffset), 195 std::make_pair(ColReg, ColOffset)}) { 196 int64_t Imm; 197 int ImmCount = 0; 198 // All def must be the same value, otherwise it is invalid MIs. 199 // Immediate is prefered. 200 for (const MachineOperand &MO : MRI->def_operands(Pair.first)) { 201 const auto *Inst = MO.getParent(); 202 if (Inst->isMoveImmediate()) { 203 ImmCount++; 204 Imm = Inst->getOperand(1).getImm(); 205 break; 206 } 207 } 208 auto StoreConfig = [&](int Offset) { 209 MachineInstr *NewMI = nullptr; 210 if (ImmCount) 211 NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII); 212 else { 213 const TargetRegisterClass *RC = MRI->getRegClass(Pair.first); 214 NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS, 215 Offset, TII, RC, TRI); 216 } 217 SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI); 218 if (!ImmCount) { 219 // Extend the live interval. 220 SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()}; 221 LiveInterval &Int = LIS->getInterval(Pair.first); 222 LIS->extendToIndices(Int, EndPoints); 223 } 224 }; 225 StoreConfig(Pair.second); 226 BitSize += 8; 227 } 228 } 229 } 230 231 bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) { 232 MF = &mf; 233 MRI = &mf.getRegInfo(); 234 ST = &mf.getSubtarget<X86Subtarget>(); 235 TRI = ST->getRegisterInfo(); 236 TII = mf.getSubtarget().getInstrInfo(); 237 DomTree = &getAnalysis<MachineDominatorTree>(); 238 VRM = &getAnalysis<VirtRegMap>(); 239 LIS = &getAnalysis<LiveIntervals>(); 240 241 if (VRM->isShapeMapEmpty()) 242 return false; 243 244 tileConfig(); 245 return true; 246 } 247 248 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } 249