1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86MachineFunctionInfo.h" 29 #include "X86RegisterInfo.h" 30 #include "X86Subtarget.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstr.h" 33 #include "llvm/CodeGen/MachineLoopInfo.h" 34 #include "llvm/CodeGen/MachineRegisterInfo.h" 35 #include "llvm/CodeGen/Passes.h" 36 #include "llvm/CodeGen/TargetInstrInfo.h" 37 #include "llvm/CodeGen/TargetRegisterInfo.h" 38 #include "llvm/InitializePasses.h" 39 40 using namespace llvm; 41 42 #define DEBUG_TYPE "tile-pre-config" 43 #define REPORT_CONFIG_FAIL \ 44 report_fatal_error( \ 45 MF.getName() + \ 46 ": Failed to config tile register, please define the shape earlier"); 47 48 namespace { 49 50 struct MIRef { 51 MachineInstr *MI = nullptr; 52 MachineBasicBlock *MBB = nullptr; 53 // A virtual position for instruction that will be inserted after MI. 54 size_t Pos = 0; 55 MIRef() = default; 56 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 57 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 58 ++I, ++Pos) 59 MI = &*I; 60 } 61 MIRef(MachineInstr *MI) 62 : MI(MI), MBB(MI->getParent()), 63 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 64 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 65 : MI(MI), MBB(MBB), 66 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 67 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 68 : MI(MI), MBB(MBB), Pos(Pos) {} 69 operator bool() const { return MBB != nullptr; } 70 bool operator==(const MIRef &RHS) const { 71 return MI == RHS.MI && MBB == RHS.MBB; 72 } 73 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } 74 bool operator<(const MIRef &RHS) const { 75 // Comparison between different BBs happens when inserting a MIRef into set. 76 // So we compare MBB first to make the insertion happy. 77 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 78 } 79 bool operator>(const MIRef &RHS) const { 80 // Comparison between different BBs happens when inserting a MIRef into set. 81 // So we compare MBB first to make the insertion happy. 82 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 83 } 84 }; 85 86 struct BBInfo { 87 MIRef FirstAMX; 88 MIRef LastCall; 89 bool HasAMXRegLiveIn = false; 90 bool TileCfgForbidden = false; 91 bool NeedTileCfgLiveIn = false; 92 }; 93 94 class X86PreTileConfig : public MachineFunctionPass { 95 MachineRegisterInfo *MRI; 96 const MachineLoopInfo *MLI; 97 SmallSet<MachineInstr *, 8> DefVisited; 98 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 99 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; 100 101 /// Check if the callee will clobber AMX registers. 102 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 103 auto Iter = llvm::find_if( 104 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 105 if (Iter == MI.operands_end()) 106 return false; 107 UsableRegs.clearBitsInMask(Iter->getRegMask()); 108 return !UsableRegs.none(); 109 } 110 111 /// Check if MI is AMX pseudo instruction. 112 bool isAMXInstruction(MachineInstr &MI) { 113 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 114 return false; 115 MachineOperand &MO = MI.getOperand(0); 116 // We can simply check if it is AMX instruction by its def. 117 // But we should exclude old API which uses physical registers. 118 if (MO.isReg() && MO.getReg().isVirtual() && 119 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 120 collectShapeInfo(MI); 121 return true; 122 } 123 // PTILESTOREDV is the only exception that doesn't def a AMX register. 124 return MI.getOpcode() == X86::PTILESTOREDV; 125 } 126 127 /// Check if it is an edge from loop bottom to loop head. 128 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 129 if (!MLI->isLoopHeader(Header)) 130 return false; 131 auto *ML = MLI->getLoopFor(Header); 132 if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) 133 return true; 134 135 return false; 136 } 137 138 /// Collect the shape def information for later use. 139 void collectShapeInfo(MachineInstr &MI); 140 141 /// Try to hoist shapes definded below AMX instructions. 142 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { 143 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; 144 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); 145 auto InsertPoint = FirstAMX.MI->getIterator(); 146 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { 147 // Do not hoist instructions that access memory. 148 if (I->MI->mayLoadOrStore()) 149 return false; 150 for (auto &MO : I->MI->operands()) { 151 if (MO.isDef()) 152 continue; 153 // Do not hoist instructions if the sources' def under AMX instruction. 154 // TODO: We can handle isMoveImmediate MI here. 155 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) 156 return false; 157 // TODO: Maybe need more checks here. 158 } 159 MBB->insert(InsertPoint, I->MI->removeFromParent()); 160 } 161 // We only need to mark the last shape in the BB now. 162 Shapes.clear(); 163 Shapes.push_back(MIRef(&*--InsertPoint, MBB)); 164 return true; 165 } 166 167 public: 168 X86PreTileConfig() : MachineFunctionPass(ID) {} 169 170 /// Return the pass name. 171 StringRef getPassName() const override { 172 return "Tile Register Pre-configure"; 173 } 174 175 /// X86PreTileConfig analysis usage. 176 void getAnalysisUsage(AnalysisUsage &AU) const override { 177 AU.setPreservesAll(); 178 AU.addRequired<MachineLoopInfo>(); 179 MachineFunctionPass::getAnalysisUsage(AU); 180 } 181 182 /// Clear MF related structures. 183 void releaseMemory() override { 184 ShapeBBs.clear(); 185 DefVisited.clear(); 186 BBVisitedInfo.clear(); 187 } 188 189 /// Perform ldtilecfg instructions inserting. 190 bool runOnMachineFunction(MachineFunction &MF) override; 191 192 static char ID; 193 }; 194 195 } // end anonymous namespace 196 197 char X86PreTileConfig::ID = 0; 198 199 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 200 "Tile Register Pre-configure", false, false) 201 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 202 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 203 "Tile Register Pre-configure", false, false) 204 205 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 206 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 207 MIRef MIR(MI, MBB); 208 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); 209 if (I == ShapeBBs[MBB].end() || *I != MIR) 210 ShapeBBs[MBB].insert(I, MIR); 211 }; 212 213 SmallVector<Register, 8> WorkList( 214 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 215 while (!WorkList.empty()) { 216 Register R = WorkList.pop_back_val(); 217 MachineInstr *DefMI = MRI->getVRegDef(R); 218 assert(DefMI && "R must has one define instruction"); 219 MachineBasicBlock *DefMBB = DefMI->getParent(); 220 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 221 continue; 222 if (DefMI->isPHI()) { 223 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 224 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 225 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 226 else 227 WorkList.push_back(DefMI->getOperand(I).getReg()); 228 } else { 229 RecordShape(DefMI, DefMBB); 230 } 231 } 232 } 233 234 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 235 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 236 const TargetInstrInfo *TII = ST.getInstrInfo(); 237 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 238 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 239 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); 240 241 BitVector AMXRegs(TRI->getNumRegs()); 242 for (unsigned I = 0; I < RC->getNumRegs(); I++) 243 AMXRegs.set(X86::TMM0 + I); 244 245 // Iterate MF to collect information. 246 MRI = &MF.getRegInfo(); 247 MLI = &getAnalysis<MachineLoopInfo>(); 248 SmallSet<MIRef, 8> CfgNeedInsert; 249 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 250 for (auto &MBB : MF) { 251 size_t Pos = 0; 252 for (auto &MI : MBB) { 253 ++Pos; 254 if (isAMXInstruction(MI)) { 255 // If there's call before the AMX, we need to reload tile config. 256 if (BBVisitedInfo[&MBB].LastCall) 257 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 258 else // Otherwise, we need tile config to live in this BB. 259 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 260 // Always record the first AMX in case there's shape def after it. 261 if (!BBVisitedInfo[&MBB].FirstAMX) 262 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 263 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 264 // Record the call only if the callee clobbers all AMX registers. 265 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 266 } 267 } 268 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 269 if (&MBB == &MF.front()) 270 CfgNeedInsert.insert(MIRef(&MBB)); 271 else 272 CfgLiveInBBs.push_back(&MBB); 273 } 274 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) 275 for (auto *Succ : MBB.successors()) 276 if (!isLoopBackEdge(Succ, &MBB)) 277 BBVisitedInfo[Succ].HasAMXRegLiveIn = true; 278 } 279 280 // Update NeedTileCfgLiveIn for predecessors. 281 while (!CfgLiveInBBs.empty()) { 282 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 283 for (auto *Pred : MBB->predecessors()) { 284 if (BBVisitedInfo[Pred].LastCall) { 285 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 286 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 287 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 288 if (Pred == &MF.front()) 289 CfgNeedInsert.insert(MIRef(Pred)); 290 else 291 CfgLiveInBBs.push_back(Pred); 292 } 293 } 294 } 295 296 // There's no AMX instruction if we didn't find a tile config live in point. 297 if (CfgNeedInsert.empty()) 298 return false; 299 X86FI->setHasVirtualTileReg(true); 300 301 // Avoid to insert ldtilecfg before any shape defs. 302 SmallVector<MachineBasicBlock *, 8> WorkList; 303 for (auto &I : ShapeBBs) { 304 // TODO: We can hoist shapes across BBs here. 305 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) 306 REPORT_CONFIG_FAIL 307 if (BBVisitedInfo[I.first].FirstAMX && 308 BBVisitedInfo[I.first].FirstAMX < I.second.back() && 309 !hoistShapesInBB(I.first, I.second)) 310 REPORT_CONFIG_FAIL 311 WorkList.push_back(I.first); 312 } 313 while (!WorkList.empty()) { 314 MachineBasicBlock *MBB = WorkList.pop_back_val(); 315 for (auto *Pred : MBB->predecessors()) { 316 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 317 BBVisitedInfo[Pred].TileCfgForbidden = true; 318 WorkList.push_back(Pred); 319 } 320 } 321 } 322 323 DebugLoc DL; 324 SmallSet<MIRef, 8> VisitedOrInserted; 325 int SS = MF.getFrameInfo().CreateStackObject( 326 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 327 328 // Try to insert for the tile config live in points. 329 for (const auto &I : CfgNeedInsert) { 330 SmallSet<MIRef, 8> InsertPoints; 331 SmallVector<MIRef, 8> WorkList({I}); 332 while (!WorkList.empty()) { 333 MIRef I = WorkList.pop_back_val(); 334 if (!VisitedOrInserted.count(I)) { 335 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 336 // If the BB is all shapes reachable, stop sink and try to insert. 337 InsertPoints.insert(I); 338 } else { 339 // Avoid the BB to be multi visited. 340 VisitedOrInserted.insert(I); 341 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 342 // true when MBB isn't all shapes reachable. 343 for (auto *Succ : I.MBB->successors()) 344 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 345 WorkList.push_back(MIRef(Succ)); 346 } 347 } 348 } 349 350 // A given point might be forked due to shape conditions are not met. 351 for (MIRef I : InsertPoints) { 352 // Make sure we insert ldtilecfg after the last shape def in MBB. 353 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) 354 I = ShapeBBs[I.MBB].back(); 355 // There're chances the MBB is sunk more than once. Record it to avoid 356 // multi insert. 357 if (VisitedOrInserted.insert(I).second) { 358 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 359 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), 360 SS); 361 } 362 } 363 } 364 365 // Zero stack slot. 366 MachineBasicBlock &MBB = MF.front(); 367 MachineInstr *MI = &*MBB.begin(); 368 if (ST.hasAVX512()) { 369 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 370 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) 371 .addReg(Zmm, RegState::Undef) 372 .addReg(Zmm, RegState::Undef); 373 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 374 .addReg(Zmm); 375 } else if (ST.hasAVX2()) { 376 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 377 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) 378 .addReg(Ymm, RegState::Undef) 379 .addReg(Ymm, RegState::Undef); 380 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 381 .addReg(Ymm); 382 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 383 .addReg(Ymm); 384 } else { 385 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 386 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 387 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) 388 .addReg(Xmm, RegState::Undef) 389 .addReg(Xmm, RegState::Undef); 390 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) 391 .addReg(Xmm); 392 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) 393 .addReg(Xmm); 394 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) 395 .addReg(Xmm); 396 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) 397 .addReg(Xmm); 398 } 399 // Fill in the palette first. 400 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 401 402 return true; 403 } 404 405 FunctionPass *llvm::createX86PreTileConfigPass() { 406 return new X86PreTileConfig(); 407 } 408