1 //===- AArch64FalkorHWPFFix.cpp - Avoid HW prefetcher pitfalls on Falkor --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file For Falkor, we want to avoid HW prefetcher instruction tag collisions 9 /// that may inhibit the HW prefetching. This is done in two steps. Before 10 /// ISel, we mark strided loads (i.e. those that will likely benefit from 11 /// prefetching) with metadata. Then, after opcodes have been finalized, we 12 /// insert MOVs and re-write loads to prevent unintentional tag collisions. 13 // ===---------------------------------------------------------------------===// 14 15 #include "AArch64.h" 16 #include "AArch64InstrInfo.h" 17 #include "AArch64Subtarget.h" 18 #include "AArch64TargetMachine.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/DepthFirstIterator.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/Statistic.h" 23 #include "llvm/Analysis/LoopInfo.h" 24 #include "llvm/Analysis/ScalarEvolution.h" 25 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 26 #include "llvm/CodeGen/LiveRegUnits.h" 27 #include "llvm/CodeGen/MachineBasicBlock.h" 28 #include "llvm/CodeGen/MachineFunction.h" 29 #include "llvm/CodeGen/MachineFunctionPass.h" 30 #include "llvm/CodeGen/MachineInstr.h" 31 #include "llvm/CodeGen/MachineInstrBuilder.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineOperand.h" 34 #include "llvm/CodeGen/MachineRegisterInfo.h" 35 #include "llvm/CodeGen/TargetPassConfig.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/IR/DebugLoc.h" 38 #include "llvm/IR/Dominators.h" 39 #include "llvm/IR/Function.h" 40 #include "llvm/IR/Instruction.h" 41 #include "llvm/IR/Instructions.h" 42 #include "llvm/IR/Metadata.h" 43 #include "llvm/InitializePasses.h" 44 #include "llvm/Pass.h" 45 #include "llvm/Support/Casting.h" 46 #include "llvm/Support/Debug.h" 47 #include "llvm/Support/DebugCounter.h" 48 #include "llvm/Support/raw_ostream.h" 49 #include <iterator> 50 #include <utility> 51 52 using namespace llvm; 53 54 #define DEBUG_TYPE "aarch64-falkor-hwpf-fix" 55 56 STATISTIC(NumStridedLoadsMarked, "Number of strided loads marked"); 57 STATISTIC(NumCollisionsAvoided, 58 "Number of HW prefetch tag collisions avoided"); 59 STATISTIC(NumCollisionsNotAvoided, 60 "Number of HW prefetch tag collisions not avoided due to lack of registers"); 61 DEBUG_COUNTER(FixCounter, "falkor-hwpf", 62 "Controls which tag collisions are avoided"); 63 64 namespace { 65 66 class FalkorMarkStridedAccesses { 67 public: 68 FalkorMarkStridedAccesses(LoopInfo &LI, ScalarEvolution &SE) 69 : LI(LI), SE(SE) {} 70 71 bool run(); 72 73 private: 74 bool runOnLoop(Loop &L); 75 76 LoopInfo &LI; 77 ScalarEvolution &SE; 78 }; 79 80 class FalkorMarkStridedAccessesLegacy : public FunctionPass { 81 public: 82 static char ID; // Pass ID, replacement for typeid 83 84 FalkorMarkStridedAccessesLegacy() : FunctionPass(ID) { 85 initializeFalkorMarkStridedAccessesLegacyPass( 86 *PassRegistry::getPassRegistry()); 87 } 88 89 void getAnalysisUsage(AnalysisUsage &AU) const override { 90 AU.addRequired<TargetPassConfig>(); 91 AU.addPreserved<DominatorTreeWrapperPass>(); 92 AU.addRequired<LoopInfoWrapperPass>(); 93 AU.addPreserved<LoopInfoWrapperPass>(); 94 AU.addRequired<ScalarEvolutionWrapperPass>(); 95 AU.addPreserved<ScalarEvolutionWrapperPass>(); 96 } 97 98 bool runOnFunction(Function &F) override; 99 }; 100 101 } // end anonymous namespace 102 103 char FalkorMarkStridedAccessesLegacy::ID = 0; 104 105 INITIALIZE_PASS_BEGIN(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE, 106 "Falkor HW Prefetch Fix", false, false) 107 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 108 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 109 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 110 INITIALIZE_PASS_END(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE, 111 "Falkor HW Prefetch Fix", false, false) 112 113 FunctionPass *llvm::createFalkorMarkStridedAccessesPass() { 114 return new FalkorMarkStridedAccessesLegacy(); 115 } 116 117 bool FalkorMarkStridedAccessesLegacy::runOnFunction(Function &F) { 118 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 119 const AArch64Subtarget *ST = 120 TPC.getTM<AArch64TargetMachine>().getSubtargetImpl(F); 121 if (ST->getProcFamily() != AArch64Subtarget::Falkor) 122 return false; 123 124 if (skipFunction(F)) 125 return false; 126 127 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 128 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 129 130 FalkorMarkStridedAccesses LDP(LI, SE); 131 return LDP.run(); 132 } 133 134 bool FalkorMarkStridedAccesses::run() { 135 bool MadeChange = false; 136 137 for (Loop *L : LI) 138 for (Loop *LIt : depth_first(L)) 139 MadeChange |= runOnLoop(*LIt); 140 141 return MadeChange; 142 } 143 144 bool FalkorMarkStridedAccesses::runOnLoop(Loop &L) { 145 // Only mark strided loads in the inner-most loop 146 if (!L.isInnermost()) 147 return false; 148 149 bool MadeChange = false; 150 151 for (BasicBlock *BB : L.blocks()) { 152 for (Instruction &I : *BB) { 153 LoadInst *LoadI = dyn_cast<LoadInst>(&I); 154 if (!LoadI) 155 continue; 156 157 Value *PtrValue = LoadI->getPointerOperand(); 158 if (L.isLoopInvariant(PtrValue)) 159 continue; 160 161 const SCEV *LSCEV = SE.getSCEV(PtrValue); 162 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); 163 if (!LSCEVAddRec || !LSCEVAddRec->isAffine()) 164 continue; 165 166 LoadI->setMetadata(FALKOR_STRIDED_ACCESS_MD, 167 MDNode::get(LoadI->getContext(), {})); 168 ++NumStridedLoadsMarked; 169 LLVM_DEBUG(dbgs() << "Load: " << I << " marked as strided\n"); 170 MadeChange = true; 171 } 172 } 173 174 return MadeChange; 175 } 176 177 namespace { 178 179 class FalkorHWPFFix : public MachineFunctionPass { 180 public: 181 static char ID; 182 183 FalkorHWPFFix() : MachineFunctionPass(ID) { 184 initializeFalkorHWPFFixPass(*PassRegistry::getPassRegistry()); 185 } 186 187 bool runOnMachineFunction(MachineFunction &Fn) override; 188 189 void getAnalysisUsage(AnalysisUsage &AU) const override { 190 AU.setPreservesCFG(); 191 AU.addRequired<MachineLoopInfo>(); 192 MachineFunctionPass::getAnalysisUsage(AU); 193 } 194 195 MachineFunctionProperties getRequiredProperties() const override { 196 return MachineFunctionProperties().set( 197 MachineFunctionProperties::Property::NoVRegs); 198 } 199 200 private: 201 void runOnLoop(MachineLoop &L, MachineFunction &Fn); 202 203 const AArch64InstrInfo *TII; 204 const TargetRegisterInfo *TRI; 205 DenseMap<unsigned, SmallVector<MachineInstr *, 4>> TagMap; 206 bool Modified; 207 }; 208 209 /// Bits from load opcodes used to compute HW prefetcher instruction tags. 210 struct LoadInfo { 211 LoadInfo() = default; 212 213 Register DestReg; 214 Register BaseReg; 215 int BaseRegIdx = -1; 216 const MachineOperand *OffsetOpnd = nullptr; 217 bool IsPrePost = false; 218 }; 219 220 } // end anonymous namespace 221 222 char FalkorHWPFFix::ID = 0; 223 224 INITIALIZE_PASS_BEGIN(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late", 225 "Falkor HW Prefetch Fix Late Phase", false, false) 226 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 227 INITIALIZE_PASS_END(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late", 228 "Falkor HW Prefetch Fix Late Phase", false, false) 229 230 static unsigned makeTag(unsigned Dest, unsigned Base, unsigned Offset) { 231 return (Dest & 0xf) | ((Base & 0xf) << 4) | ((Offset & 0x3f) << 8); 232 } 233 234 static std::optional<LoadInfo> getLoadInfo(const MachineInstr &MI) { 235 int DestRegIdx; 236 int BaseRegIdx; 237 int OffsetIdx; 238 bool IsPrePost; 239 240 switch (MI.getOpcode()) { 241 default: 242 return std::nullopt; 243 244 case AArch64::LD1i64: 245 case AArch64::LD2i64: 246 DestRegIdx = 0; 247 BaseRegIdx = 3; 248 OffsetIdx = -1; 249 IsPrePost = false; 250 break; 251 252 case AArch64::LD1i8: 253 case AArch64::LD1i16: 254 case AArch64::LD1i32: 255 case AArch64::LD2i8: 256 case AArch64::LD2i16: 257 case AArch64::LD2i32: 258 case AArch64::LD3i8: 259 case AArch64::LD3i16: 260 case AArch64::LD3i32: 261 case AArch64::LD3i64: 262 case AArch64::LD4i8: 263 case AArch64::LD4i16: 264 case AArch64::LD4i32: 265 case AArch64::LD4i64: 266 DestRegIdx = -1; 267 BaseRegIdx = 3; 268 OffsetIdx = -1; 269 IsPrePost = false; 270 break; 271 272 case AArch64::LD1Onev1d: 273 case AArch64::LD1Onev2s: 274 case AArch64::LD1Onev4h: 275 case AArch64::LD1Onev8b: 276 case AArch64::LD1Onev2d: 277 case AArch64::LD1Onev4s: 278 case AArch64::LD1Onev8h: 279 case AArch64::LD1Onev16b: 280 case AArch64::LD1Rv1d: 281 case AArch64::LD1Rv2s: 282 case AArch64::LD1Rv4h: 283 case AArch64::LD1Rv8b: 284 case AArch64::LD1Rv2d: 285 case AArch64::LD1Rv4s: 286 case AArch64::LD1Rv8h: 287 case AArch64::LD1Rv16b: 288 DestRegIdx = 0; 289 BaseRegIdx = 1; 290 OffsetIdx = -1; 291 IsPrePost = false; 292 break; 293 294 case AArch64::LD1Twov1d: 295 case AArch64::LD1Twov2s: 296 case AArch64::LD1Twov4h: 297 case AArch64::LD1Twov8b: 298 case AArch64::LD1Twov2d: 299 case AArch64::LD1Twov4s: 300 case AArch64::LD1Twov8h: 301 case AArch64::LD1Twov16b: 302 case AArch64::LD1Threev1d: 303 case AArch64::LD1Threev2s: 304 case AArch64::LD1Threev4h: 305 case AArch64::LD1Threev8b: 306 case AArch64::LD1Threev2d: 307 case AArch64::LD1Threev4s: 308 case AArch64::LD1Threev8h: 309 case AArch64::LD1Threev16b: 310 case AArch64::LD1Fourv1d: 311 case AArch64::LD1Fourv2s: 312 case AArch64::LD1Fourv4h: 313 case AArch64::LD1Fourv8b: 314 case AArch64::LD1Fourv2d: 315 case AArch64::LD1Fourv4s: 316 case AArch64::LD1Fourv8h: 317 case AArch64::LD1Fourv16b: 318 case AArch64::LD2Twov2s: 319 case AArch64::LD2Twov4s: 320 case AArch64::LD2Twov8b: 321 case AArch64::LD2Twov2d: 322 case AArch64::LD2Twov4h: 323 case AArch64::LD2Twov8h: 324 case AArch64::LD2Twov16b: 325 case AArch64::LD2Rv1d: 326 case AArch64::LD2Rv2s: 327 case AArch64::LD2Rv4s: 328 case AArch64::LD2Rv8b: 329 case AArch64::LD2Rv2d: 330 case AArch64::LD2Rv4h: 331 case AArch64::LD2Rv8h: 332 case AArch64::LD2Rv16b: 333 case AArch64::LD3Threev2s: 334 case AArch64::LD3Threev4h: 335 case AArch64::LD3Threev8b: 336 case AArch64::LD3Threev2d: 337 case AArch64::LD3Threev4s: 338 case AArch64::LD3Threev8h: 339 case AArch64::LD3Threev16b: 340 case AArch64::LD3Rv1d: 341 case AArch64::LD3Rv2s: 342 case AArch64::LD3Rv4h: 343 case AArch64::LD3Rv8b: 344 case AArch64::LD3Rv2d: 345 case AArch64::LD3Rv4s: 346 case AArch64::LD3Rv8h: 347 case AArch64::LD3Rv16b: 348 case AArch64::LD4Fourv2s: 349 case AArch64::LD4Fourv4h: 350 case AArch64::LD4Fourv8b: 351 case AArch64::LD4Fourv2d: 352 case AArch64::LD4Fourv4s: 353 case AArch64::LD4Fourv8h: 354 case AArch64::LD4Fourv16b: 355 case AArch64::LD4Rv1d: 356 case AArch64::LD4Rv2s: 357 case AArch64::LD4Rv4h: 358 case AArch64::LD4Rv8b: 359 case AArch64::LD4Rv2d: 360 case AArch64::LD4Rv4s: 361 case AArch64::LD4Rv8h: 362 case AArch64::LD4Rv16b: 363 DestRegIdx = -1; 364 BaseRegIdx = 1; 365 OffsetIdx = -1; 366 IsPrePost = false; 367 break; 368 369 case AArch64::LD1i64_POST: 370 case AArch64::LD2i64_POST: 371 DestRegIdx = 1; 372 BaseRegIdx = 4; 373 OffsetIdx = 5; 374 IsPrePost = true; 375 break; 376 377 case AArch64::LD1i8_POST: 378 case AArch64::LD1i16_POST: 379 case AArch64::LD1i32_POST: 380 case AArch64::LD2i8_POST: 381 case AArch64::LD2i16_POST: 382 case AArch64::LD2i32_POST: 383 case AArch64::LD3i8_POST: 384 case AArch64::LD3i16_POST: 385 case AArch64::LD3i32_POST: 386 case AArch64::LD3i64_POST: 387 case AArch64::LD4i8_POST: 388 case AArch64::LD4i16_POST: 389 case AArch64::LD4i32_POST: 390 case AArch64::LD4i64_POST: 391 DestRegIdx = -1; 392 BaseRegIdx = 4; 393 OffsetIdx = 5; 394 IsPrePost = true; 395 break; 396 397 case AArch64::LD1Onev1d_POST: 398 case AArch64::LD1Onev2s_POST: 399 case AArch64::LD1Onev4h_POST: 400 case AArch64::LD1Onev8b_POST: 401 case AArch64::LD1Onev2d_POST: 402 case AArch64::LD1Onev4s_POST: 403 case AArch64::LD1Onev8h_POST: 404 case AArch64::LD1Onev16b_POST: 405 case AArch64::LD1Rv1d_POST: 406 case AArch64::LD1Rv2s_POST: 407 case AArch64::LD1Rv4h_POST: 408 case AArch64::LD1Rv8b_POST: 409 case AArch64::LD1Rv2d_POST: 410 case AArch64::LD1Rv4s_POST: 411 case AArch64::LD1Rv8h_POST: 412 case AArch64::LD1Rv16b_POST: 413 DestRegIdx = 1; 414 BaseRegIdx = 2; 415 OffsetIdx = 3; 416 IsPrePost = true; 417 break; 418 419 case AArch64::LD1Twov1d_POST: 420 case AArch64::LD1Twov2s_POST: 421 case AArch64::LD1Twov4h_POST: 422 case AArch64::LD1Twov8b_POST: 423 case AArch64::LD1Twov2d_POST: 424 case AArch64::LD1Twov4s_POST: 425 case AArch64::LD1Twov8h_POST: 426 case AArch64::LD1Twov16b_POST: 427 case AArch64::LD1Threev1d_POST: 428 case AArch64::LD1Threev2s_POST: 429 case AArch64::LD1Threev4h_POST: 430 case AArch64::LD1Threev8b_POST: 431 case AArch64::LD1Threev2d_POST: 432 case AArch64::LD1Threev4s_POST: 433 case AArch64::LD1Threev8h_POST: 434 case AArch64::LD1Threev16b_POST: 435 case AArch64::LD1Fourv1d_POST: 436 case AArch64::LD1Fourv2s_POST: 437 case AArch64::LD1Fourv4h_POST: 438 case AArch64::LD1Fourv8b_POST: 439 case AArch64::LD1Fourv2d_POST: 440 case AArch64::LD1Fourv4s_POST: 441 case AArch64::LD1Fourv8h_POST: 442 case AArch64::LD1Fourv16b_POST: 443 case AArch64::LD2Twov2s_POST: 444 case AArch64::LD2Twov4s_POST: 445 case AArch64::LD2Twov8b_POST: 446 case AArch64::LD2Twov2d_POST: 447 case AArch64::LD2Twov4h_POST: 448 case AArch64::LD2Twov8h_POST: 449 case AArch64::LD2Twov16b_POST: 450 case AArch64::LD2Rv1d_POST: 451 case AArch64::LD2Rv2s_POST: 452 case AArch64::LD2Rv4s_POST: 453 case AArch64::LD2Rv8b_POST: 454 case AArch64::LD2Rv2d_POST: 455 case AArch64::LD2Rv4h_POST: 456 case AArch64::LD2Rv8h_POST: 457 case AArch64::LD2Rv16b_POST: 458 case AArch64::LD3Threev2s_POST: 459 case AArch64::LD3Threev4h_POST: 460 case AArch64::LD3Threev8b_POST: 461 case AArch64::LD3Threev2d_POST: 462 case AArch64::LD3Threev4s_POST: 463 case AArch64::LD3Threev8h_POST: 464 case AArch64::LD3Threev16b_POST: 465 case AArch64::LD3Rv1d_POST: 466 case AArch64::LD3Rv2s_POST: 467 case AArch64::LD3Rv4h_POST: 468 case AArch64::LD3Rv8b_POST: 469 case AArch64::LD3Rv2d_POST: 470 case AArch64::LD3Rv4s_POST: 471 case AArch64::LD3Rv8h_POST: 472 case AArch64::LD3Rv16b_POST: 473 case AArch64::LD4Fourv2s_POST: 474 case AArch64::LD4Fourv4h_POST: 475 case AArch64::LD4Fourv8b_POST: 476 case AArch64::LD4Fourv2d_POST: 477 case AArch64::LD4Fourv4s_POST: 478 case AArch64::LD4Fourv8h_POST: 479 case AArch64::LD4Fourv16b_POST: 480 case AArch64::LD4Rv1d_POST: 481 case AArch64::LD4Rv2s_POST: 482 case AArch64::LD4Rv4h_POST: 483 case AArch64::LD4Rv8b_POST: 484 case AArch64::LD4Rv2d_POST: 485 case AArch64::LD4Rv4s_POST: 486 case AArch64::LD4Rv8h_POST: 487 case AArch64::LD4Rv16b_POST: 488 DestRegIdx = -1; 489 BaseRegIdx = 2; 490 OffsetIdx = 3; 491 IsPrePost = true; 492 break; 493 494 case AArch64::LDRBBroW: 495 case AArch64::LDRBBroX: 496 case AArch64::LDRBBui: 497 case AArch64::LDRBroW: 498 case AArch64::LDRBroX: 499 case AArch64::LDRBui: 500 case AArch64::LDRDl: 501 case AArch64::LDRDroW: 502 case AArch64::LDRDroX: 503 case AArch64::LDRDui: 504 case AArch64::LDRHHroW: 505 case AArch64::LDRHHroX: 506 case AArch64::LDRHHui: 507 case AArch64::LDRHroW: 508 case AArch64::LDRHroX: 509 case AArch64::LDRHui: 510 case AArch64::LDRQl: 511 case AArch64::LDRQroW: 512 case AArch64::LDRQroX: 513 case AArch64::LDRQui: 514 case AArch64::LDRSBWroW: 515 case AArch64::LDRSBWroX: 516 case AArch64::LDRSBWui: 517 case AArch64::LDRSBXroW: 518 case AArch64::LDRSBXroX: 519 case AArch64::LDRSBXui: 520 case AArch64::LDRSHWroW: 521 case AArch64::LDRSHWroX: 522 case AArch64::LDRSHWui: 523 case AArch64::LDRSHXroW: 524 case AArch64::LDRSHXroX: 525 case AArch64::LDRSHXui: 526 case AArch64::LDRSWl: 527 case AArch64::LDRSWroW: 528 case AArch64::LDRSWroX: 529 case AArch64::LDRSWui: 530 case AArch64::LDRSl: 531 case AArch64::LDRSroW: 532 case AArch64::LDRSroX: 533 case AArch64::LDRSui: 534 case AArch64::LDRWl: 535 case AArch64::LDRWroW: 536 case AArch64::LDRWroX: 537 case AArch64::LDRWui: 538 case AArch64::LDRXl: 539 case AArch64::LDRXroW: 540 case AArch64::LDRXroX: 541 case AArch64::LDRXui: 542 case AArch64::LDURBBi: 543 case AArch64::LDURBi: 544 case AArch64::LDURDi: 545 case AArch64::LDURHHi: 546 case AArch64::LDURHi: 547 case AArch64::LDURQi: 548 case AArch64::LDURSBWi: 549 case AArch64::LDURSBXi: 550 case AArch64::LDURSHWi: 551 case AArch64::LDURSHXi: 552 case AArch64::LDURSWi: 553 case AArch64::LDURSi: 554 case AArch64::LDURWi: 555 case AArch64::LDURXi: 556 DestRegIdx = 0; 557 BaseRegIdx = 1; 558 OffsetIdx = 2; 559 IsPrePost = false; 560 break; 561 562 case AArch64::LDRBBpost: 563 case AArch64::LDRBBpre: 564 case AArch64::LDRBpost: 565 case AArch64::LDRBpre: 566 case AArch64::LDRDpost: 567 case AArch64::LDRDpre: 568 case AArch64::LDRHHpost: 569 case AArch64::LDRHHpre: 570 case AArch64::LDRHpost: 571 case AArch64::LDRHpre: 572 case AArch64::LDRQpost: 573 case AArch64::LDRQpre: 574 case AArch64::LDRSBWpost: 575 case AArch64::LDRSBWpre: 576 case AArch64::LDRSBXpost: 577 case AArch64::LDRSBXpre: 578 case AArch64::LDRSHWpost: 579 case AArch64::LDRSHWpre: 580 case AArch64::LDRSHXpost: 581 case AArch64::LDRSHXpre: 582 case AArch64::LDRSWpost: 583 case AArch64::LDRSWpre: 584 case AArch64::LDRSpost: 585 case AArch64::LDRSpre: 586 case AArch64::LDRWpost: 587 case AArch64::LDRWpre: 588 case AArch64::LDRXpost: 589 case AArch64::LDRXpre: 590 DestRegIdx = 1; 591 BaseRegIdx = 2; 592 OffsetIdx = 3; 593 IsPrePost = true; 594 break; 595 596 case AArch64::LDNPDi: 597 case AArch64::LDNPQi: 598 case AArch64::LDNPSi: 599 case AArch64::LDPQi: 600 case AArch64::LDPDi: 601 case AArch64::LDPSi: 602 DestRegIdx = -1; 603 BaseRegIdx = 2; 604 OffsetIdx = 3; 605 IsPrePost = false; 606 break; 607 608 case AArch64::LDPSWi: 609 case AArch64::LDPWi: 610 case AArch64::LDPXi: 611 DestRegIdx = 0; 612 BaseRegIdx = 2; 613 OffsetIdx = 3; 614 IsPrePost = false; 615 break; 616 617 case AArch64::LDPQpost: 618 case AArch64::LDPQpre: 619 case AArch64::LDPDpost: 620 case AArch64::LDPDpre: 621 case AArch64::LDPSpost: 622 case AArch64::LDPSpre: 623 DestRegIdx = -1; 624 BaseRegIdx = 3; 625 OffsetIdx = 4; 626 IsPrePost = true; 627 break; 628 629 case AArch64::LDPSWpost: 630 case AArch64::LDPSWpre: 631 case AArch64::LDPWpost: 632 case AArch64::LDPWpre: 633 case AArch64::LDPXpost: 634 case AArch64::LDPXpre: 635 DestRegIdx = 1; 636 BaseRegIdx = 3; 637 OffsetIdx = 4; 638 IsPrePost = true; 639 break; 640 } 641 642 // Loads from the stack pointer don't get prefetched. 643 Register BaseReg = MI.getOperand(BaseRegIdx).getReg(); 644 if (BaseReg == AArch64::SP || BaseReg == AArch64::WSP) 645 return std::nullopt; 646 647 LoadInfo LI; 648 LI.DestReg = DestRegIdx == -1 ? Register() : MI.getOperand(DestRegIdx).getReg(); 649 LI.BaseReg = BaseReg; 650 LI.BaseRegIdx = BaseRegIdx; 651 LI.OffsetOpnd = OffsetIdx == -1 ? nullptr : &MI.getOperand(OffsetIdx); 652 LI.IsPrePost = IsPrePost; 653 return LI; 654 } 655 656 static std::optional<unsigned> getTag(const TargetRegisterInfo *TRI, 657 const MachineInstr &MI, 658 const LoadInfo &LI) { 659 unsigned Dest = LI.DestReg ? TRI->getEncodingValue(LI.DestReg) : 0; 660 unsigned Base = TRI->getEncodingValue(LI.BaseReg); 661 unsigned Off; 662 if (LI.OffsetOpnd == nullptr) 663 Off = 0; 664 else if (LI.OffsetOpnd->isGlobal() || LI.OffsetOpnd->isSymbol() || 665 LI.OffsetOpnd->isCPI()) 666 return std::nullopt; 667 else if (LI.OffsetOpnd->isReg()) 668 Off = (1 << 5) | TRI->getEncodingValue(LI.OffsetOpnd->getReg()); 669 else 670 Off = LI.OffsetOpnd->getImm() >> 2; 671 672 return makeTag(Dest, Base, Off); 673 } 674 675 void FalkorHWPFFix::runOnLoop(MachineLoop &L, MachineFunction &Fn) { 676 // Build the initial tag map for the whole loop. 677 TagMap.clear(); 678 for (MachineBasicBlock *MBB : L.getBlocks()) 679 for (MachineInstr &MI : *MBB) { 680 std::optional<LoadInfo> LInfo = getLoadInfo(MI); 681 if (!LInfo) 682 continue; 683 std::optional<unsigned> Tag = getTag(TRI, MI, *LInfo); 684 if (!Tag) 685 continue; 686 TagMap[*Tag].push_back(&MI); 687 } 688 689 bool AnyCollisions = false; 690 for (auto &P : TagMap) { 691 auto Size = P.second.size(); 692 if (Size > 1) { 693 for (auto *MI : P.second) { 694 if (TII->isStridedAccess(*MI)) { 695 AnyCollisions = true; 696 break; 697 } 698 } 699 } 700 if (AnyCollisions) 701 break; 702 } 703 // Nothing to fix. 704 if (!AnyCollisions) 705 return; 706 707 MachineRegisterInfo &MRI = Fn.getRegInfo(); 708 709 // Go through all the basic blocks in the current loop and fix any streaming 710 // loads to avoid collisions with any other loads. 711 LiveRegUnits LR(*TRI); 712 for (MachineBasicBlock *MBB : L.getBlocks()) { 713 LR.clear(); 714 LR.addLiveOuts(*MBB); 715 for (auto I = MBB->rbegin(); I != MBB->rend(); LR.stepBackward(*I), ++I) { 716 MachineInstr &MI = *I; 717 if (!TII->isStridedAccess(MI)) 718 continue; 719 720 std::optional<LoadInfo> OptLdI = getLoadInfo(MI); 721 if (!OptLdI) 722 continue; 723 LoadInfo LdI = *OptLdI; 724 std::optional<unsigned> OptOldTag = getTag(TRI, MI, LdI); 725 if (!OptOldTag) 726 continue; 727 auto &OldCollisions = TagMap[*OptOldTag]; 728 if (OldCollisions.size() <= 1) 729 continue; 730 731 bool Fixed = false; 732 LLVM_DEBUG(dbgs() << "Attempting to fix tag collision: " << MI); 733 734 if (!DebugCounter::shouldExecute(FixCounter)) { 735 LLVM_DEBUG(dbgs() << "Skipping fix due to debug counter:\n " << MI); 736 continue; 737 } 738 739 // Add the non-base registers of MI as live so we don't use them as 740 // scratch registers. 741 for (unsigned OpI = 0, OpE = MI.getNumOperands(); OpI < OpE; ++OpI) { 742 if (OpI == static_cast<unsigned>(LdI.BaseRegIdx)) 743 continue; 744 MachineOperand &MO = MI.getOperand(OpI); 745 if (MO.isReg() && MO.readsReg()) 746 LR.addReg(MO.getReg()); 747 } 748 749 for (unsigned ScratchReg : AArch64::GPR64RegClass) { 750 if (!LR.available(ScratchReg) || MRI.isReserved(ScratchReg)) 751 continue; 752 753 LoadInfo NewLdI(LdI); 754 NewLdI.BaseReg = ScratchReg; 755 unsigned NewTag = *getTag(TRI, MI, NewLdI); 756 // Scratch reg tag would collide too, so don't use it. 757 if (TagMap.count(NewTag)) 758 continue; 759 760 LLVM_DEBUG(dbgs() << "Changing base reg to: " 761 << printReg(ScratchReg, TRI) << '\n'); 762 763 // Rewrite: 764 // Xd = LOAD Xb, off 765 // to: 766 // Xc = MOV Xb 767 // Xd = LOAD Xc, off 768 DebugLoc DL = MI.getDebugLoc(); 769 BuildMI(*MBB, &MI, DL, TII->get(AArch64::ORRXrs), ScratchReg) 770 .addReg(AArch64::XZR) 771 .addReg(LdI.BaseReg) 772 .addImm(0); 773 MachineOperand &BaseOpnd = MI.getOperand(LdI.BaseRegIdx); 774 BaseOpnd.setReg(ScratchReg); 775 776 // If the load does a pre/post increment, then insert a MOV after as 777 // well to update the real base register. 778 if (LdI.IsPrePost) { 779 LLVM_DEBUG(dbgs() << "Doing post MOV of incremented reg: " 780 << printReg(ScratchReg, TRI) << '\n'); 781 MI.getOperand(0).setReg( 782 ScratchReg); // Change tied operand pre/post update dest. 783 BuildMI(*MBB, std::next(MachineBasicBlock::iterator(MI)), DL, 784 TII->get(AArch64::ORRXrs), LdI.BaseReg) 785 .addReg(AArch64::XZR) 786 .addReg(ScratchReg) 787 .addImm(0); 788 } 789 790 for (int I = 0, E = OldCollisions.size(); I != E; ++I) 791 if (OldCollisions[I] == &MI) { 792 std::swap(OldCollisions[I], OldCollisions[E - 1]); 793 OldCollisions.pop_back(); 794 break; 795 } 796 797 // Update TagMap to reflect instruction changes to reduce the number 798 // of later MOVs to be inserted. This needs to be done after 799 // OldCollisions is updated since it may be relocated by this 800 // insertion. 801 TagMap[NewTag].push_back(&MI); 802 ++NumCollisionsAvoided; 803 Fixed = true; 804 Modified = true; 805 break; 806 } 807 if (!Fixed) 808 ++NumCollisionsNotAvoided; 809 } 810 } 811 } 812 813 bool FalkorHWPFFix::runOnMachineFunction(MachineFunction &Fn) { 814 auto &ST = Fn.getSubtarget<AArch64Subtarget>(); 815 if (ST.getProcFamily() != AArch64Subtarget::Falkor) 816 return false; 817 818 if (skipFunction(Fn.getFunction())) 819 return false; 820 821 TII = static_cast<const AArch64InstrInfo *>(ST.getInstrInfo()); 822 TRI = ST.getRegisterInfo(); 823 824 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>(); 825 826 Modified = false; 827 828 for (MachineLoop *I : LI) 829 for (MachineLoop *L : depth_first(I)) 830 // Only process inner-loops 831 if (L->isInnermost()) 832 runOnLoop(*L, Fn); 833 834 return Modified; 835 } 836 837 FunctionPass *llvm::createFalkorHWPFFixPass() { return new FalkorHWPFFix(); } 838