xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp (revision 05427f4639bcf2703329a9be9d25ec09bb782742)
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
10 //
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
16 
17 #include "AArch64PBQPRegAlloc.h"
18 #include "AArch64.h"
19 #include "AArch64InstrInfo.h"
20 #include "AArch64RegisterInfo.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include "llvm/CodeGen/MachineBasicBlock.h"
23 #include "llvm/CodeGen/MachineFunction.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/RegAllocPBQP.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "aarch64-pbqp"
31 
32 using namespace llvm;
33 
34 namespace {
35 
36 bool isOdd(unsigned reg) {
37   switch (reg) {
38   default:
39     llvm_unreachable("Register is not from the expected class !");
40   case AArch64::S1:
41   case AArch64::S3:
42   case AArch64::S5:
43   case AArch64::S7:
44   case AArch64::S9:
45   case AArch64::S11:
46   case AArch64::S13:
47   case AArch64::S15:
48   case AArch64::S17:
49   case AArch64::S19:
50   case AArch64::S21:
51   case AArch64::S23:
52   case AArch64::S25:
53   case AArch64::S27:
54   case AArch64::S29:
55   case AArch64::S31:
56   case AArch64::D1:
57   case AArch64::D3:
58   case AArch64::D5:
59   case AArch64::D7:
60   case AArch64::D9:
61   case AArch64::D11:
62   case AArch64::D13:
63   case AArch64::D15:
64   case AArch64::D17:
65   case AArch64::D19:
66   case AArch64::D21:
67   case AArch64::D23:
68   case AArch64::D25:
69   case AArch64::D27:
70   case AArch64::D29:
71   case AArch64::D31:
72   case AArch64::Q1:
73   case AArch64::Q3:
74   case AArch64::Q5:
75   case AArch64::Q7:
76   case AArch64::Q9:
77   case AArch64::Q11:
78   case AArch64::Q13:
79   case AArch64::Q15:
80   case AArch64::Q17:
81   case AArch64::Q19:
82   case AArch64::Q21:
83   case AArch64::Q23:
84   case AArch64::Q25:
85   case AArch64::Q27:
86   case AArch64::Q29:
87   case AArch64::Q31:
88     return true;
89   case AArch64::S0:
90   case AArch64::S2:
91   case AArch64::S4:
92   case AArch64::S6:
93   case AArch64::S8:
94   case AArch64::S10:
95   case AArch64::S12:
96   case AArch64::S14:
97   case AArch64::S16:
98   case AArch64::S18:
99   case AArch64::S20:
100   case AArch64::S22:
101   case AArch64::S24:
102   case AArch64::S26:
103   case AArch64::S28:
104   case AArch64::S30:
105   case AArch64::D0:
106   case AArch64::D2:
107   case AArch64::D4:
108   case AArch64::D6:
109   case AArch64::D8:
110   case AArch64::D10:
111   case AArch64::D12:
112   case AArch64::D14:
113   case AArch64::D16:
114   case AArch64::D18:
115   case AArch64::D20:
116   case AArch64::D22:
117   case AArch64::D24:
118   case AArch64::D26:
119   case AArch64::D28:
120   case AArch64::D30:
121   case AArch64::Q0:
122   case AArch64::Q2:
123   case AArch64::Q4:
124   case AArch64::Q6:
125   case AArch64::Q8:
126   case AArch64::Q10:
127   case AArch64::Q12:
128   case AArch64::Q14:
129   case AArch64::Q16:
130   case AArch64::Q18:
131   case AArch64::Q20:
132   case AArch64::Q22:
133   case AArch64::Q24:
134   case AArch64::Q26:
135   case AArch64::Q28:
136   case AArch64::Q30:
137     return false;
138 
139   }
140 }
141 
142 bool haveSameParity(unsigned reg1, unsigned reg2) {
143   assert(AArch64InstrInfo::isFpOrNEON(reg1) &&
144          "Expecting an FP register for reg1");
145   assert(AArch64InstrInfo::isFpOrNEON(reg2) &&
146          "Expecting an FP register for reg2");
147 
148   return isOdd(reg1) == isOdd(reg2);
149 }
150 
151 }
152 
153 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
154                                                  unsigned Ra) {
155   if (Rd == Ra)
156     return false;
157 
158   LiveIntervals &LIs = G.getMetadata().LIS;
159 
160   if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
161     LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
162                       << Register::isPhysicalRegister(Rd) << '\n');
163     LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
164                       << Register::isPhysicalRegister(Ra) << '\n');
165     return false;
166   }
167 
168   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
169   PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
170 
171   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
172     &G.getNodeMetadata(node1).getAllowedRegs();
173   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
174     &G.getNodeMetadata(node2).getAllowedRegs();
175 
176   PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
177 
178   // The edge does not exist. Create one with the appropriate interference
179   // costs.
180   if (edge == G.invalidEdgeId()) {
181     const LiveInterval &ld = LIs.getInterval(Rd);
182     const LiveInterval &la = LIs.getInterval(Ra);
183     bool livesOverlap = ld.overlaps(la);
184 
185     PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
186                                  vRaAllowed->size() + 1, 0);
187     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
188       unsigned pRd = (*vRdAllowed)[i];
189       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
190         unsigned pRa = (*vRaAllowed)[j];
191         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
192           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
193         else
194           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
195       }
196     }
197     G.addEdge(node1, node2, std::move(costs));
198     return true;
199   }
200 
201   if (G.getEdgeNode1Id(edge) == node2) {
202     std::swap(node1, node2);
203     std::swap(vRdAllowed, vRaAllowed);
204   }
205 
206   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
207   PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
208   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
209     unsigned pRd = (*vRdAllowed)[i];
210 
211     // Get the maximum cost (excluding unallocatable reg) for same parity
212     // registers
213     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
214     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
215       unsigned pRa = (*vRaAllowed)[j];
216       if (haveSameParity(pRd, pRa))
217         if (costs[i + 1][j + 1] !=
218                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
219             costs[i + 1][j + 1] > sameParityMax)
220           sameParityMax = costs[i + 1][j + 1];
221     }
222 
223     // Ensure all registers with a different parity have a higher cost
224     // than sameParityMax
225     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
226       unsigned pRa = (*vRaAllowed)[j];
227       if (!haveSameParity(pRd, pRa))
228         if (sameParityMax > costs[i + 1][j + 1])
229           costs[i + 1][j + 1] = sameParityMax + 1.0;
230     }
231   }
232   G.updateEdgeCosts(edge, std::move(costs));
233 
234   return true;
235 }
236 
237 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
238                                                  unsigned Ra) {
239   LiveIntervals &LIs = G.getMetadata().LIS;
240 
241   // Do some Chain management
242   if (Chains.count(Ra)) {
243     if (Rd != Ra) {
244       LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
245                         << " to " << printReg(Rd, TRI) << '\n';);
246       Chains.remove(Ra);
247       Chains.insert(Rd);
248     }
249   } else {
250     LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
251                       << '\n';);
252     Chains.insert(Rd);
253   }
254 
255   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
256 
257   const LiveInterval &ld = LIs.getInterval(Rd);
258   for (auto r : Chains) {
259     // Skip self
260     if (r == Rd)
261       continue;
262 
263     const LiveInterval &lr = LIs.getInterval(r);
264     if (ld.overlaps(lr)) {
265       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
266         &G.getNodeMetadata(node1).getAllowedRegs();
267 
268       PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
269       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
270         &G.getNodeMetadata(node2).getAllowedRegs();
271 
272       PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
273       assert(edge != G.invalidEdgeId() &&
274              "PBQP error ! The edge should exist !");
275 
276       LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
277 
278       if (G.getEdgeNode1Id(edge) == node2) {
279         std::swap(node1, node2);
280         std::swap(vRdAllowed, vRrAllowed);
281       }
282 
283       // Enforce that cost is higher with all other Chains of the same parity
284       PBQP::Matrix costs(G.getEdgeCosts(edge));
285       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
286         unsigned pRd = (*vRdAllowed)[i];
287 
288         // Get the maximum cost (excluding unallocatable reg) for all other
289         // parity registers
290         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
291         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
292           unsigned pRa = (*vRrAllowed)[j];
293           if (!haveSameParity(pRd, pRa))
294             if (costs[i + 1][j + 1] !=
295                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
296                 costs[i + 1][j + 1] > sameParityMax)
297               sameParityMax = costs[i + 1][j + 1];
298         }
299 
300         // Ensure all registers with same parity have a higher cost
301         // than sameParityMax
302         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
303           unsigned pRa = (*vRrAllowed)[j];
304           if (haveSameParity(pRd, pRa))
305             if (sameParityMax > costs[i + 1][j + 1])
306               costs[i + 1][j + 1] = sameParityMax + 1.0;
307         }
308       }
309       G.updateEdgeCosts(edge, std::move(costs));
310     }
311   }
312 }
313 
314 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
315                                 const MachineInstr &MI) {
316   const LiveInterval &LI = LIs.getInterval(reg);
317   SlotIndex SI = LIs.getInstructionIndex(MI);
318   return LI.expiredAt(SI);
319 }
320 
321 void A57ChainingConstraint::apply(PBQPRAGraph &G) {
322   const MachineFunction &MF = G.getMetadata().MF;
323   LiveIntervals &LIs = G.getMetadata().LIS;
324 
325   TRI = MF.getSubtarget().getRegisterInfo();
326   LLVM_DEBUG(MF.dump());
327 
328   for (const auto &MBB: MF) {
329     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
330 
331     for (const auto &MI: MBB) {
332 
333       // Forget Chains which have expired
334       for (auto r : Chains) {
335         SmallVector<unsigned, 8> toDel;
336         if(regJustKilledBefore(LIs, r, MI)) {
337           LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
338                      MI.print(dbgs()););
339           toDel.push_back(r);
340         }
341 
342         while (!toDel.empty()) {
343           Chains.remove(toDel.back());
344           toDel.pop_back();
345         }
346       }
347 
348       switch (MI.getOpcode()) {
349       case AArch64::FMSUBSrrr:
350       case AArch64::FMADDSrrr:
351       case AArch64::FNMSUBSrrr:
352       case AArch64::FNMADDSrrr:
353       case AArch64::FMSUBDrrr:
354       case AArch64::FMADDDrrr:
355       case AArch64::FNMSUBDrrr:
356       case AArch64::FNMADDDrrr: {
357         Register Rd = MI.getOperand(0).getReg();
358         Register Ra = MI.getOperand(3).getReg();
359 
360         if (addIntraChainConstraint(G, Rd, Ra))
361           addInterChainConstraint(G, Rd, Ra);
362         break;
363       }
364 
365       case AArch64::FMLAv2f32:
366       case AArch64::FMLSv2f32: {
367         Register Rd = MI.getOperand(0).getReg();
368         addInterChainConstraint(G, Rd, Rd);
369         break;
370       }
371 
372       default:
373         break;
374       }
375     }
376   }
377 }
378