1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/MachineFrameInfo.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineInstr.h" 28 #include "llvm/CodeGen/MachineRegisterInfo.h" 29 #include "llvm/CodeGen/Passes.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetRegisterInfo.h" 32 #include "llvm/InitializePasses.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "fasttileconfig" 37 38 namespace { 39 40 class X86FastTileConfig : public MachineFunctionPass { 41 // context 42 MachineFunction *MF = nullptr; 43 const X86Subtarget *ST = nullptr; 44 const TargetRegisterInfo *TRI = nullptr; 45 const TargetInstrInfo *TII = nullptr; 46 MachineRegisterInfo *MRI = nullptr; 47 48 MachineInstr *getTileConfigPoint(); 49 void tileConfig(); 50 51 public: 52 X86FastTileConfig() : MachineFunctionPass(ID) {} 53 54 bool fastTileConfig(); 55 bool isTileLoad(MachineInstr &MI); 56 bool isTileStore(MachineInstr &MI); 57 bool isAMXInstr(MachineInstr &MI); 58 void getTileStoreShape(MachineInstr &MI, 59 SmallVector<MachineOperand *> &ShapedTiles); 60 61 MachineInstr *getKeyAMXInstr(MachineInstr *MI); 62 void getTileShapesCfg(MachineInstr *MI, 63 SmallVector<MachineOperand *> &ShapedTiles); 64 void getShapeCfgInstrs(MachineInstr *MI, 65 std::map<unsigned, MachineInstr *> &RowCfgs, 66 std::map<unsigned, MachineInstr *> &ColCfgs); 67 68 /// Return the pass name. 69 StringRef getPassName() const override { 70 return "Fast Tile Register Configure"; 71 } 72 73 void materializeTileCfg(MachineInstr *MI); 74 75 void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles, 76 std::map<unsigned, MachineInstr *> &RowCfgs, 77 std::map<unsigned, MachineInstr *> &ColCfgs); 78 79 /// Perform register allocation. 80 bool runOnMachineFunction(MachineFunction &MFunc) override; 81 82 MachineFunctionProperties getRequiredProperties() const override { 83 return MachineFunctionProperties().set( 84 MachineFunctionProperties::Property::NoPHIs); 85 } 86 87 static char ID; 88 }; 89 90 } // end anonymous namespace 91 92 char X86FastTileConfig::ID = 0; 93 94 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 95 "Fast Tile Register Configure", false, false) 96 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 97 "Fast Tile Register Configure", false, false) 98 99 static bool isTilePhysReg(MachineOperand &Op) { 100 if (!Op.isReg()) 101 return false; 102 103 Register Reg = Op.getReg(); 104 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 105 return true; 106 return false; 107 } 108 109 static unsigned getTilePhysRegIdx(MachineOperand *Op) { 110 assert(isTilePhysReg(*Op) && "Tile Operand is invalid"); 111 return Op->getReg() - X86::TMM0; 112 } 113 114 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) { 115 unsigned Offset = 48 + TIdx; 116 MI->getOperand(3).ChangeToImmediate(Offset); 117 } 118 119 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) { 120 unsigned Offset = 16 + TIdx * 2; 121 MI->getOperand(3).ChangeToImmediate(Offset); 122 } 123 124 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) { 125 return MI.getOpcode() == X86::PTILELOADDV || 126 MI.getOpcode() == X86::PTILELOADDT1V; 127 } 128 bool X86FastTileConfig::isTileStore(MachineInstr &MI) { 129 return MI.getOpcode() == X86::PTILESTOREDV; 130 } 131 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) { 132 // TODO: May need to handle some special nontile amx instrucion. 133 if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr()) 134 return false; 135 136 for (MachineOperand &MO : MI.operands()) 137 if (isTilePhysReg(MO)) 138 return true; 139 140 return false; 141 } 142 143 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) { 144 auto Cfg = MachineBasicBlock::iterator(MI); 145 MachineBasicBlock *MBB = MI->getParent(); 146 MachineInstr *KeyMI = nullptr; 147 int KeyAMXNum = 0; 148 149 for (auto II = Cfg; II != MBB->end(); II++) { 150 if (isTileLoad(*II)) { 151 KeyMI = &*II; 152 continue; 153 } 154 155 if (isTileStore(*II)) { 156 assert(KeyMI && "Key AMX Should be found before!"); 157 break; 158 } 159 160 if (isAMXInstr(*II)) { 161 assert((KeyAMXNum == 0) && "Too many Key AMX instruction!"); 162 KeyAMXNum++; 163 KeyMI = &*II; 164 } 165 } 166 assert(KeyMI && "There must be an AMX instruction."); 167 return KeyMI; 168 } 169 170 // Orderly get the tiles in key amx instruction, uses before defs. 171 void X86FastTileConfig::getTileShapesCfg( 172 MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) { 173 MachineInstr *KeyMI = getKeyAMXInstr(CfgMI); 174 175 SmallVector<MachineOperand *> DefTiles; 176 for (MachineOperand &MO : KeyMI->operands()) { 177 if (!isTilePhysReg(MO)) 178 continue; 179 if (MO.isDef()) 180 DefTiles.push_back(&MO); 181 else 182 ShapedTiles.push_back(&MO); 183 } 184 ShapedTiles.append(DefTiles); 185 } 186 187 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and 188 // amx.shape.N.col*" at pass "Pre AMX Tile Config". 189 // The 'N' implies the order of tiles in key amx intrinsic. 190 void X86FastTileConfig::getShapeCfgInstrs( 191 MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs, 192 std::map<unsigned, MachineInstr *> &ColCfgs) { 193 auto Cfg = MachineBasicBlock::iterator(MI); 194 MachineBasicBlock *MBB = MI->getParent(); 195 196 for (auto II = Cfg; II != MBB->begin(); II--) { 197 if (isAMXInstr(*II) || II->isTerminator() || II->isCall()) 198 break; 199 if (!II->mayStore() || !II->hasOneMemOperand()) 200 continue; 201 const Value *MemPtr = II->memoperands()[0]->getValue(); 202 if (!MemPtr) 203 continue; 204 205 StringRef Name = MemPtr->getName(); 206 if (!Name.startswith("amx.tmm.")) 207 continue; 208 209 // Get the 'N'th tile shape config in key amx instruction. 210 auto N = Name.find(".shape"); 211 StringRef STileIdx = Name.slice(8, N); 212 unsigned Idx; 213 STileIdx.getAsInteger(10, Idx); 214 215 // And related them with their store instructions. 216 if (Name.contains("row")) 217 RowCfgs[Idx] = &*II; 218 else if (Name.contains("col")) 219 ColCfgs[Idx] = &*II; 220 else 221 llvm_unreachable("Invalid tile shape info!"); 222 } 223 assert((RowCfgs.size() == ColCfgs.size()) && 224 "The number of tile row and col must be equal!"); 225 } 226 227 // Here is the data format for the tile config. 228 // 0 palette = 1 now. 229 // 1 start_row = 0 now. 230 // 2-15 reserved, must be zero 231 // 16-17 tile0.colsb Tile 0 bytes per row. 232 // 18-19 tile1.colsb Tile 1 bytes per row. 233 // 20-21 tile2.colsb Tile 2 bytes per row. 234 // ... (sequence continues) 235 // 30-31 tile7.colsb Tile 7 bytes per row. 236 // 32-47 reserved, must be zero 237 // 48 tile0.rows Tile 0 rows. 238 // 49 tile1.rows Tile 1 rows. 239 // 50 tile2.rows Tile 2 rows. 240 // ... (sequence continues) 241 // 55 tile7.rows Tile 7 rows. 242 // 56-63 reserved, must be zero 243 void X86FastTileConfig::rewriteTileCfg( 244 SmallVector<MachineOperand *> &ShapedTiles, 245 std::map<unsigned, MachineInstr *> &RowCfgs, 246 std::map<unsigned, MachineInstr *> &ColCfgs) { 247 assert((RowCfgs.size() == ShapedTiles.size()) && 248 "The number of tile shapes not equal with the number of tiles!"); 249 250 // Orderly get the tiles and adjust the shape config. 251 for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) { 252 MachineOperand *MO = ShapedTiles[I]; 253 unsigned TmmIdx = getTilePhysRegIdx(MO); 254 if (I == TmmIdx) 255 continue; 256 adjustRowCfg(TmmIdx, RowCfgs[I]); 257 adjustColCfg(TmmIdx, ColCfgs[I]); 258 } 259 } 260 261 // We have already preconfig the shapes before fast register allocation at 262 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register 263 // allocation, the shapes pre-written before may not rightly corresponding 264 // to the correct tmm registers, so we need adjust them. 265 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) { 266 SmallVector<MachineOperand *> ShapedTiles; 267 std::map<unsigned, MachineInstr *> RowCfgs; 268 std::map<unsigned, MachineInstr *> ColCfgs; 269 270 // Orderly keep the tile uses and def in ShapedTiles; 271 getTileShapesCfg(CfgMI, ShapedTiles); 272 assert(ShapedTiles.size() && "Not find shapes config!"); 273 274 getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs); 275 276 rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs); 277 } 278 279 bool X86FastTileConfig::fastTileConfig() { 280 bool Changed = false; 281 282 for (MachineBasicBlock &MBB : *MF) { 283 SmallVector<MachineInstr *, 2> CFGs; 284 for (MachineInstr &MI : MBB) 285 if (MI.getOpcode() == X86::PLDTILECFGV) 286 CFGs.push_back(&MI); 287 for (auto *MI : CFGs) 288 materializeTileCfg(MI); 289 if (!CFGs.empty()) 290 Changed = true; 291 } 292 return Changed; 293 } 294 295 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 296 MF = &MFunc; 297 MRI = &MFunc.getRegInfo(); 298 ST = &MFunc.getSubtarget<X86Subtarget>(); 299 TRI = ST->getRegisterInfo(); 300 TII = MFunc.getSubtarget().getInstrInfo(); 301 302 return fastTileConfig(); 303 } 304 305 FunctionPass *llvm::createX86FastTileConfigPass() { 306 return new X86FastTileConfig(); 307 } 308