xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp (revision 924226fba12cc9a228c73b956e1b7fa24c60b055)
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
10 //
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
16 
17 #include "AArch64PBQPRegAlloc.h"
18 #include "AArch64.h"
19 #include "AArch64RegisterInfo.h"
20 #include "llvm/CodeGen/LiveIntervals.h"
21 #include "llvm/CodeGen/MachineBasicBlock.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/RegAllocPBQP.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 #define DEBUG_TYPE "aarch64-pbqp"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 #ifndef NDEBUG
36 bool isFPReg(unsigned reg) {
37   return AArch64::FPR32RegClass.contains(reg) ||
38          AArch64::FPR64RegClass.contains(reg) ||
39          AArch64::FPR128RegClass.contains(reg);
40 }
41 #endif
42 
43 bool isOdd(unsigned reg) {
44   switch (reg) {
45   default:
46     llvm_unreachable("Register is not from the expected class !");
47   case AArch64::S1:
48   case AArch64::S3:
49   case AArch64::S5:
50   case AArch64::S7:
51   case AArch64::S9:
52   case AArch64::S11:
53   case AArch64::S13:
54   case AArch64::S15:
55   case AArch64::S17:
56   case AArch64::S19:
57   case AArch64::S21:
58   case AArch64::S23:
59   case AArch64::S25:
60   case AArch64::S27:
61   case AArch64::S29:
62   case AArch64::S31:
63   case AArch64::D1:
64   case AArch64::D3:
65   case AArch64::D5:
66   case AArch64::D7:
67   case AArch64::D9:
68   case AArch64::D11:
69   case AArch64::D13:
70   case AArch64::D15:
71   case AArch64::D17:
72   case AArch64::D19:
73   case AArch64::D21:
74   case AArch64::D23:
75   case AArch64::D25:
76   case AArch64::D27:
77   case AArch64::D29:
78   case AArch64::D31:
79   case AArch64::Q1:
80   case AArch64::Q3:
81   case AArch64::Q5:
82   case AArch64::Q7:
83   case AArch64::Q9:
84   case AArch64::Q11:
85   case AArch64::Q13:
86   case AArch64::Q15:
87   case AArch64::Q17:
88   case AArch64::Q19:
89   case AArch64::Q21:
90   case AArch64::Q23:
91   case AArch64::Q25:
92   case AArch64::Q27:
93   case AArch64::Q29:
94   case AArch64::Q31:
95     return true;
96   case AArch64::S0:
97   case AArch64::S2:
98   case AArch64::S4:
99   case AArch64::S6:
100   case AArch64::S8:
101   case AArch64::S10:
102   case AArch64::S12:
103   case AArch64::S14:
104   case AArch64::S16:
105   case AArch64::S18:
106   case AArch64::S20:
107   case AArch64::S22:
108   case AArch64::S24:
109   case AArch64::S26:
110   case AArch64::S28:
111   case AArch64::S30:
112   case AArch64::D0:
113   case AArch64::D2:
114   case AArch64::D4:
115   case AArch64::D6:
116   case AArch64::D8:
117   case AArch64::D10:
118   case AArch64::D12:
119   case AArch64::D14:
120   case AArch64::D16:
121   case AArch64::D18:
122   case AArch64::D20:
123   case AArch64::D22:
124   case AArch64::D24:
125   case AArch64::D26:
126   case AArch64::D28:
127   case AArch64::D30:
128   case AArch64::Q0:
129   case AArch64::Q2:
130   case AArch64::Q4:
131   case AArch64::Q6:
132   case AArch64::Q8:
133   case AArch64::Q10:
134   case AArch64::Q12:
135   case AArch64::Q14:
136   case AArch64::Q16:
137   case AArch64::Q18:
138   case AArch64::Q20:
139   case AArch64::Q22:
140   case AArch64::Q24:
141   case AArch64::Q26:
142   case AArch64::Q28:
143   case AArch64::Q30:
144     return false;
145 
146   }
147 }
148 
149 bool haveSameParity(unsigned reg1, unsigned reg2) {
150   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
151   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
152 
153   return isOdd(reg1) == isOdd(reg2);
154 }
155 
156 }
157 
158 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
159                                                  unsigned Ra) {
160   if (Rd == Ra)
161     return false;
162 
163   LiveIntervals &LIs = G.getMetadata().LIS;
164 
165   if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
166     LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
167                       << Register::isPhysicalRegister(Rd) << '\n');
168     LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
169                       << Register::isPhysicalRegister(Ra) << '\n');
170     return false;
171   }
172 
173   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
174   PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
175 
176   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
177     &G.getNodeMetadata(node1).getAllowedRegs();
178   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
179     &G.getNodeMetadata(node2).getAllowedRegs();
180 
181   PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
182 
183   // The edge does not exist. Create one with the appropriate interference
184   // costs.
185   if (edge == G.invalidEdgeId()) {
186     const LiveInterval &ld = LIs.getInterval(Rd);
187     const LiveInterval &la = LIs.getInterval(Ra);
188     bool livesOverlap = ld.overlaps(la);
189 
190     PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
191                                  vRaAllowed->size() + 1, 0);
192     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
193       unsigned pRd = (*vRdAllowed)[i];
194       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
195         unsigned pRa = (*vRaAllowed)[j];
196         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
197           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
198         else
199           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
200       }
201     }
202     G.addEdge(node1, node2, std::move(costs));
203     return true;
204   }
205 
206   if (G.getEdgeNode1Id(edge) == node2) {
207     std::swap(node1, node2);
208     std::swap(vRdAllowed, vRaAllowed);
209   }
210 
211   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212   PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
213   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
214     unsigned pRd = (*vRdAllowed)[i];
215 
216     // Get the maximum cost (excluding unallocatable reg) for same parity
217     // registers
218     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
219     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
220       unsigned pRa = (*vRaAllowed)[j];
221       if (haveSameParity(pRd, pRa))
222         if (costs[i + 1][j + 1] !=
223                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
224             costs[i + 1][j + 1] > sameParityMax)
225           sameParityMax = costs[i + 1][j + 1];
226     }
227 
228     // Ensure all registers with a different parity have a higher cost
229     // than sameParityMax
230     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
231       unsigned pRa = (*vRaAllowed)[j];
232       if (!haveSameParity(pRd, pRa))
233         if (sameParityMax > costs[i + 1][j + 1])
234           costs[i + 1][j + 1] = sameParityMax + 1.0;
235     }
236   }
237   G.updateEdgeCosts(edge, std::move(costs));
238 
239   return true;
240 }
241 
242 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
243                                                  unsigned Ra) {
244   LiveIntervals &LIs = G.getMetadata().LIS;
245 
246   // Do some Chain management
247   if (Chains.count(Ra)) {
248     if (Rd != Ra) {
249       LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
250                         << " to " << printReg(Rd, TRI) << '\n';);
251       Chains.remove(Ra);
252       Chains.insert(Rd);
253     }
254   } else {
255     LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
256                       << '\n';);
257     Chains.insert(Rd);
258   }
259 
260   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
261 
262   const LiveInterval &ld = LIs.getInterval(Rd);
263   for (auto r : Chains) {
264     // Skip self
265     if (r == Rd)
266       continue;
267 
268     const LiveInterval &lr = LIs.getInterval(r);
269     if (ld.overlaps(lr)) {
270       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
271         &G.getNodeMetadata(node1).getAllowedRegs();
272 
273       PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
274       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
275         &G.getNodeMetadata(node2).getAllowedRegs();
276 
277       PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
278       assert(edge != G.invalidEdgeId() &&
279              "PBQP error ! The edge should exist !");
280 
281       LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
282 
283       if (G.getEdgeNode1Id(edge) == node2) {
284         std::swap(node1, node2);
285         std::swap(vRdAllowed, vRrAllowed);
286       }
287 
288       // Enforce that cost is higher with all other Chains of the same parity
289       PBQP::Matrix costs(G.getEdgeCosts(edge));
290       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
291         unsigned pRd = (*vRdAllowed)[i];
292 
293         // Get the maximum cost (excluding unallocatable reg) for all other
294         // parity registers
295         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
296         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
297           unsigned pRa = (*vRrAllowed)[j];
298           if (!haveSameParity(pRd, pRa))
299             if (costs[i + 1][j + 1] !=
300                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
301                 costs[i + 1][j + 1] > sameParityMax)
302               sameParityMax = costs[i + 1][j + 1];
303         }
304 
305         // Ensure all registers with same parity have a higher cost
306         // than sameParityMax
307         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
308           unsigned pRa = (*vRrAllowed)[j];
309           if (haveSameParity(pRd, pRa))
310             if (sameParityMax > costs[i + 1][j + 1])
311               costs[i + 1][j + 1] = sameParityMax + 1.0;
312         }
313       }
314       G.updateEdgeCosts(edge, std::move(costs));
315     }
316   }
317 }
318 
319 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
320                                 const MachineInstr &MI) {
321   const LiveInterval &LI = LIs.getInterval(reg);
322   SlotIndex SI = LIs.getInstructionIndex(MI);
323   return LI.expiredAt(SI);
324 }
325 
326 void A57ChainingConstraint::apply(PBQPRAGraph &G) {
327   const MachineFunction &MF = G.getMetadata().MF;
328   LiveIntervals &LIs = G.getMetadata().LIS;
329 
330   TRI = MF.getSubtarget().getRegisterInfo();
331   LLVM_DEBUG(MF.dump());
332 
333   for (const auto &MBB: MF) {
334     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
335 
336     for (const auto &MI: MBB) {
337 
338       // Forget Chains which have expired
339       for (auto r : Chains) {
340         SmallVector<unsigned, 8> toDel;
341         if(regJustKilledBefore(LIs, r, MI)) {
342           LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
343                      MI.print(dbgs()););
344           toDel.push_back(r);
345         }
346 
347         while (!toDel.empty()) {
348           Chains.remove(toDel.back());
349           toDel.pop_back();
350         }
351       }
352 
353       switch (MI.getOpcode()) {
354       case AArch64::FMSUBSrrr:
355       case AArch64::FMADDSrrr:
356       case AArch64::FNMSUBSrrr:
357       case AArch64::FNMADDSrrr:
358       case AArch64::FMSUBDrrr:
359       case AArch64::FMADDDrrr:
360       case AArch64::FNMSUBDrrr:
361       case AArch64::FNMADDDrrr: {
362         Register Rd = MI.getOperand(0).getReg();
363         Register Ra = MI.getOperand(3).getReg();
364 
365         if (addIntraChainConstraint(G, Rd, Ra))
366           addInterChainConstraint(G, Rd, Ra);
367         break;
368       }
369 
370       case AArch64::FMLAv2f32:
371       case AArch64::FMLSv2f32: {
372         Register Rd = MI.getOperand(0).getReg();
373         addInterChainConstraint(G, Rd, Rd);
374         break;
375       }
376 
377       default:
378         break;
379       }
380     }
381   }
382 }
383