1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
10 //
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
16
17 #include "AArch64PBQPRegAlloc.h"
18 #include "AArch64.h"
19 #include "AArch64InstrInfo.h"
20 #include "AArch64RegisterInfo.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include "llvm/CodeGen/MachineBasicBlock.h"
23 #include "llvm/CodeGen/MachineFunction.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/RegAllocPBQP.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 #define DEBUG_TYPE "aarch64-pbqp"
31
32 using namespace llvm;
33
34 namespace {
35
isOdd(unsigned reg)36 bool isOdd(unsigned reg) {
37 switch (reg) {
38 default:
39 llvm_unreachable("Register is not from the expected class !");
40 case AArch64::S1:
41 case AArch64::S3:
42 case AArch64::S5:
43 case AArch64::S7:
44 case AArch64::S9:
45 case AArch64::S11:
46 case AArch64::S13:
47 case AArch64::S15:
48 case AArch64::S17:
49 case AArch64::S19:
50 case AArch64::S21:
51 case AArch64::S23:
52 case AArch64::S25:
53 case AArch64::S27:
54 case AArch64::S29:
55 case AArch64::S31:
56 case AArch64::D1:
57 case AArch64::D3:
58 case AArch64::D5:
59 case AArch64::D7:
60 case AArch64::D9:
61 case AArch64::D11:
62 case AArch64::D13:
63 case AArch64::D15:
64 case AArch64::D17:
65 case AArch64::D19:
66 case AArch64::D21:
67 case AArch64::D23:
68 case AArch64::D25:
69 case AArch64::D27:
70 case AArch64::D29:
71 case AArch64::D31:
72 case AArch64::Q1:
73 case AArch64::Q3:
74 case AArch64::Q5:
75 case AArch64::Q7:
76 case AArch64::Q9:
77 case AArch64::Q11:
78 case AArch64::Q13:
79 case AArch64::Q15:
80 case AArch64::Q17:
81 case AArch64::Q19:
82 case AArch64::Q21:
83 case AArch64::Q23:
84 case AArch64::Q25:
85 case AArch64::Q27:
86 case AArch64::Q29:
87 case AArch64::Q31:
88 return true;
89 case AArch64::S0:
90 case AArch64::S2:
91 case AArch64::S4:
92 case AArch64::S6:
93 case AArch64::S8:
94 case AArch64::S10:
95 case AArch64::S12:
96 case AArch64::S14:
97 case AArch64::S16:
98 case AArch64::S18:
99 case AArch64::S20:
100 case AArch64::S22:
101 case AArch64::S24:
102 case AArch64::S26:
103 case AArch64::S28:
104 case AArch64::S30:
105 case AArch64::D0:
106 case AArch64::D2:
107 case AArch64::D4:
108 case AArch64::D6:
109 case AArch64::D8:
110 case AArch64::D10:
111 case AArch64::D12:
112 case AArch64::D14:
113 case AArch64::D16:
114 case AArch64::D18:
115 case AArch64::D20:
116 case AArch64::D22:
117 case AArch64::D24:
118 case AArch64::D26:
119 case AArch64::D28:
120 case AArch64::D30:
121 case AArch64::Q0:
122 case AArch64::Q2:
123 case AArch64::Q4:
124 case AArch64::Q6:
125 case AArch64::Q8:
126 case AArch64::Q10:
127 case AArch64::Q12:
128 case AArch64::Q14:
129 case AArch64::Q16:
130 case AArch64::Q18:
131 case AArch64::Q20:
132 case AArch64::Q22:
133 case AArch64::Q24:
134 case AArch64::Q26:
135 case AArch64::Q28:
136 case AArch64::Q30:
137 return false;
138
139 }
140 }
141
haveSameParity(unsigned reg1,unsigned reg2)142 bool haveSameParity(unsigned reg1, unsigned reg2) {
143 assert(AArch64InstrInfo::isFpOrNEON(reg1) &&
144 "Expecting an FP register for reg1");
145 assert(AArch64InstrInfo::isFpOrNEON(reg2) &&
146 "Expecting an FP register for reg2");
147
148 return isOdd(reg1) == isOdd(reg2);
149 }
150
151 }
152
addIntraChainConstraint(PBQPRAGraph & G,unsigned Rd,unsigned Ra)153 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
154 unsigned Ra) {
155 if (Rd == Ra)
156 return false;
157
158 LiveIntervals &LIs = G.getMetadata().LIS;
159
160 if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
161 LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
162 << Register::isPhysicalRegister(Rd) << '\n');
163 LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
164 << Register::isPhysicalRegister(Ra) << '\n');
165 return false;
166 }
167
168 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
169 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
170
171 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
172 &G.getNodeMetadata(node1).getAllowedRegs();
173 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
174 &G.getNodeMetadata(node2).getAllowedRegs();
175
176 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
177
178 // The edge does not exist. Create one with the appropriate interference
179 // costs.
180 if (edge == G.invalidEdgeId()) {
181 const LiveInterval &ld = LIs.getInterval(Rd);
182 const LiveInterval &la = LIs.getInterval(Ra);
183 bool livesOverlap = ld.overlaps(la);
184
185 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
186 vRaAllowed->size() + 1, 0);
187 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
188 unsigned pRd = (*vRdAllowed)[i];
189 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
190 unsigned pRa = (*vRaAllowed)[j];
191 if (livesOverlap && TRI->regsOverlap(pRd, pRa))
192 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
193 else
194 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
195 }
196 }
197 G.addEdge(node1, node2, std::move(costs));
198 return true;
199 }
200
201 if (G.getEdgeNode1Id(edge) == node2) {
202 std::swap(node1, node2);
203 std::swap(vRdAllowed, vRaAllowed);
204 }
205
206 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
207 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
208 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
209 unsigned pRd = (*vRdAllowed)[i];
210
211 // Get the maximum cost (excluding unallocatable reg) for same parity
212 // registers
213 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
214 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
215 unsigned pRa = (*vRaAllowed)[j];
216 if (haveSameParity(pRd, pRa))
217 if (costs[i + 1][j + 1] !=
218 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
219 costs[i + 1][j + 1] > sameParityMax)
220 sameParityMax = costs[i + 1][j + 1];
221 }
222
223 // Ensure all registers with a different parity have a higher cost
224 // than sameParityMax
225 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
226 unsigned pRa = (*vRaAllowed)[j];
227 if (!haveSameParity(pRd, pRa))
228 if (sameParityMax > costs[i + 1][j + 1])
229 costs[i + 1][j + 1] = sameParityMax + 1.0;
230 }
231 }
232 G.updateEdgeCosts(edge, std::move(costs));
233
234 return true;
235 }
236
addInterChainConstraint(PBQPRAGraph & G,unsigned Rd,unsigned Ra)237 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
238 unsigned Ra) {
239 LiveIntervals &LIs = G.getMetadata().LIS;
240
241 // Do some Chain management
242 if (Chains.count(Ra)) {
243 if (Rd != Ra) {
244 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
245 << " to " << printReg(Rd, TRI) << '\n';);
246 Chains.remove(Ra);
247 Chains.insert(Rd);
248 }
249 } else {
250 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
251 << '\n';);
252 Chains.insert(Rd);
253 }
254
255 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
256
257 const LiveInterval &ld = LIs.getInterval(Rd);
258 for (auto r : Chains) {
259 // Skip self
260 if (r == Rd)
261 continue;
262
263 const LiveInterval &lr = LIs.getInterval(r);
264 if (ld.overlaps(lr)) {
265 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
266 &G.getNodeMetadata(node1).getAllowedRegs();
267
268 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
269 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
270 &G.getNodeMetadata(node2).getAllowedRegs();
271
272 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
273 assert(edge != G.invalidEdgeId() &&
274 "PBQP error ! The edge should exist !");
275
276 LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
277
278 if (G.getEdgeNode1Id(edge) == node2) {
279 std::swap(node1, node2);
280 std::swap(vRdAllowed, vRrAllowed);
281 }
282
283 // Enforce that cost is higher with all other Chains of the same parity
284 PBQP::Matrix costs(G.getEdgeCosts(edge));
285 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
286 unsigned pRd = (*vRdAllowed)[i];
287
288 // Get the maximum cost (excluding unallocatable reg) for all other
289 // parity registers
290 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
291 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
292 unsigned pRa = (*vRrAllowed)[j];
293 if (!haveSameParity(pRd, pRa))
294 if (costs[i + 1][j + 1] !=
295 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
296 costs[i + 1][j + 1] > sameParityMax)
297 sameParityMax = costs[i + 1][j + 1];
298 }
299
300 // Ensure all registers with same parity have a higher cost
301 // than sameParityMax
302 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
303 unsigned pRa = (*vRrAllowed)[j];
304 if (haveSameParity(pRd, pRa))
305 if (sameParityMax > costs[i + 1][j + 1])
306 costs[i + 1][j + 1] = sameParityMax + 1.0;
307 }
308 }
309 G.updateEdgeCosts(edge, std::move(costs));
310 }
311 }
312 }
313
regJustKilledBefore(const LiveIntervals & LIs,unsigned reg,const MachineInstr & MI)314 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
315 const MachineInstr &MI) {
316 const LiveInterval &LI = LIs.getInterval(reg);
317 SlotIndex SI = LIs.getInstructionIndex(MI);
318 return LI.expiredAt(SI);
319 }
320
apply(PBQPRAGraph & G)321 void A57ChainingConstraint::apply(PBQPRAGraph &G) {
322 const MachineFunction &MF = G.getMetadata().MF;
323 LiveIntervals &LIs = G.getMetadata().LIS;
324
325 TRI = MF.getSubtarget().getRegisterInfo();
326 LLVM_DEBUG(MF.dump());
327
328 for (const auto &MBB: MF) {
329 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
330
331 for (const auto &MI: MBB) {
332
333 // Forget Chains which have expired
334 for (auto r : Chains) {
335 SmallVector<unsigned, 8> toDel;
336 if(regJustKilledBefore(LIs, r, MI)) {
337 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
338 MI.print(dbgs()););
339 toDel.push_back(r);
340 }
341
342 while (!toDel.empty()) {
343 Chains.remove(toDel.back());
344 toDel.pop_back();
345 }
346 }
347
348 switch (MI.getOpcode()) {
349 case AArch64::FMSUBSrrr:
350 case AArch64::FMADDSrrr:
351 case AArch64::FNMSUBSrrr:
352 case AArch64::FNMADDSrrr:
353 case AArch64::FMSUBDrrr:
354 case AArch64::FMADDDrrr:
355 case AArch64::FNMSUBDrrr:
356 case AArch64::FNMADDDrrr: {
357 Register Rd = MI.getOperand(0).getReg();
358 Register Ra = MI.getOperand(3).getReg();
359
360 if (addIntraChainConstraint(G, Rd, Ra))
361 addInterChainConstraint(G, Rd, Ra);
362 break;
363 }
364
365 case AArch64::FMLAv2f32:
366 case AArch64::FMLSv2f32: {
367 Register Rd = MI.getOperand(0).getReg();
368 addInterChainConstraint(G, Rd, Rd);
369 break;
370 }
371
372 default:
373 break;
374 }
375 }
376 }
377 }
378