1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file contains the AArch64 / Cortex-A57 specific register allocation 9 // constraints for use by the PBQP register allocator. 10 // 11 // It is essentially a transcription of what is contained in 12 // AArch64A57FPLoadBalancing, which tries to use a balanced 13 // mix of odd and even D-registers when performing a critical sequence of 14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. 15 //===----------------------------------------------------------------------===// 16 17 #include "AArch64PBQPRegAlloc.h" 18 #include "AArch64.h" 19 #include "AArch64InstrInfo.h" 20 #include "AArch64RegisterInfo.h" 21 #include "llvm/CodeGen/LiveIntervals.h" 22 #include "llvm/CodeGen/MachineBasicBlock.h" 23 #include "llvm/CodeGen/MachineFunction.h" 24 #include "llvm/CodeGen/MachineRegisterInfo.h" 25 #include "llvm/CodeGen/RegAllocPBQP.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/ErrorHandling.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #define DEBUG_TYPE "aarch64-pbqp" 31 32 using namespace llvm; 33 34 namespace { 35 36 bool isOdd(unsigned reg) { 37 switch (reg) { 38 default: 39 llvm_unreachable("Register is not from the expected class !"); 40 case AArch64::S1: 41 case AArch64::S3: 42 case AArch64::S5: 43 case AArch64::S7: 44 case AArch64::S9: 45 case AArch64::S11: 46 case AArch64::S13: 47 case AArch64::S15: 48 case AArch64::S17: 49 case AArch64::S19: 50 case AArch64::S21: 51 case AArch64::S23: 52 case AArch64::S25: 53 case AArch64::S27: 54 case AArch64::S29: 55 case AArch64::S31: 56 case AArch64::D1: 57 case AArch64::D3: 58 case AArch64::D5: 59 case AArch64::D7: 60 case AArch64::D9: 61 case AArch64::D11: 62 case AArch64::D13: 63 case AArch64::D15: 64 case AArch64::D17: 65 case AArch64::D19: 66 case AArch64::D21: 67 case AArch64::D23: 68 case AArch64::D25: 69 case AArch64::D27: 70 case AArch64::D29: 71 case AArch64::D31: 72 case AArch64::Q1: 73 case AArch64::Q3: 74 case AArch64::Q5: 75 case AArch64::Q7: 76 case AArch64::Q9: 77 case AArch64::Q11: 78 case AArch64::Q13: 79 case AArch64::Q15: 80 case AArch64::Q17: 81 case AArch64::Q19: 82 case AArch64::Q21: 83 case AArch64::Q23: 84 case AArch64::Q25: 85 case AArch64::Q27: 86 case AArch64::Q29: 87 case AArch64::Q31: 88 return true; 89 case AArch64::S0: 90 case AArch64::S2: 91 case AArch64::S4: 92 case AArch64::S6: 93 case AArch64::S8: 94 case AArch64::S10: 95 case AArch64::S12: 96 case AArch64::S14: 97 case AArch64::S16: 98 case AArch64::S18: 99 case AArch64::S20: 100 case AArch64::S22: 101 case AArch64::S24: 102 case AArch64::S26: 103 case AArch64::S28: 104 case AArch64::S30: 105 case AArch64::D0: 106 case AArch64::D2: 107 case AArch64::D4: 108 case AArch64::D6: 109 case AArch64::D8: 110 case AArch64::D10: 111 case AArch64::D12: 112 case AArch64::D14: 113 case AArch64::D16: 114 case AArch64::D18: 115 case AArch64::D20: 116 case AArch64::D22: 117 case AArch64::D24: 118 case AArch64::D26: 119 case AArch64::D28: 120 case AArch64::D30: 121 case AArch64::Q0: 122 case AArch64::Q2: 123 case AArch64::Q4: 124 case AArch64::Q6: 125 case AArch64::Q8: 126 case AArch64::Q10: 127 case AArch64::Q12: 128 case AArch64::Q14: 129 case AArch64::Q16: 130 case AArch64::Q18: 131 case AArch64::Q20: 132 case AArch64::Q22: 133 case AArch64::Q24: 134 case AArch64::Q26: 135 case AArch64::Q28: 136 case AArch64::Q30: 137 return false; 138 139 } 140 } 141 142 bool haveSameParity(unsigned reg1, unsigned reg2) { 143 assert(AArch64InstrInfo::isFpOrNEON(reg1) && 144 "Expecting an FP register for reg1"); 145 assert(AArch64InstrInfo::isFpOrNEON(reg2) && 146 "Expecting an FP register for reg2"); 147 148 return isOdd(reg1) == isOdd(reg2); 149 } 150 151 } 152 153 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, 154 unsigned Ra) { 155 if (Rd == Ra) 156 return false; 157 158 LiveIntervals &LIs = G.getMetadata().LIS; 159 160 if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) { 161 LLVM_DEBUG(dbgs() << "Rd is a physical reg:" 162 << Register::isPhysicalRegister(Rd) << '\n'); 163 LLVM_DEBUG(dbgs() << "Ra is a physical reg:" 164 << Register::isPhysicalRegister(Ra) << '\n'); 165 return false; 166 } 167 168 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 169 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra); 170 171 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 172 &G.getNodeMetadata(node1).getAllowedRegs(); 173 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = 174 &G.getNodeMetadata(node2).getAllowedRegs(); 175 176 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 177 178 // The edge does not exist. Create one with the appropriate interference 179 // costs. 180 if (edge == G.invalidEdgeId()) { 181 const LiveInterval &ld = LIs.getInterval(Rd); 182 const LiveInterval &la = LIs.getInterval(Ra); 183 bool livesOverlap = ld.overlaps(la); 184 185 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, 186 vRaAllowed->size() + 1, 0); 187 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 188 unsigned pRd = (*vRdAllowed)[i]; 189 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 190 unsigned pRa = (*vRaAllowed)[j]; 191 if (livesOverlap && TRI->regsOverlap(pRd, pRa)) 192 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); 193 else 194 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; 195 } 196 } 197 G.addEdge(node1, node2, std::move(costs)); 198 return true; 199 } 200 201 if (G.getEdgeNode1Id(edge) == node2) { 202 std::swap(node1, node2); 203 std::swap(vRdAllowed, vRaAllowed); 204 } 205 206 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) 207 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge)); 208 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 209 unsigned pRd = (*vRdAllowed)[i]; 210 211 // Get the maximum cost (excluding unallocatable reg) for same parity 212 // registers 213 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 214 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 215 unsigned pRa = (*vRaAllowed)[j]; 216 if (haveSameParity(pRd, pRa)) 217 if (costs[i + 1][j + 1] != 218 std::numeric_limits<PBQP::PBQPNum>::infinity() && 219 costs[i + 1][j + 1] > sameParityMax) 220 sameParityMax = costs[i + 1][j + 1]; 221 } 222 223 // Ensure all registers with a different parity have a higher cost 224 // than sameParityMax 225 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 226 unsigned pRa = (*vRaAllowed)[j]; 227 if (!haveSameParity(pRd, pRa)) 228 if (sameParityMax > costs[i + 1][j + 1]) 229 costs[i + 1][j + 1] = sameParityMax + 1.0; 230 } 231 } 232 G.updateEdgeCosts(edge, std::move(costs)); 233 234 return true; 235 } 236 237 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, 238 unsigned Ra) { 239 LiveIntervals &LIs = G.getMetadata().LIS; 240 241 // Do some Chain management 242 if (Chains.count(Ra)) { 243 if (Rd != Ra) { 244 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI) 245 << " to " << printReg(Rd, TRI) << '\n';); 246 Chains.remove(Ra); 247 Chains.insert(Rd); 248 } 249 } else { 250 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI) 251 << '\n';); 252 Chains.insert(Rd); 253 } 254 255 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 256 257 const LiveInterval &ld = LIs.getInterval(Rd); 258 for (auto r : Chains) { 259 // Skip self 260 if (r == Rd) 261 continue; 262 263 const LiveInterval &lr = LIs.getInterval(r); 264 if (ld.overlaps(lr)) { 265 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 266 &G.getNodeMetadata(node1).getAllowedRegs(); 267 268 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r); 269 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = 270 &G.getNodeMetadata(node2).getAllowedRegs(); 271 272 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 273 assert(edge != G.invalidEdgeId() && 274 "PBQP error ! The edge should exist !"); 275 276 LLVM_DEBUG(dbgs() << "Refining constraint !\n";); 277 278 if (G.getEdgeNode1Id(edge) == node2) { 279 std::swap(node1, node2); 280 std::swap(vRdAllowed, vRrAllowed); 281 } 282 283 // Enforce that cost is higher with all other Chains of the same parity 284 PBQP::Matrix costs(G.getEdgeCosts(edge)); 285 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 286 unsigned pRd = (*vRdAllowed)[i]; 287 288 // Get the maximum cost (excluding unallocatable reg) for all other 289 // parity registers 290 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 291 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 292 unsigned pRa = (*vRrAllowed)[j]; 293 if (!haveSameParity(pRd, pRa)) 294 if (costs[i + 1][j + 1] != 295 std::numeric_limits<PBQP::PBQPNum>::infinity() && 296 costs[i + 1][j + 1] > sameParityMax) 297 sameParityMax = costs[i + 1][j + 1]; 298 } 299 300 // Ensure all registers with same parity have a higher cost 301 // than sameParityMax 302 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 303 unsigned pRa = (*vRrAllowed)[j]; 304 if (haveSameParity(pRd, pRa)) 305 if (sameParityMax > costs[i + 1][j + 1]) 306 costs[i + 1][j + 1] = sameParityMax + 1.0; 307 } 308 } 309 G.updateEdgeCosts(edge, std::move(costs)); 310 } 311 } 312 } 313 314 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, 315 const MachineInstr &MI) { 316 const LiveInterval &LI = LIs.getInterval(reg); 317 SlotIndex SI = LIs.getInstructionIndex(MI); 318 return LI.expiredAt(SI); 319 } 320 321 void A57ChainingConstraint::apply(PBQPRAGraph &G) { 322 const MachineFunction &MF = G.getMetadata().MF; 323 LiveIntervals &LIs = G.getMetadata().LIS; 324 325 TRI = MF.getSubtarget().getRegisterInfo(); 326 LLVM_DEBUG(MF.dump()); 327 328 for (const auto &MBB: MF) { 329 Chains.clear(); // FIXME: really needed ? Could not work at MF level ? 330 331 for (const auto &MI: MBB) { 332 333 // Forget Chains which have expired 334 for (auto r : Chains) { 335 SmallVector<unsigned, 8> toDel; 336 if(regJustKilledBefore(LIs, r, MI)) { 337 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at "; 338 MI.print(dbgs());); 339 toDel.push_back(r); 340 } 341 342 while (!toDel.empty()) { 343 Chains.remove(toDel.back()); 344 toDel.pop_back(); 345 } 346 } 347 348 switch (MI.getOpcode()) { 349 case AArch64::FMSUBSrrr: 350 case AArch64::FMADDSrrr: 351 case AArch64::FNMSUBSrrr: 352 case AArch64::FNMADDSrrr: 353 case AArch64::FMSUBDrrr: 354 case AArch64::FMADDDrrr: 355 case AArch64::FNMSUBDrrr: 356 case AArch64::FNMADDDrrr: { 357 Register Rd = MI.getOperand(0).getReg(); 358 Register Ra = MI.getOperand(3).getReg(); 359 360 if (addIntraChainConstraint(G, Rd, Ra)) 361 addInterChainConstraint(G, Rd, Ra); 362 break; 363 } 364 365 case AArch64::FMLAv2f32: 366 case AArch64::FMLSv2f32: { 367 Register Rd = MI.getOperand(0).getReg(); 368 addInterChainConstraint(G, Rd, Rd); 369 break; 370 } 371 372 default: 373 break; 374 } 375 } 376 } 377 } 378