xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp (revision 1342eb5a832fa10e689a29faab3acb6054e4778c)
1 //===-- AMDGPURewriteAGPRCopyMFMA.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file \brief Try to replace MFMA instructions using VGPRs with MFMA
10 /// instructions using AGPRs. We expect MFMAs to be selected using VGPRs, and
11 /// only use AGPRs if it helps avoid spilling. In this case, the MFMA will have
12 /// copies between AGPRs and VGPRs and the AGPR variant of an MFMA pseudo. This
13 /// pass will attempt to delete the cross register bank copy and replace the
14 /// MFMA opcode.
15 ///
16 /// TODO:
17 ///  - Handle non-tied dst+src2 cases. We need to try to find a copy from an
18 ///    AGPR from src2, or reassign src2 to an available AGPR (which should work
19 ///    in the common case of a load).
20 ///
21 ///  - Handle multiple MFMA uses of the same register. e.g. chained MFMAs that
22 ///    can be rewritten as a set
23 ///
24 ///  - Update LiveIntervals incrementally instead of recomputing from scratch
25 ///
26 //===----------------------------------------------------------------------===//
27 
28 #include "AMDGPU.h"
29 #include "GCNSubtarget.h"
30 #include "SIMachineFunctionInfo.h"
31 #include "SIRegisterInfo.h"
32 #include "llvm/CodeGen/LiveIntervals.h"
33 #include "llvm/CodeGen/LiveRegMatrix.h"
34 #include "llvm/CodeGen/MachineFunctionPass.h"
35 #include "llvm/CodeGen/VirtRegMap.h"
36 #include "llvm/InitializePasses.h"
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "amdgpu-rewrite-agpr-copy-mfma"
41 
42 namespace {
43 
44 class AMDGPURewriteAGPRCopyMFMAImpl {
45   const GCNSubtarget &ST;
46   const SIInstrInfo &TII;
47   const SIRegisterInfo &TRI;
48   MachineRegisterInfo &MRI;
49   VirtRegMap &VRM;
50   LiveRegMatrix ‎
51   LiveIntervals &LIS;
52 
53 public:
54   AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
55                                 LiveRegMatrix &LRM, LiveIntervals &LIS)
56       : ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
57         TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
58         LIS(LIS) {}
59 
60   /// Compute the register class constraints based on the uses of \p Reg,
61   /// excluding uses from \p ExceptMI. This should be nearly identical to
62   /// MachineRegisterInfo::recomputeRegClass.
63   const TargetRegisterClass *
64   recomputeRegClassExcept(Register Reg, const TargetRegisterClass *OldRC,
65                           const TargetRegisterClass *NewRC,
66                           const MachineInstr *ExceptMI) const;
67 
68   bool run(MachineFunction &MF) const;
69 };
70 
71 const TargetRegisterClass *
72 AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExcept(
73     Register Reg, const TargetRegisterClass *OldRC,
74     const TargetRegisterClass *NewRC, const MachineInstr *ExceptMI) const {
75 
76   // Accumulate constraints from all uses.
77   for (MachineOperand &MO : MRI.reg_nodbg_operands(Reg)) {
78     // Apply the effect of the given operand to NewRC.
79     MachineInstr *MI = MO.getParent();
80     if (MI == ExceptMI)
81       continue;
82 
83     unsigned OpNo = &MO - &MI->getOperand(0);
84     NewRC = MI->getRegClassConstraintEffect(OpNo, NewRC, &TII, &TRI);
85     if (!NewRC || NewRC == OldRC)
86       return nullptr;
87   }
88 
89   return NewRC;
90 }
91 
92 bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
93   // This only applies on subtargets that have a configurable AGPR vs. VGPR
94   // allocation.
95   if (!ST.hasGFX90AInsts())
96     return false;
97 
98   // Early exit if no AGPRs were assigned.
99   if (!LRM.isPhysRegUsed(AMDGPU::AGPR0))
100     return false;
101 
102   bool MadeChange = false;
103 
104   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
105     Register VReg = Register::index2VirtReg(I);
106     Register PhysReg = VRM.getPhys(VReg);
107     if (!PhysReg)
108       continue;
109 
110     // Find AV_* registers assigned to AGPRs.
111     const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
112     if (!TRI.isVectorSuperClass(VirtRegRC))
113       continue;
114 
115     const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
116     if (!TRI.isAGPRClass(AssignedRC))
117       continue;
118 
119     LiveInterval &LI = LIS.getInterval(VReg);
120 
121     // TODO: Test multiple uses
122     for (VNInfo *VNI : LI.vnis()) {
123       MachineInstr *DefMI = LIS.getInstructionFromIndex(VNI->def);
124 
125       // TODO: Handle SplitKit produced copy bundles for partially defined
126       // registers.
127       if (!DefMI || !DefMI->isFullCopy())
128         continue;
129 
130       Register CopySrcReg = DefMI->getOperand(1).getReg();
131       if (!CopySrcReg.isVirtual())
132         continue;
133 
134       LiveInterval &CopySrcLI = LIS.getInterval(CopySrcReg);
135       LiveQueryResult LRQ = CopySrcLI.Query(VNI->def.getRegSlot());
136       MachineInstr *CopySrcMI = LIS.getInstructionFromIndex(LRQ.valueIn()->def);
137       if (!CopySrcMI)
138         continue;
139 
140       int AGPROp = AMDGPU::getMFMASrcCVDstAGPROp(CopySrcMI->getOpcode());
141       if (AGPROp == -1)
142         continue;
143 
144       MachineOperand *Src2 =
145           TII.getNamedOperand(*CopySrcMI, AMDGPU::OpName::src2);
146 
147       // FIXME: getMinimalPhysRegClass returns a nonsense AV_* subclass instead
148       // of an AGPR or VGPR subclass, so we can't simply use the result on the
149       // assignment.
150 
151       LLVM_DEBUG({
152         Register Src2PhysReg = VRM.getPhys(Src2->getReg());
153         dbgs() << "Attempting to replace VGPR MFMA with AGPR version:"
154                << " Dst=[" << printReg(VReg) << " => "
155                << printReg(PhysReg, &TRI) << "], Src2=["
156                << printReg(Src2->getReg(), &TRI) << " => "
157                << printReg(Src2PhysReg, &TRI) << "]: " << *CopySrcMI;
158       });
159 
160       // If the inputs are tied and the same register, we can shortcut and
161       // directly replace the register.
162       if (Src2->getReg() != CopySrcReg) {
163         LLVM_DEBUG(
164             dbgs()
165             << "Replacing untied VGPR MFMAs with AGPR form not yet handled\n");
166         // TODO: Only handles the tied case for now. If the input operand is a
167         // different register, we need to also reassign it (either by looking
168         // for a compatible copy-from-AGPR, or by seeing if an available AGPR is
169         // compatible with all other uses.
170 
171         // If we can't reassign it, we'd need to introduce a different copy
172         // which is likely worse than the copy we'd be saving.
173         continue;
174       }
175 
176       const TargetRegisterClass *Src2VirtRegRC =
177           MRI.getRegClass(Src2->getReg());
178 
179       // We've found av = COPY (MFMA), and need to verify that we can trivially
180       // rewrite src2 to use the new AGPR. If we can't trivially replace it,
181       // we're going to induce as many copies as we would have emitted in the
182       // first place, as well as need to assign another register, and need to
183       // figure out where to put them. The live range splitting is smarter than
184       // anything we're doing here, so trust it did something reasonable.
185       const TargetRegisterClass *Src2ExceptRC = recomputeRegClassExcept(
186           Src2->getReg(), Src2VirtRegRC, VirtRegRC, CopySrcMI);
187       if (!Src2ExceptRC)
188         continue;
189 
190       const TargetRegisterClass *NewSrc2ConstraintRC =
191           TII.getRegClass(TII.get(AGPROp), Src2->getOperandNo(), &TRI, MF);
192 
193       // Try to constrain src2 to the replacement instruction candidate's
194       // register class.
195       const TargetRegisterClass *NewSrc2RC =
196           TRI.getCommonSubClass(Src2ExceptRC, NewSrc2ConstraintRC);
197       if (!NewSrc2RC) {
198         // TODO: This is ignoring ther rewritable uses. e.g. a rewritable MFMA
199         // using a rewritable MFMA can be rewritten as a pair.
200         LLVM_DEBUG(dbgs() << "Other uses of " << printReg(Src2->getReg(), &TRI)
201                           << " are incompatible with replacement class\n");
202         continue;
203       }
204 
205       MRI.setRegClass(VReg, AssignedRC);
206       MRI.setRegClass(Src2->getReg(), NewSrc2RC);
207 
208       CopySrcMI->setDesc(TII.get(AGPROp));
209 
210       // TODO: Is replacing too aggressive, fixup these instructions only?
211       MRI.replaceRegWith(CopySrcReg, VReg);
212 
213       LLVM_DEBUG(dbgs() << "Replaced VGPR MFMA with AGPR: " << *CopySrcMI);
214 
215       // We left behind an identity copy, so delete it.
216       LIS.RemoveMachineInstrFromMaps(*DefMI);
217       DefMI->eraseFromParent();
218 
219       LRM.unassign(CopySrcLI);
220 
221       // We don't need the liveness information anymore, so don't bother
222       // updating the intervals. Just delete the stale information.
223       // TODO: Is it worth preserving these?
224       LIS.removeInterval(CopySrcReg);
225       LIS.removeInterval(VReg);
226       LIS.createAndComputeVirtRegInterval(VReg);
227 
228       MadeChange = true;
229     }
230   }
231 
232   return MadeChange;
233 }
234 
235 class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
236 public:
237   static char ID;
238 
239   AMDGPURewriteAGPRCopyMFMALegacy() : MachineFunctionPass(ID) {
240     initializeAMDGPURewriteAGPRCopyMFMALegacyPass(
241         *PassRegistry::getPassRegistry());
242   }
243 
244   bool runOnMachineFunction(MachineFunction &MF) override;
245 
246   StringRef getPassName() const override {
247     return "AMDGPU Rewrite AGPR-Copy-MFMA";
248   }
249 
250   void getAnalysisUsage(AnalysisUsage &AU) const override {
251     AU.addRequired<LiveIntervalsWrapperPass>();
252     AU.addRequired<VirtRegMapWrapperLegacy>();
253     AU.addRequired<LiveRegMatrixWrapperLegacy>();
254 
255     AU.addPreserved<LiveIntervalsWrapperPass>();
256     AU.addPreserved<VirtRegMapWrapperLegacy>();
257     AU.addPreserved<LiveRegMatrixWrapperLegacy>();
258     AU.setPreservesAll();
259     MachineFunctionPass::getAnalysisUsage(AU);
260   }
261 };
262 
263 } // End anonymous namespace.
264 
265 INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
266                       "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
267 INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
268 INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
269 INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
270 INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
271                     "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
272 
273 char AMDGPURewriteAGPRCopyMFMALegacy::ID = 0;
274 
275 char &llvm::AMDGPURewriteAGPRCopyMFMALegacyID =
276     AMDGPURewriteAGPRCopyMFMALegacy::ID;
277 
278 bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
279     MachineFunction &MF) {
280   if (skipFunction(MF.getFunction()))
281     return false;
282 
283   auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
284   auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
285   auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
286 
287   AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS);
288   return Impl.run(MF);
289 }
290 
291 PreservedAnalyses
292 AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
293                                    MachineFunctionAnalysisManager &MFAM) {
294   VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(MF);
295   LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(MF);
296   LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(MF);
297 
298   AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS);
299   if (!Impl.run(MF))
300     return PreservedAnalyses::all();
301   auto PA = getMachineFunctionPassPreservedAnalyses();
302   PA.preserveSet<CFGAnalyses>();
303   return PA;
304 }
305