//===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// \file Pass to pre-config the shapes of AMX registers /// AMX register needs to be configured before use. The shapes of AMX register /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. /// /// The instruction ldtilecfg is used to config the shapes. It must be reachable /// for all variable shapes. ldtilecfg will be inserted more than once if we /// cannot find a dominating point for all AMX instructions. /// /// The configure register is caller saved according to ABI. We need to insert /// ldtilecfg again after the call instruction if callee clobbers any AMX /// registers. /// /// This pass calculates all points that ldtilecfg need to be inserted to and /// insert them. It reports error if the reachability conditions aren't met. // //===----------------------------------------------------------------------===// #include "X86.h" #include "X86InstrBuilder.h" #include "X86MachineFunctionInfo.h" #include "X86RegisterInfo.h" #include "X86Subtarget.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/InitializePasses.h" using namespace llvm; #define DEBUG_TYPE "tile-pre-config" #define REPORT_CONFIG_FAIL \ report_fatal_error( \ MF.getName() + \ ": Failed to config tile register, please define the shape earlier"); namespace { struct MIRef { MachineInstr *MI = nullptr; MachineBasicBlock *MBB = nullptr; // A virtual position for instruction that will be inserted after MI. size_t Pos = 0; MIRef() = default; MIRef(MachineBasicBlock *MBB) : MBB(MBB) { for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); ++I, ++Pos) MI = &*I; } MIRef(MachineInstr *MI) : MI(MI), MBB(MI->getParent()), Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} MIRef(MachineInstr *MI, MachineBasicBlock *MBB) : MI(MI), MBB(MBB), Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) : MI(MI), MBB(MBB), Pos(Pos) {} operator bool() const { return MBB != nullptr; } bool operator==(const MIRef &RHS) const { return MI == RHS.MI && MBB == RHS.MBB; } bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } bool operator<(const MIRef &RHS) const { // Comparison between different BBs happens when inserting a MIRef into set. // So we compare MBB first to make the insertion happy. return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); } bool operator>(const MIRef &RHS) const { // Comparison between different BBs happens when inserting a MIRef into set. // So we compare MBB first to make the insertion happy. return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); } }; struct BBInfo { MIRef FirstAMX; MIRef LastCall; bool HasAMXRegLiveIn = false; bool TileCfgForbidden = false; bool NeedTileCfgLiveIn = false; }; class X86PreTileConfig : public MachineFunctionPass { MachineRegisterInfo *MRI; const MachineLoopInfo *MLI; SmallSet DefVisited; DenseMap BBVisitedInfo; DenseMap> ShapeBBs; /// Check if the callee will clobber AMX registers. bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { auto Iter = llvm::find_if( MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); if (Iter == MI.operands_end()) return false; UsableRegs.clearBitsInMask(Iter->getRegMask()); return !UsableRegs.none(); } /// Check if MI is AMX pseudo instruction. bool isAMXInstruction(MachineInstr &MI) { if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) return false; MachineOperand &MO = MI.getOperand(0); // We can simply check if it is AMX instruction by its def. // But we should exclude old API which uses physical registers. if (MO.isReg() && MO.getReg().isVirtual() && MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { collectShapeInfo(MI); return true; } // PTILESTOREDV is the only exception that doesn't def a AMX register. return MI.getOpcode() == X86::PTILESTOREDV; } /// Check if it is an edge from loop bottom to loop head. bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { if (!MLI->isLoopHeader(Header)) return false; auto *ML = MLI->getLoopFor(Header); if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) return true; return false; } /// Collect the shape def information for later use. void collectShapeInfo(MachineInstr &MI); /// Try to hoist shapes definded below AMX instructions. bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl &Shapes) { MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); auto InsertPoint = FirstAMX.MI->getIterator(); for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { // Do not hoist instructions that access memory. if (I->MI->mayLoadOrStore()) return false; for (auto &MO : I->MI->operands()) { if (MO.isDef()) continue; // Do not hoist instructions if the sources' def under AMX instruction. // TODO: We can handle isMoveImmediate MI here. if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) return false; // TODO: Maybe need more checks here. } MBB->insert(InsertPoint, I->MI->removeFromParent()); } // We only need to mark the last shape in the BB now. Shapes.clear(); Shapes.push_back(MIRef(&*--InsertPoint, MBB)); return true; } public: X86PreTileConfig() : MachineFunctionPass(ID) {} /// Return the pass name. StringRef getPassName() const override { return "Tile Register Pre-configure"; } /// X86PreTileConfig analysis usage. void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } /// Clear MF related structures. void releaseMemory() override { ShapeBBs.clear(); DefVisited.clear(); BBVisitedInfo.clear(); } /// Perform ldtilecfg instructions inserting. bool runOnMachineFunction(MachineFunction &MF) override; static char ID; }; } // end anonymous namespace char X86PreTileConfig::ID = 0; INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", "Tile Register Pre-configure", false, false) INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", "Tile Register Pre-configure", false, false) void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { MIRef MIR(MI, MBB); auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); if (I == ShapeBBs[MBB].end() || *I != MIR) ShapeBBs[MBB].insert(I, MIR); }; SmallVector WorkList( {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); while (!WorkList.empty()) { Register R = WorkList.pop_back_val(); MachineInstr *DefMI = MRI->getVRegDef(R); assert(DefMI && "R must has one define instruction"); MachineBasicBlock *DefMBB = DefMI->getParent(); if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) continue; if (DefMI->isPHI()) { for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. else WorkList.push_back(DefMI->getOperand(I).getReg()); } else { RecordShape(DefMI, DefMBB); } } } bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { const X86Subtarget &ST = MF.getSubtarget(); const TargetInstrInfo *TII = ST.getInstrInfo(); const TargetRegisterInfo *TRI = ST.getRegisterInfo(); const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); X86MachineFunctionInfo *X86FI = MF.getInfo(); BitVector AMXRegs(TRI->getNumRegs()); for (unsigned I = 0; I < RC->getNumRegs(); I++) AMXRegs.set(X86::TMM0 + I); // Iterate MF to collect information. MRI = &MF.getRegInfo(); MLI = &getAnalysis(); SmallSet CfgNeedInsert; SmallVector CfgLiveInBBs; for (auto &MBB : MF) { size_t Pos = 0; for (auto &MI : MBB) { ++Pos; if (isAMXInstruction(MI)) { // If there's call before the AMX, we need to reload tile config. if (BBVisitedInfo[&MBB].LastCall) CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); else // Otherwise, we need tile config to live in this BB. BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; // Always record the first AMX in case there's shape def after it. if (!BBVisitedInfo[&MBB].FirstAMX) BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { // Record the call only if the callee clobbers all AMX registers. BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); } } if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { if (&MBB == &MF.front()) CfgNeedInsert.insert(MIRef(&MBB)); else CfgLiveInBBs.push_back(&MBB); } if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) for (auto *Succ : MBB.successors()) if (!isLoopBackEdge(Succ, &MBB)) BBVisitedInfo[Succ].HasAMXRegLiveIn = true; } // Update NeedTileCfgLiveIn for predecessors. while (!CfgLiveInBBs.empty()) { MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); for (auto *Pred : MBB->predecessors()) { if (BBVisitedInfo[Pred].LastCall) { CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; if (Pred == &MF.front()) CfgNeedInsert.insert(MIRef(Pred)); else CfgLiveInBBs.push_back(Pred); } } } // There's no AMX instruction if we didn't find a tile config live in point. if (CfgNeedInsert.empty()) return false; X86FI->setHasVirtualTileReg(true); // Avoid to insert ldtilecfg before any shape defs. SmallVector WorkList; for (auto &I : ShapeBBs) { // TODO: We can hoist shapes across BBs here. if (BBVisitedInfo[I.first].HasAMXRegLiveIn) REPORT_CONFIG_FAIL if (BBVisitedInfo[I.first].FirstAMX && BBVisitedInfo[I.first].FirstAMX < I.second.back() && !hoistShapesInBB(I.first, I.second)) REPORT_CONFIG_FAIL WorkList.push_back(I.first); } while (!WorkList.empty()) { MachineBasicBlock *MBB = WorkList.pop_back_val(); for (auto *Pred : MBB->predecessors()) { if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { BBVisitedInfo[Pred].TileCfgForbidden = true; WorkList.push_back(Pred); } } } DebugLoc DL; SmallSet VisitedOrInserted; int SS = MF.getFrameInfo().CreateStackObject( ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); // Try to insert for the tile config live in points. for (const auto &I : CfgNeedInsert) { SmallSet InsertPoints; SmallVector WorkList({I}); while (!WorkList.empty()) { MIRef I = WorkList.pop_back_val(); if (!VisitedOrInserted.count(I)) { if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { // If the BB is all shapes reachable, stop sink and try to insert. InsertPoints.insert(I); } else { // Avoid the BB to be multi visited. VisitedOrInserted.insert(I); // Sink the inserting point along the chain with NeedTileCfgLiveIn = // true when MBB isn't all shapes reachable. for (auto *Succ : I.MBB->successors()) if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) WorkList.push_back(MIRef(Succ)); } } } // A given point might be forked due to shape conditions are not met. for (MIRef I : InsertPoints) { // Make sure we insert ldtilecfg after the last shape def in MBB. if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) I = ShapeBBs[I.MBB].back(); // There're chances the MBB is sunk more than once. Record it to avoid // multi insert. if (VisitedOrInserted.insert(I).second) { auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), SS); } } } // Zero stack slot. MachineBasicBlock &MBB = MF.front(); MachineInstr *MI = &*MBB.begin(); if (ST.hasAVX512()) { Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) .addReg(Zmm, RegState::Undef) .addReg(Zmm, RegState::Undef); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) .addReg(Zmm); } else if (ST.hasAVX2()) { Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) .addReg(Ymm, RegState::Undef) .addReg(Ymm, RegState::Undef); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) .addReg(Ymm); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) .addReg(Ymm); } else { assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) .addReg(Xmm, RegState::Undef) .addReg(Xmm, RegState::Undef); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) .addReg(Xmm); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) .addReg(Xmm); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) .addReg(Xmm); addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) .addReg(Xmm); } // Fill in the palette first. addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); return true; } FunctionPass *llvm::createX86PreTileConfigPass() { return new X86PreTileConfig(); }