1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86MachineFunctionInfo.h" 29 #include "X86RegisterInfo.h" 30 #include "X86Subtarget.h" 31 #include "llvm/ADT/SmallSet.h" 32 #include "llvm/CodeGen/MachineFunctionPass.h" 33 #include "llvm/CodeGen/MachineInstr.h" 34 #include "llvm/CodeGen/MachineLoopInfo.h" 35 #include "llvm/CodeGen/MachineModuleInfo.h" 36 #include "llvm/CodeGen/MachineRegisterInfo.h" 37 #include "llvm/CodeGen/Passes.h" 38 #include "llvm/CodeGen/TargetInstrInfo.h" 39 #include "llvm/CodeGen/TargetRegisterInfo.h" 40 #include "llvm/InitializePasses.h" 41 42 using namespace llvm; 43 44 #define DEBUG_TYPE "tile-pre-config" 45 46 static void emitErrorMsg(MachineFunction &MF) { 47 LLVMContext &Context = MF.getMMI().getModule()->getContext(); 48 Context.emitError( 49 MF.getName() + 50 ": Failed to config tile register, please define the shape earlier"); 51 } 52 53 namespace { 54 55 struct MIRef { 56 MachineInstr *MI = nullptr; 57 MachineBasicBlock *MBB = nullptr; 58 // A virtual position for instruction that will be inserted after MI. 59 size_t Pos = 0; 60 MIRef() = default; 61 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 62 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 63 ++I, ++Pos) 64 MI = &*I; 65 } 66 MIRef(MachineInstr *MI) 67 : MI(MI), MBB(MI->getParent()), 68 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 69 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 70 : MI(MI), MBB(MBB), 71 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 72 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 73 : MI(MI), MBB(MBB), Pos(Pos) {} 74 operator bool() const { return MBB != nullptr; } 75 bool operator==(const MIRef &RHS) const { 76 return MI == RHS.MI && MBB == RHS.MBB; 77 } 78 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } 79 bool operator<(const MIRef &RHS) const { 80 // Comparison between different BBs happens when inserting a MIRef into set. 81 // So we compare MBB first to make the insertion happy. 82 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 83 } 84 bool operator>(const MIRef &RHS) const { 85 // Comparison between different BBs happens when inserting a MIRef into set. 86 // So we compare MBB first to make the insertion happy. 87 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 88 } 89 }; 90 91 struct BBInfo { 92 MIRef FirstAMX; 93 MIRef LastCall; 94 bool HasAMXRegLiveIn = false; 95 bool TileCfgForbidden = false; 96 bool NeedTileCfgLiveIn = false; 97 }; 98 99 class X86PreTileConfig : public MachineFunctionPass { 100 MachineRegisterInfo *MRI = nullptr; 101 const MachineLoopInfo *MLI = nullptr; 102 SmallSet<MachineInstr *, 8> DefVisited; 103 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 104 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; 105 106 /// Check if the callee will clobber AMX registers. 107 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 108 auto Iter = llvm::find_if( 109 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 110 if (Iter == MI.operands_end()) 111 return false; 112 UsableRegs.clearBitsInMask(Iter->getRegMask()); 113 return !UsableRegs.none(); 114 } 115 116 /// Check if MI is AMX pseudo instruction. 117 bool isAMXInstruction(MachineInstr &MI) { 118 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 119 return false; 120 MachineOperand &MO = MI.getOperand(0); 121 // We can simply check if it is AMX instruction by its def. 122 // But we should exclude old API which uses physical registers. 123 if (MO.isReg() && MO.getReg().isVirtual() && 124 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 125 collectShapeInfo(MI); 126 return true; 127 } 128 // PTILESTOREDV is the only exception that doesn't def a AMX register. 129 return MI.getOpcode() == X86::PTILESTOREDV; 130 } 131 132 /// Check if it is an edge from loop bottom to loop head. 133 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 134 if (!MLI->isLoopHeader(Header)) 135 return false; 136 auto *ML = MLI->getLoopFor(Header); 137 if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) 138 return true; 139 140 return false; 141 } 142 143 /// Collect the shape def information for later use. 144 void collectShapeInfo(MachineInstr &MI); 145 146 /// Try to hoist shapes definded below AMX instructions. 147 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { 148 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; 149 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); 150 auto InsertPoint = FirstAMX.MI->getIterator(); 151 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { 152 // Do not hoist instructions that access memory. 153 if (I->MI->mayLoadOrStore()) 154 return false; 155 for (auto &MO : I->MI->operands()) { 156 if (MO.isDef()) 157 continue; 158 // Do not hoist instructions if the sources' def under AMX instruction. 159 // TODO: We can handle isMoveImmediate MI here. 160 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) 161 return false; 162 // TODO: Maybe need more checks here. 163 } 164 MBB->insert(InsertPoint, I->MI->removeFromParent()); 165 } 166 // We only need to mark the last shape in the BB now. 167 Shapes.clear(); 168 Shapes.push_back(MIRef(&*--InsertPoint, MBB)); 169 return true; 170 } 171 172 public: 173 X86PreTileConfig() : MachineFunctionPass(ID) {} 174 175 /// Return the pass name. 176 StringRef getPassName() const override { 177 return "Tile Register Pre-configure"; 178 } 179 180 /// X86PreTileConfig analysis usage. 181 void getAnalysisUsage(AnalysisUsage &AU) const override { 182 AU.setPreservesAll(); 183 AU.addRequired<MachineLoopInfo>(); 184 MachineFunctionPass::getAnalysisUsage(AU); 185 } 186 187 /// Clear MF related structures. 188 void releaseMemory() override { 189 ShapeBBs.clear(); 190 DefVisited.clear(); 191 BBVisitedInfo.clear(); 192 } 193 194 /// Perform ldtilecfg instructions inserting. 195 bool runOnMachineFunction(MachineFunction &MF) override; 196 197 static char ID; 198 }; 199 200 } // end anonymous namespace 201 202 char X86PreTileConfig::ID = 0; 203 204 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 205 "Tile Register Pre-configure", false, false) 206 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 207 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 208 "Tile Register Pre-configure", false, false) 209 210 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 211 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 212 MIRef MIR(MI, MBB); 213 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); 214 if (I == ShapeBBs[MBB].end() || *I != MIR) 215 ShapeBBs[MBB].insert(I, MIR); 216 }; 217 218 SmallVector<Register, 8> WorkList( 219 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 220 while (!WorkList.empty()) { 221 Register R = WorkList.pop_back_val(); 222 MachineInstr *DefMI = MRI->getVRegDef(R); 223 assert(DefMI && "R must has one define instruction"); 224 MachineBasicBlock *DefMBB = DefMI->getParent(); 225 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 226 continue; 227 if (DefMI->isPHI()) { 228 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 229 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 230 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 231 else 232 WorkList.push_back(DefMI->getOperand(I).getReg()); 233 } else { 234 RecordShape(DefMI, DefMBB); 235 } 236 } 237 } 238 239 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 240 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 241 const TargetInstrInfo *TII = ST.getInstrInfo(); 242 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 243 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 244 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); 245 246 BitVector AMXRegs(TRI->getNumRegs()); 247 for (unsigned I = 0; I < RC->getNumRegs(); I++) 248 AMXRegs.set(X86::TMM0 + I); 249 250 // Iterate MF to collect information. 251 MRI = &MF.getRegInfo(); 252 MLI = &getAnalysis<MachineLoopInfo>(); 253 SmallSet<MIRef, 8> CfgNeedInsert; 254 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 255 for (auto &MBB : MF) { 256 size_t Pos = 0; 257 for (auto &MI : MBB) { 258 ++Pos; 259 if (isAMXInstruction(MI)) { 260 // If there's call before the AMX, we need to reload tile config. 261 if (BBVisitedInfo[&MBB].LastCall) 262 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 263 else // Otherwise, we need tile config to live in this BB. 264 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 265 // Always record the first AMX in case there's shape def after it. 266 if (!BBVisitedInfo[&MBB].FirstAMX) 267 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 268 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 269 // Record the call only if the callee clobbers all AMX registers. 270 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 271 } 272 } 273 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 274 if (&MBB == &MF.front()) 275 CfgNeedInsert.insert(MIRef(&MBB)); 276 else 277 CfgLiveInBBs.push_back(&MBB); 278 } 279 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) 280 for (auto *Succ : MBB.successors()) 281 if (!isLoopBackEdge(Succ, &MBB)) 282 BBVisitedInfo[Succ].HasAMXRegLiveIn = true; 283 } 284 285 // Update NeedTileCfgLiveIn for predecessors. 286 while (!CfgLiveInBBs.empty()) { 287 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 288 for (auto *Pred : MBB->predecessors()) { 289 if (BBVisitedInfo[Pred].LastCall) { 290 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 291 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 292 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 293 if (Pred == &MF.front()) 294 CfgNeedInsert.insert(MIRef(Pred)); 295 else 296 CfgLiveInBBs.push_back(Pred); 297 } 298 } 299 } 300 301 // There's no AMX instruction if we didn't find a tile config live in point. 302 if (CfgNeedInsert.empty()) 303 return false; 304 X86FI->setHasVirtualTileReg(true); 305 306 // Avoid to insert ldtilecfg before any shape defs. 307 SmallVector<MachineBasicBlock *, 8> WorkList; 308 for (auto &I : ShapeBBs) { 309 // TODO: We can hoist shapes across BBs here. 310 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) { 311 // We are not able to config tile registers since the shape to config 312 // is not defined yet. Emit error message and continue. The function 313 // would not config tile registers. 314 emitErrorMsg(MF); 315 return false; 316 } 317 if (BBVisitedInfo[I.first].FirstAMX && 318 BBVisitedInfo[I.first].FirstAMX < I.second.back() && 319 !hoistShapesInBB(I.first, I.second)) { 320 emitErrorMsg(MF); 321 return false; 322 } 323 WorkList.push_back(I.first); 324 } 325 while (!WorkList.empty()) { 326 MachineBasicBlock *MBB = WorkList.pop_back_val(); 327 for (auto *Pred : MBB->predecessors()) { 328 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 329 BBVisitedInfo[Pred].TileCfgForbidden = true; 330 WorkList.push_back(Pred); 331 } 332 } 333 } 334 335 DebugLoc DL; 336 SmallSet<MIRef, 8> VisitedOrInserted; 337 int SS = MF.getFrameInfo().CreateStackObject( 338 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 339 340 // Try to insert for the tile config live in points. 341 for (const auto &I : CfgNeedInsert) { 342 SmallSet<MIRef, 8> InsertPoints; 343 SmallVector<MIRef, 8> WorkList({I}); 344 while (!WorkList.empty()) { 345 MIRef I = WorkList.pop_back_val(); 346 if (!VisitedOrInserted.count(I)) { 347 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 348 // If the BB is all shapes reachable, stop sink and try to insert. 349 InsertPoints.insert(I); 350 } else { 351 // Avoid the BB to be multi visited. 352 VisitedOrInserted.insert(I); 353 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 354 // true when MBB isn't all shapes reachable. 355 for (auto *Succ : I.MBB->successors()) 356 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 357 WorkList.push_back(MIRef(Succ)); 358 } 359 } 360 } 361 362 // A given point might be forked due to shape conditions are not met. 363 for (MIRef I : InsertPoints) { 364 // Make sure we insert ldtilecfg after the last shape def in MBB. 365 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) 366 I = ShapeBBs[I.MBB].back(); 367 // There're chances the MBB is sunk more than once. Record it to avoid 368 // multi insert. 369 if (VisitedOrInserted.insert(I).second) { 370 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 371 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)), 372 SS); 373 } 374 } 375 } 376 377 // Zero stack slot. 378 MachineBasicBlock &MBB = MF.front(); 379 MachineInstr *MI = &*MBB.begin(); 380 if (ST.hasAVX512()) { 381 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 382 BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm); 383 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 384 .addReg(Zmm); 385 } else if (ST.hasAVX2()) { 386 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 387 BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm); 388 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 389 .addReg(Ymm); 390 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 391 .addReg(Ymm); 392 } else { 393 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 394 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr; 395 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 396 BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm); 397 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm); 398 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16) 399 .addReg(Xmm); 400 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32) 401 .addReg(Xmm); 402 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48) 403 .addReg(Xmm); 404 } 405 // Fill in the palette first. 406 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 407 408 return true; 409 } 410 411 FunctionPass *llvm::createX86PreTileConfigPass() { 412 return new X86PreTileConfig(); 413 } 414