1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file contains the AArch64 / Cortex-A57 specific register allocation 9 // constraints for use by the PBQP register allocator. 10 // 11 // It is essentially a transcription of what is contained in 12 // AArch64A57FPLoadBalancing, which tries to use a balanced 13 // mix of odd and even D-registers when performing a critical sequence of 14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. 15 //===----------------------------------------------------------------------===// 16 17 #define DEBUG_TYPE "aarch64-pbqp" 18 19 #include "AArch64PBQPRegAlloc.h" 20 #include "AArch64.h" 21 #include "AArch64RegisterInfo.h" 22 #include "llvm/CodeGen/LiveIntervals.h" 23 #include "llvm/CodeGen/MachineBasicBlock.h" 24 #include "llvm/CodeGen/MachineFunction.h" 25 #include "llvm/CodeGen/MachineRegisterInfo.h" 26 #include "llvm/CodeGen/RegAllocPBQP.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/ErrorHandling.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 using namespace llvm; 32 33 namespace { 34 35 #ifndef NDEBUG 36 bool isFPReg(unsigned reg) { 37 return AArch64::FPR32RegClass.contains(reg) || 38 AArch64::FPR64RegClass.contains(reg) || 39 AArch64::FPR128RegClass.contains(reg); 40 } 41 #endif 42 43 bool isOdd(unsigned reg) { 44 switch (reg) { 45 default: 46 llvm_unreachable("Register is not from the expected class !"); 47 case AArch64::S1: 48 case AArch64::S3: 49 case AArch64::S5: 50 case AArch64::S7: 51 case AArch64::S9: 52 case AArch64::S11: 53 case AArch64::S13: 54 case AArch64::S15: 55 case AArch64::S17: 56 case AArch64::S19: 57 case AArch64::S21: 58 case AArch64::S23: 59 case AArch64::S25: 60 case AArch64::S27: 61 case AArch64::S29: 62 case AArch64::S31: 63 case AArch64::D1: 64 case AArch64::D3: 65 case AArch64::D5: 66 case AArch64::D7: 67 case AArch64::D9: 68 case AArch64::D11: 69 case AArch64::D13: 70 case AArch64::D15: 71 case AArch64::D17: 72 case AArch64::D19: 73 case AArch64::D21: 74 case AArch64::D23: 75 case AArch64::D25: 76 case AArch64::D27: 77 case AArch64::D29: 78 case AArch64::D31: 79 case AArch64::Q1: 80 case AArch64::Q3: 81 case AArch64::Q5: 82 case AArch64::Q7: 83 case AArch64::Q9: 84 case AArch64::Q11: 85 case AArch64::Q13: 86 case AArch64::Q15: 87 case AArch64::Q17: 88 case AArch64::Q19: 89 case AArch64::Q21: 90 case AArch64::Q23: 91 case AArch64::Q25: 92 case AArch64::Q27: 93 case AArch64::Q29: 94 case AArch64::Q31: 95 return true; 96 case AArch64::S0: 97 case AArch64::S2: 98 case AArch64::S4: 99 case AArch64::S6: 100 case AArch64::S8: 101 case AArch64::S10: 102 case AArch64::S12: 103 case AArch64::S14: 104 case AArch64::S16: 105 case AArch64::S18: 106 case AArch64::S20: 107 case AArch64::S22: 108 case AArch64::S24: 109 case AArch64::S26: 110 case AArch64::S28: 111 case AArch64::S30: 112 case AArch64::D0: 113 case AArch64::D2: 114 case AArch64::D4: 115 case AArch64::D6: 116 case AArch64::D8: 117 case AArch64::D10: 118 case AArch64::D12: 119 case AArch64::D14: 120 case AArch64::D16: 121 case AArch64::D18: 122 case AArch64::D20: 123 case AArch64::D22: 124 case AArch64::D24: 125 case AArch64::D26: 126 case AArch64::D28: 127 case AArch64::D30: 128 case AArch64::Q0: 129 case AArch64::Q2: 130 case AArch64::Q4: 131 case AArch64::Q6: 132 case AArch64::Q8: 133 case AArch64::Q10: 134 case AArch64::Q12: 135 case AArch64::Q14: 136 case AArch64::Q16: 137 case AArch64::Q18: 138 case AArch64::Q20: 139 case AArch64::Q22: 140 case AArch64::Q24: 141 case AArch64::Q26: 142 case AArch64::Q28: 143 case AArch64::Q30: 144 return false; 145 146 } 147 } 148 149 bool haveSameParity(unsigned reg1, unsigned reg2) { 150 assert(isFPReg(reg1) && "Expecting an FP register for reg1"); 151 assert(isFPReg(reg2) && "Expecting an FP register for reg2"); 152 153 return isOdd(reg1) == isOdd(reg2); 154 } 155 156 } 157 158 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, 159 unsigned Ra) { 160 if (Rd == Ra) 161 return false; 162 163 LiveIntervals &LIs = G.getMetadata().LIS; 164 165 if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) { 166 LLVM_DEBUG(dbgs() << "Rd is a physical reg:" 167 << Register::isPhysicalRegister(Rd) << '\n'); 168 LLVM_DEBUG(dbgs() << "Ra is a physical reg:" 169 << Register::isPhysicalRegister(Ra) << '\n'); 170 return false; 171 } 172 173 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 174 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra); 175 176 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 177 &G.getNodeMetadata(node1).getAllowedRegs(); 178 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = 179 &G.getNodeMetadata(node2).getAllowedRegs(); 180 181 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 182 183 // The edge does not exist. Create one with the appropriate interference 184 // costs. 185 if (edge == G.invalidEdgeId()) { 186 const LiveInterval &ld = LIs.getInterval(Rd); 187 const LiveInterval &la = LIs.getInterval(Ra); 188 bool livesOverlap = ld.overlaps(la); 189 190 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, 191 vRaAllowed->size() + 1, 0); 192 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 193 unsigned pRd = (*vRdAllowed)[i]; 194 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 195 unsigned pRa = (*vRaAllowed)[j]; 196 if (livesOverlap && TRI->regsOverlap(pRd, pRa)) 197 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); 198 else 199 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; 200 } 201 } 202 G.addEdge(node1, node2, std::move(costs)); 203 return true; 204 } 205 206 if (G.getEdgeNode1Id(edge) == node2) { 207 std::swap(node1, node2); 208 std::swap(vRdAllowed, vRaAllowed); 209 } 210 211 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) 212 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge)); 213 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 214 unsigned pRd = (*vRdAllowed)[i]; 215 216 // Get the maximum cost (excluding unallocatable reg) for same parity 217 // registers 218 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 219 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 220 unsigned pRa = (*vRaAllowed)[j]; 221 if (haveSameParity(pRd, pRa)) 222 if (costs[i + 1][j + 1] != 223 std::numeric_limits<PBQP::PBQPNum>::infinity() && 224 costs[i + 1][j + 1] > sameParityMax) 225 sameParityMax = costs[i + 1][j + 1]; 226 } 227 228 // Ensure all registers with a different parity have a higher cost 229 // than sameParityMax 230 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 231 unsigned pRa = (*vRaAllowed)[j]; 232 if (!haveSameParity(pRd, pRa)) 233 if (sameParityMax > costs[i + 1][j + 1]) 234 costs[i + 1][j + 1] = sameParityMax + 1.0; 235 } 236 } 237 G.updateEdgeCosts(edge, std::move(costs)); 238 239 return true; 240 } 241 242 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, 243 unsigned Ra) { 244 LiveIntervals &LIs = G.getMetadata().LIS; 245 246 // Do some Chain management 247 if (Chains.count(Ra)) { 248 if (Rd != Ra) { 249 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI) 250 << " to " << printReg(Rd, TRI) << '\n';); 251 Chains.remove(Ra); 252 Chains.insert(Rd); 253 } 254 } else { 255 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI) 256 << '\n';); 257 Chains.insert(Rd); 258 } 259 260 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 261 262 const LiveInterval &ld = LIs.getInterval(Rd); 263 for (auto r : Chains) { 264 // Skip self 265 if (r == Rd) 266 continue; 267 268 const LiveInterval &lr = LIs.getInterval(r); 269 if (ld.overlaps(lr)) { 270 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 271 &G.getNodeMetadata(node1).getAllowedRegs(); 272 273 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r); 274 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = 275 &G.getNodeMetadata(node2).getAllowedRegs(); 276 277 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 278 assert(edge != G.invalidEdgeId() && 279 "PBQP error ! The edge should exist !"); 280 281 LLVM_DEBUG(dbgs() << "Refining constraint !\n";); 282 283 if (G.getEdgeNode1Id(edge) == node2) { 284 std::swap(node1, node2); 285 std::swap(vRdAllowed, vRrAllowed); 286 } 287 288 // Enforce that cost is higher with all other Chains of the same parity 289 PBQP::Matrix costs(G.getEdgeCosts(edge)); 290 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 291 unsigned pRd = (*vRdAllowed)[i]; 292 293 // Get the maximum cost (excluding unallocatable reg) for all other 294 // parity registers 295 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 296 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 297 unsigned pRa = (*vRrAllowed)[j]; 298 if (!haveSameParity(pRd, pRa)) 299 if (costs[i + 1][j + 1] != 300 std::numeric_limits<PBQP::PBQPNum>::infinity() && 301 costs[i + 1][j + 1] > sameParityMax) 302 sameParityMax = costs[i + 1][j + 1]; 303 } 304 305 // Ensure all registers with same parity have a higher cost 306 // than sameParityMax 307 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 308 unsigned pRa = (*vRrAllowed)[j]; 309 if (haveSameParity(pRd, pRa)) 310 if (sameParityMax > costs[i + 1][j + 1]) 311 costs[i + 1][j + 1] = sameParityMax + 1.0; 312 } 313 } 314 G.updateEdgeCosts(edge, std::move(costs)); 315 } 316 } 317 } 318 319 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, 320 const MachineInstr &MI) { 321 const LiveInterval &LI = LIs.getInterval(reg); 322 SlotIndex SI = LIs.getInstructionIndex(MI); 323 return LI.expiredAt(SI); 324 } 325 326 void A57ChainingConstraint::apply(PBQPRAGraph &G) { 327 const MachineFunction &MF = G.getMetadata().MF; 328 LiveIntervals &LIs = G.getMetadata().LIS; 329 330 TRI = MF.getSubtarget().getRegisterInfo(); 331 LLVM_DEBUG(MF.dump()); 332 333 for (const auto &MBB: MF) { 334 Chains.clear(); // FIXME: really needed ? Could not work at MF level ? 335 336 for (const auto &MI: MBB) { 337 338 // Forget Chains which have expired 339 for (auto r : Chains) { 340 SmallVector<unsigned, 8> toDel; 341 if(regJustKilledBefore(LIs, r, MI)) { 342 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at "; 343 MI.print(dbgs());); 344 toDel.push_back(r); 345 } 346 347 while (!toDel.empty()) { 348 Chains.remove(toDel.back()); 349 toDel.pop_back(); 350 } 351 } 352 353 switch (MI.getOpcode()) { 354 case AArch64::FMSUBSrrr: 355 case AArch64::FMADDSrrr: 356 case AArch64::FNMSUBSrrr: 357 case AArch64::FNMADDSrrr: 358 case AArch64::FMSUBDrrr: 359 case AArch64::FMADDDrrr: 360 case AArch64::FNMSUBDrrr: 361 case AArch64::FNMADDDrrr: { 362 Register Rd = MI.getOperand(0).getReg(); 363 Register Ra = MI.getOperand(3).getReg(); 364 365 if (addIntraChainConstraint(G, Rd, Ra)) 366 addInterChainConstraint(G, Rd, Ra); 367 break; 368 } 369 370 case AArch64::FMLAv2f32: 371 case AArch64::FMLSv2f32: { 372 Register Rd = MI.getOperand(0).getReg(); 373 addInterChainConstraint(G, Rd, Rd); 374 break; 375 } 376 377 default: 378 break; 379 } 380 } 381 } 382 } 383