1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86Subtarget.h" 24 #include "llvm/CodeGen/MachineFrameInfo.h" 25 #include "llvm/CodeGen/MachineFunctionPass.h" 26 #include "llvm/CodeGen/MachineInstr.h" 27 #include "llvm/CodeGen/MachineRegisterInfo.h" 28 #include "llvm/CodeGen/Passes.h" 29 #include "llvm/CodeGen/TargetInstrInfo.h" 30 #include "llvm/CodeGen/TargetRegisterInfo.h" 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "fasttileconfig" 35 36 namespace { 37 38 class X86FastTileConfig : public MachineFunctionPass { 39 // context 40 MachineFunction *MF = nullptr; 41 const TargetInstrInfo *TII = nullptr; 42 MachineRegisterInfo *MRI = nullptr; 43 const TargetRegisterInfo *TRI = nullptr; 44 X86MachineFunctionInfo *X86FI = nullptr; 45 46 bool configBasicBlock(MachineBasicBlock &MBB); 47 48 public: 49 X86FastTileConfig() : MachineFunctionPass(ID) {} 50 51 /// Return the pass name. 52 StringRef getPassName() const override { 53 return "Fast Tile Register Configure"; 54 } 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 AU.setPreservesAll(); 58 MachineFunctionPass::getAnalysisUsage(AU); 59 } 60 61 /// Perform register allocation. 62 bool runOnMachineFunction(MachineFunction &MFunc) override; 63 64 MachineFunctionProperties getRequiredProperties() const override { 65 return MachineFunctionProperties().setNoPHIs(); 66 } 67 68 static char ID; 69 }; 70 71 } // end anonymous namespace 72 73 char X86FastTileConfig::ID = 0; 74 75 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 76 "Fast Tile Register Configure", false, false) 77 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 78 "Fast Tile Register Configure", false, false) 79 80 static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { 81 // There is no phi instruction after register allocation. 82 assert(MI.isPHI() == false); 83 // The instruction must have 3 operands: tile def, row, col. 84 // It should be AMX pseudo instruction that have shape operand. 85 if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 || 86 !MI.isPseudo()) 87 return 0; 88 MachineOperand &MO = MI.getOperand(0); 89 90 if (MO.isReg()) { 91 Register Reg = MO.getReg(); 92 // FIXME: It may be used after Greedy RA and the physical 93 // register is not rewritten yet. 94 if (Reg.isVirtual()) { 95 if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) 96 return 1; 97 if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) 98 return 2; 99 } 100 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 101 return 1; 102 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) 103 return 2; 104 } 105 106 return 0; 107 } 108 109 static unsigned getTMMIndex(Register Reg) { 110 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 111 return Reg - X86::TMM0; 112 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) 113 return (Reg - X86::TMM0_TMM1) * 2; 114 llvm_unreachable("Invalid Tmm Reg!"); 115 } 116 117 // PreTileConfig should configure the tile registers based on basic 118 // block. 119 bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) { 120 bool Change = false; 121 SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos; 122 for (MachineInstr &MI : reverse(MBB)) { 123 unsigned DefNum = getNumDefTiles(MRI, MI); 124 if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) 125 continue; 126 // AMX instructions that define tile register. 127 if (MI.getOpcode() != X86::PLDTILECFGV) { 128 MachineOperand &Row = MI.getOperand(1); 129 unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); 130 for (unsigned I = 0; I < DefNum; I++) { 131 MachineOperand &Col = MI.getOperand(2 + I); 132 ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); 133 } 134 } else { // PLDTILECFGV 135 // Rewrite the shape information to memory. Stack slot should have 136 // been initialized to zero in pre config. 137 int SS = MI.getOperand(0).getIndex(); // tile config stack slot. 138 for (auto &ShapeInfo : ShapeInfos) { 139 DebugLoc DL; 140 unsigned TMMIdx = ShapeInfo.first; 141 Register RowReg = ShapeInfo.second.getRow()->getReg(); 142 Register ColReg = ShapeInfo.second.getCol()->getReg(); 143 // Here is the data format for the tile config. 144 // 0 palette 145 // 1 start_row 146 // 2-15 reserved, must be zero 147 // 16-17 tile0.colsb Tile 0 bytes per row. 148 // 18-19 tile1.colsb Tile 1 bytes per row. 149 // 20-21 tile2.colsb Tile 2 bytes per row. 150 // ... (sequence continues) 151 // 30-31 tile7.colsb Tile 7 bytes per row. 152 // 32-47 reserved, must be zero 153 // 48 tile0.rows Tile 0 rows. 154 // 49 tile1.rows Tile 1 rows. 155 // 50 tile2.rows Tile 2 rows. 156 // ... (sequence continues) 157 // 55 tile7.rows Tile 7 rows. 158 // 56-63 reserved, must be zero 159 int RowOffset = 48 + TMMIdx; 160 int ColOffset = 16 + TMMIdx * 2; 161 162 Register SubRowReg = TRI->getSubReg(RowReg, X86::sub_8bit); 163 BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), SubRowReg); 164 MachineInstrBuilder StoreRow = 165 BuildMI(MBB, MI, DL, TII->get(X86::MOV8mr)); 166 addFrameReference(StoreRow, SS, RowOffset).addReg(SubRowReg); 167 168 MachineInstrBuilder StoreCol = 169 BuildMI(MBB, MI, DL, TII->get(X86::MOV16mr)); 170 addFrameReference(StoreCol, SS, ColOffset).addReg(ColReg); 171 } 172 ShapeInfos.clear(); 173 Change = true; 174 } 175 } 176 177 return Change; 178 } 179 180 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 181 X86FI = MFunc.getInfo<X86MachineFunctionInfo>(); 182 // Early exit in the common case of non-AMX code. 183 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) 184 return false; 185 186 MF = &MFunc; 187 MRI = &MFunc.getRegInfo(); 188 const TargetSubtargetInfo *ST = &MFunc.getSubtarget<X86Subtarget>(); 189 TRI = ST->getRegisterInfo(); 190 TII = MFunc.getSubtarget().getInstrInfo(); 191 bool Change = false; 192 193 // Loop over all of the basic blocks, eliminating virtual register references 194 for (MachineBasicBlock &MBB : MFunc) 195 Change |= configBasicBlock(MBB); 196 197 return Change; 198 } 199 200 FunctionPass *llvm::createX86FastTileConfigPass() { 201 return new X86FastTileConfig(); 202 } 203