1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/MachineFrameInfo.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineInstr.h" 28 #include "llvm/CodeGen/MachineRegisterInfo.h" 29 #include "llvm/CodeGen/Passes.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetRegisterInfo.h" 32 #include "llvm/InitializePasses.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "fasttileconfig" 37 38 namespace { 39 40 class X86FastTileConfig : public MachineFunctionPass { 41 // context 42 MachineFunction *MF = nullptr; 43 const X86Subtarget *ST = nullptr; 44 const TargetRegisterInfo *TRI = nullptr; 45 const TargetInstrInfo *TII = nullptr; 46 MachineRegisterInfo *MRI = nullptr; 47 X86MachineFunctionInfo *X86FI = nullptr; 48 49 MachineInstr *getTileConfigPoint(); 50 void tileConfig(); 51 52 public: 53 X86FastTileConfig() : MachineFunctionPass(ID) {} 54 55 bool fastTileConfig(); 56 bool isTileLoad(MachineInstr &MI); 57 bool isTileStore(MachineInstr &MI); 58 bool isAMXInstr(MachineInstr &MI); 59 void getTileStoreShape(MachineInstr &MI, 60 SmallVector<MachineOperand *> &ShapedTiles); 61 62 MachineInstr *getKeyAMXInstr(MachineInstr *MI); 63 void getTileShapesCfg(MachineInstr *MI, 64 SmallVector<MachineOperand *> &ShapedTiles); 65 void getShapeCfgInstrs(MachineInstr *MI, 66 std::map<unsigned, MachineInstr *> &RowCfgs, 67 std::map<unsigned, MachineInstr *> &ColCfgs); 68 69 /// Return the pass name. 70 StringRef getPassName() const override { 71 return "Fast Tile Register Configure"; 72 } 73 74 void materializeTileCfg(MachineInstr *MI); 75 76 void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles, 77 std::map<unsigned, MachineInstr *> &RowCfgs, 78 std::map<unsigned, MachineInstr *> &ColCfgs); 79 80 /// Perform register allocation. 81 bool runOnMachineFunction(MachineFunction &MFunc) override; 82 83 MachineFunctionProperties getRequiredProperties() const override { 84 return MachineFunctionProperties().set( 85 MachineFunctionProperties::Property::NoPHIs); 86 } 87 88 static char ID; 89 }; 90 91 } // end anonymous namespace 92 93 char X86FastTileConfig::ID = 0; 94 95 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 96 "Fast Tile Register Configure", false, false) 97 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 98 "Fast Tile Register Configure", false, false) 99 100 static bool isTilePhysReg(MachineOperand &Op) { 101 if (!Op.isReg()) 102 return false; 103 104 Register Reg = Op.getReg(); 105 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 106 return true; 107 return false; 108 } 109 110 static unsigned getTilePhysRegIdx(MachineOperand *Op) { 111 assert(isTilePhysReg(*Op) && "Tile Operand is invalid"); 112 return Op->getReg() - X86::TMM0; 113 } 114 115 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) { 116 unsigned Offset = 48 + TIdx; 117 MI->getOperand(3).ChangeToImmediate(Offset); 118 } 119 120 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) { 121 unsigned Offset = 16 + TIdx * 2; 122 MI->getOperand(3).ChangeToImmediate(Offset); 123 } 124 125 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) { 126 return MI.getOpcode() == X86::PTILELOADDV || 127 MI.getOpcode() == X86::PTILELOADDT1V; 128 } 129 bool X86FastTileConfig::isTileStore(MachineInstr &MI) { 130 return MI.getOpcode() == X86::PTILESTOREDV; 131 } 132 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) { 133 // TODO: May need to handle some special nontile amx instrucion. 134 if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr()) 135 return false; 136 137 for (MachineOperand &MO : MI.operands()) 138 if (isTilePhysReg(MO)) 139 return true; 140 141 return false; 142 } 143 144 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) { 145 auto Cfg = MachineBasicBlock::iterator(MI); 146 MachineBasicBlock *MBB = MI->getParent(); 147 MachineInstr *KeyMI = nullptr; 148 int KeyAMXNum = 0; 149 150 for (auto II = Cfg; II != MBB->end(); II++) { 151 if (isTileLoad(*II)) { 152 KeyMI = &*II; 153 continue; 154 } 155 156 if (isTileStore(*II)) { 157 assert(KeyMI && "Key AMX Should be found before!"); 158 break; 159 } 160 161 if (isAMXInstr(*II)) { 162 assert((KeyAMXNum == 0) && "Too many Key AMX instruction!"); 163 KeyAMXNum++; 164 KeyMI = &*II; 165 } 166 } 167 assert(KeyMI && "There must be an AMX instruction."); 168 return KeyMI; 169 } 170 171 // Orderly get the tiles in key amx instruction, uses before defs. 172 void X86FastTileConfig::getTileShapesCfg( 173 MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) { 174 MachineInstr *KeyMI = getKeyAMXInstr(CfgMI); 175 176 SmallVector<MachineOperand *> DefTiles; 177 for (MachineOperand &MO : KeyMI->operands()) { 178 if (!isTilePhysReg(MO)) 179 continue; 180 if (MO.isDef()) 181 DefTiles.push_back(&MO); 182 else 183 ShapedTiles.push_back(&MO); 184 } 185 ShapedTiles.append(DefTiles); 186 } 187 188 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and 189 // amx.shape.N.col*" at pass "Pre AMX Tile Config". 190 // The 'N' implies the order of tiles in key amx intrinsic. 191 void X86FastTileConfig::getShapeCfgInstrs( 192 MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs, 193 std::map<unsigned, MachineInstr *> &ColCfgs) { 194 auto Cfg = MachineBasicBlock::iterator(MI); 195 MachineBasicBlock *MBB = MI->getParent(); 196 197 for (auto II = Cfg; II != MBB->begin(); II--) { 198 if (isAMXInstr(*II) || II->isTerminator() || II->isCall()) 199 break; 200 if (!II->mayStore() || !II->hasOneMemOperand()) 201 continue; 202 const Value *MemPtr = II->memoperands()[0]->getValue(); 203 if (!MemPtr) 204 continue; 205 206 StringRef Name = MemPtr->getName(); 207 if (!Name.startswith("amx.tmm.")) 208 continue; 209 210 // Get the 'N'th tile shape config in key amx instruction. 211 auto N = Name.find(".shape"); 212 StringRef STileIdx = Name.slice(8, N); 213 unsigned Idx; 214 STileIdx.getAsInteger(10, Idx); 215 216 // And related them with their store instructions. 217 if (Name.contains("row")) 218 RowCfgs[Idx] = &*II; 219 else if (Name.contains("col")) 220 ColCfgs[Idx] = &*II; 221 else 222 llvm_unreachable("Invalid tile shape info!"); 223 } 224 assert((RowCfgs.size() == ColCfgs.size()) && 225 "The number of tile row and col must be equal!"); 226 } 227 228 // Here is the data format for the tile config. 229 // 0 palette = 1 now. 230 // 1 start_row = 0 now. 231 // 2-15 reserved, must be zero 232 // 16-17 tile0.colsb Tile 0 bytes per row. 233 // 18-19 tile1.colsb Tile 1 bytes per row. 234 // 20-21 tile2.colsb Tile 2 bytes per row. 235 // ... (sequence continues) 236 // 30-31 tile7.colsb Tile 7 bytes per row. 237 // 32-47 reserved, must be zero 238 // 48 tile0.rows Tile 0 rows. 239 // 49 tile1.rows Tile 1 rows. 240 // 50 tile2.rows Tile 2 rows. 241 // ... (sequence continues) 242 // 55 tile7.rows Tile 7 rows. 243 // 56-63 reserved, must be zero 244 void X86FastTileConfig::rewriteTileCfg( 245 SmallVector<MachineOperand *> &ShapedTiles, 246 std::map<unsigned, MachineInstr *> &RowCfgs, 247 std::map<unsigned, MachineInstr *> &ColCfgs) { 248 assert((RowCfgs.size() == ShapedTiles.size()) && 249 "The number of tile shapes not equal with the number of tiles!"); 250 251 // Orderly get the tiles and adjust the shape config. 252 for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) { 253 MachineOperand *MO = ShapedTiles[I]; 254 unsigned TmmIdx = getTilePhysRegIdx(MO); 255 if (I == TmmIdx) 256 continue; 257 adjustRowCfg(TmmIdx, RowCfgs[I]); 258 adjustColCfg(TmmIdx, ColCfgs[I]); 259 } 260 } 261 262 // We have already preconfig the shapes before fast register allocation at 263 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register 264 // allocation, the shapes pre-written before may not rightly corresponding 265 // to the correct tmm registers, so we need adjust them. 266 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) { 267 SmallVector<MachineOperand *> ShapedTiles; 268 std::map<unsigned, MachineInstr *> RowCfgs; 269 std::map<unsigned, MachineInstr *> ColCfgs; 270 271 // Orderly keep the tile uses and def in ShapedTiles; 272 getTileShapesCfg(CfgMI, ShapedTiles); 273 assert(ShapedTiles.size() && "Not find shapes config!"); 274 275 getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs); 276 277 rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs); 278 } 279 280 bool X86FastTileConfig::fastTileConfig() { 281 bool Changed = false; 282 283 for (MachineBasicBlock &MBB : *MF) { 284 SmallVector<MachineInstr *, 2> CFGs; 285 for (MachineInstr &MI : MBB) 286 if (MI.getOpcode() == X86::PLDTILECFGV) 287 CFGs.push_back(&MI); 288 for (auto *MI : CFGs) 289 materializeTileCfg(MI); 290 if (!CFGs.empty()) 291 Changed = true; 292 } 293 if (Changed) 294 X86FI->setHasVirtualTileReg(true); 295 return Changed; 296 } 297 298 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 299 MF = &MFunc; 300 MRI = &MFunc.getRegInfo(); 301 ST = &MFunc.getSubtarget<X86Subtarget>(); 302 TRI = ST->getRegisterInfo(); 303 TII = MFunc.getSubtarget().getInstrInfo(); 304 X86FI = MFunc.getInfo<X86MachineFunctionInfo>(); 305 306 return fastTileConfig(); 307 } 308 309 FunctionPass *llvm::createX86FastTileConfigPass() { 310 return new X86FastTileConfig(); 311 } 312