1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. In X86PreTileConfig pass 11 /// the pldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after egister allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/LiveIntervals.h" 26 #include "llvm/CodeGen/MachineFrameInfo.h" 27 #include "llvm/CodeGen/MachineFunctionPass.h" 28 #include "llvm/CodeGen/MachineInstr.h" 29 #include "llvm/CodeGen/MachineRegisterInfo.h" 30 #include "llvm/CodeGen/Passes.h" 31 #include "llvm/CodeGen/TargetInstrInfo.h" 32 #include "llvm/CodeGen/TargetRegisterInfo.h" 33 #include "llvm/CodeGen/TileShapeInfo.h" 34 #include "llvm/CodeGen/VirtRegMap.h" 35 #include "llvm/InitializePasses.h" 36 37 using namespace llvm; 38 39 #define DEBUG_TYPE "tileconfig" 40 41 namespace { 42 43 struct X86TileConfig : public MachineFunctionPass { 44 45 X86TileConfig() : MachineFunctionPass(ID) {} 46 47 /// Return the pass name. 48 StringRef getPassName() const override { return "Tile Register Configure"; } 49 50 /// X86TileConfig analysis usage. 51 void getAnalysisUsage(AnalysisUsage &AU) const override { 52 AU.setPreservesAll(); 53 AU.addRequired<VirtRegMap>(); 54 AU.addRequired<LiveIntervalsWrapperPass>(); 55 MachineFunctionPass::getAnalysisUsage(AU); 56 } 57 58 /// Perform register allocation. 59 bool runOnMachineFunction(MachineFunction &mf) override; 60 61 MachineFunctionProperties getRequiredProperties() const override { 62 return MachineFunctionProperties().set( 63 MachineFunctionProperties::Property::NoPHIs); 64 } 65 66 static char ID; 67 }; 68 69 } // end anonymous namespace 70 71 char X86TileConfig::ID = 0; 72 73 INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", 74 false, false) 75 INITIALIZE_PASS_DEPENDENCY(VirtRegMap) 76 INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false, 77 false) 78 79 bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { 80 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); 81 // Early exit in the common case of non-AMX code. 82 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) 83 return false; 84 85 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 86 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 87 const TargetInstrInfo *TII = ST.getInstrInfo(); 88 MachineRegisterInfo &MRI = MF.getRegInfo(); 89 LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 90 VirtRegMap &VRM = getAnalysis<VirtRegMap>(); 91 92 if (VRM.isShapeMapEmpty()) 93 return false; 94 95 int SS = INT_MAX; 96 for (MachineBasicBlock &MBB : MF) { 97 for (MachineInstr &MI : MBB) { 98 if (MI.getOpcode() == X86::PLDTILECFGV) { 99 SS = MI.getOperand(0).getIndex(); 100 break; 101 } 102 } 103 if (SS != INT_MAX) 104 break; 105 } 106 // Didn't find PLDTILECFGV, just return false; 107 if (SS == INT_MAX) 108 return false; 109 110 // Try to find a point to insert MIs for constant shapes. 111 // Here we are leveraging the palette id inserted in PreRA pass. 112 unsigned ConstPos = 0; 113 MachineInstr *ConstMI = nullptr; 114 for (MachineInstr &MI : MF.front()) { 115 if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) { 116 ConstMI = &MI; 117 break; 118 } 119 ++ConstPos; 120 } 121 assert(ConstMI && "Cannot find an insertion point"); 122 123 unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); 124 SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0); 125 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 126 Register VirtReg = Register::index2VirtReg(I); 127 if (MRI.reg_nodbg_empty(VirtReg)) 128 continue; 129 if (MRI.getRegClass(VirtReg)->getID() != X86::TILERegClassID) 130 continue; 131 if (VRM.getPhys(VirtReg) == VirtRegMap::NO_PHYS_REG) 132 continue; 133 unsigned Index = VRM.getPhys(VirtReg) - X86::TMM0; 134 if (!Phys2Virt[Index]) 135 Phys2Virt[Index] = VirtReg; 136 } 137 138 // Fill in the shape of each tile physical register. 139 for (unsigned I = 0; I < AMXRegNum; ++I) { 140 if (!Phys2Virt[I]) 141 continue; 142 DebugLoc DL; 143 bool IsRow = true; 144 MachineInstr *NewMI = nullptr; 145 ShapeT Shape = VRM.getShape(Phys2Virt[I]); 146 for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) { 147 // Here is the data format for the tile config. 148 // 0 palette 149 // 1 start_row 150 // 2-15 reserved, must be zero 151 // 16-17 tile0.colsb Tile 0 bytes per row. 152 // 18-19 tile1.colsb Tile 1 bytes per row. 153 // 20-21 tile2.colsb Tile 2 bytes per row. 154 // ... (sequence continues) 155 // 30-31 tile7.colsb Tile 7 bytes per row. 156 // 32-47 reserved, must be zero 157 // 48 tile0.rows Tile 0 rows. 158 // 49 tile1.rows Tile 1 rows. 159 // 50 tile2.rows Tile 2 rows. 160 // ... (sequence continues) 161 // 55 tile7.rows Tile 7 rows. 162 // 56-63 reserved, must be zero 163 int64_t Imm = INT64_MAX; 164 int Offset = IsRow ? 48 + I : 16 + I * 2; 165 for (auto &DefMI : MRI.def_instructions(R)) { 166 MachineBasicBlock &MBB = *DefMI.getParent(); 167 if (DefMI.isMoveImmediate()) { 168 if (Imm != INT64_MAX) { 169 // FIXME: We should handle this case in future. 170 assert(Imm == DefMI.getOperand(1).getImm() && 171 "Cannot initialize with different shapes"); 172 continue; 173 } 174 Imm = DefMI.getOperand(1).getImm(); 175 NewMI = addFrameReference( 176 BuildMI(MF.front(), ++ConstMI->getIterator(), DL, 177 TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)), 178 SS, Offset) 179 .addImm(Imm); 180 ConstMI = NewMI; 181 LIS.InsertMachineInstrInMaps(*NewMI); 182 } else { 183 unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit; 184 unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R)); 185 if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16)) 186 SubIdx = 0; 187 auto Iter = DefMI.getIterator(); 188 if (&MBB == &MF.front() && 189 (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos) 190 Iter = ConstMI->getIterator(); 191 NewMI = addFrameReference( 192 BuildMI(MBB, ++Iter, DL, 193 TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)), 194 SS, Offset) 195 .addReg(R, 0, SubIdx); 196 SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI); 197 LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()}); 198 } 199 } 200 IsRow = false; 201 } 202 } 203 return true; 204 } 205 206 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } 207