1 //===- AArch64FalkorHWPFFix.cpp - Avoid HW prefetcher pitfalls on Falkor --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file For Falkor, we want to avoid HW prefetcher instruction tag collisions
9 /// that may inhibit the HW prefetching. This is done in two steps. Before
10 /// ISel, we mark strided loads (i.e. those that will likely benefit from
11 /// prefetching) with metadata. Then, after opcodes have been finalized, we
12 /// insert MOVs and re-write loads to prevent unintentional tag collisions.
13 // ===---------------------------------------------------------------------===//
14
15 #include "AArch64.h"
16 #include "AArch64InstrInfo.h"
17 #include "AArch64Subtarget.h"
18 #include "AArch64TargetMachine.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
26 #include "llvm/CodeGen/LiveRegUnits.h"
27 #include "llvm/CodeGen/MachineBasicBlock.h"
28 #include "llvm/CodeGen/MachineFunction.h"
29 #include "llvm/CodeGen/MachineFunctionPass.h"
30 #include "llvm/CodeGen/MachineInstr.h"
31 #include "llvm/CodeGen/MachineInstrBuilder.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineOperand.h"
34 #include "llvm/CodeGen/MachineRegisterInfo.h"
35 #include "llvm/CodeGen/TargetPassConfig.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/IR/DebugLoc.h"
38 #include "llvm/IR/Dominators.h"
39 #include "llvm/IR/Function.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/Metadata.h"
43 #include "llvm/InitializePasses.h"
44 #include "llvm/Pass.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/DebugCounter.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include <iterator>
50 #include <utility>
51
52 using namespace llvm;
53
54 #define DEBUG_TYPE "aarch64-falkor-hwpf-fix"
55
56 STATISTIC(NumStridedLoadsMarked, "Number of strided loads marked");
57 STATISTIC(NumCollisionsAvoided,
58 "Number of HW prefetch tag collisions avoided");
59 STATISTIC(NumCollisionsNotAvoided,
60 "Number of HW prefetch tag collisions not avoided due to lack of registers");
61 DEBUG_COUNTER(FixCounter, "falkor-hwpf",
62 "Controls which tag collisions are avoided");
63
64 namespace {
65
66 class FalkorMarkStridedAccesses {
67 public:
FalkorMarkStridedAccesses(LoopInfo & LI,ScalarEvolution & SE)68 FalkorMarkStridedAccesses(LoopInfo &LI, ScalarEvolution &SE)
69 : LI(LI), SE(SE) {}
70
71 bool run();
72
73 private:
74 bool runOnLoop(Loop &L);
75
76 LoopInfo &LI;
77 ScalarEvolution &SE;
78 };
79
80 class FalkorMarkStridedAccessesLegacy : public FunctionPass {
81 public:
82 static char ID; // Pass ID, replacement for typeid
83
FalkorMarkStridedAccessesLegacy()84 FalkorMarkStridedAccessesLegacy() : FunctionPass(ID) {
85 initializeFalkorMarkStridedAccessesLegacyPass(
86 *PassRegistry::getPassRegistry());
87 }
88
getAnalysisUsage(AnalysisUsage & AU) const89 void getAnalysisUsage(AnalysisUsage &AU) const override {
90 AU.addRequired<TargetPassConfig>();
91 AU.addPreserved<DominatorTreeWrapperPass>();
92 AU.addRequired<LoopInfoWrapperPass>();
93 AU.addPreserved<LoopInfoWrapperPass>();
94 AU.addRequired<ScalarEvolutionWrapperPass>();
95 AU.addPreserved<ScalarEvolutionWrapperPass>();
96 }
97
98 bool runOnFunction(Function &F) override;
99 };
100
101 } // end anonymous namespace
102
103 char FalkorMarkStridedAccessesLegacy::ID = 0;
104
105 INITIALIZE_PASS_BEGIN(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
106 "Falkor HW Prefetch Fix", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)107 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
108 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
109 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
110 INITIALIZE_PASS_END(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
111 "Falkor HW Prefetch Fix", false, false)
112
113 FunctionPass *llvm::createFalkorMarkStridedAccessesPass() {
114 return new FalkorMarkStridedAccessesLegacy();
115 }
116
runOnFunction(Function & F)117 bool FalkorMarkStridedAccessesLegacy::runOnFunction(Function &F) {
118 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
119 const AArch64Subtarget *ST =
120 TPC.getTM<AArch64TargetMachine>().getSubtargetImpl(F);
121 if (ST->getProcFamily() != AArch64Subtarget::Falkor)
122 return false;
123
124 if (skipFunction(F))
125 return false;
126
127 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
128 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
129
130 FalkorMarkStridedAccesses LDP(LI, SE);
131 return LDP.run();
132 }
133
run()134 bool FalkorMarkStridedAccesses::run() {
135 bool MadeChange = false;
136
137 for (Loop *L : LI)
138 for (Loop *LIt : depth_first(L))
139 MadeChange |= runOnLoop(*LIt);
140
141 return MadeChange;
142 }
143
runOnLoop(Loop & L)144 bool FalkorMarkStridedAccesses::runOnLoop(Loop &L) {
145 // Only mark strided loads in the inner-most loop
146 if (!L.isInnermost())
147 return false;
148
149 bool MadeChange = false;
150
151 for (BasicBlock *BB : L.blocks()) {
152 for (Instruction &I : *BB) {
153 LoadInst *LoadI = dyn_cast<LoadInst>(&I);
154 if (!LoadI)
155 continue;
156
157 Value *PtrValue = LoadI->getPointerOperand();
158 if (L.isLoopInvariant(PtrValue))
159 continue;
160
161 const SCEV *LSCEV = SE.getSCEV(PtrValue);
162 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
163 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
164 continue;
165
166 LoadI->setMetadata(FALKOR_STRIDED_ACCESS_MD,
167 MDNode::get(LoadI->getContext(), {}));
168 ++NumStridedLoadsMarked;
169 LLVM_DEBUG(dbgs() << "Load: " << I << " marked as strided\n");
170 MadeChange = true;
171 }
172 }
173
174 return MadeChange;
175 }
176
177 namespace {
178
179 class FalkorHWPFFix : public MachineFunctionPass {
180 public:
181 static char ID;
182
FalkorHWPFFix()183 FalkorHWPFFix() : MachineFunctionPass(ID) {
184 initializeFalkorHWPFFixPass(*PassRegistry::getPassRegistry());
185 }
186
187 bool runOnMachineFunction(MachineFunction &Fn) override;
188
getAnalysisUsage(AnalysisUsage & AU) const189 void getAnalysisUsage(AnalysisUsage &AU) const override {
190 AU.setPreservesCFG();
191 AU.addRequired<MachineLoopInfoWrapperPass>();
192 MachineFunctionPass::getAnalysisUsage(AU);
193 }
194
getRequiredProperties() const195 MachineFunctionProperties getRequiredProperties() const override {
196 return MachineFunctionProperties().set(
197 MachineFunctionProperties::Property::NoVRegs);
198 }
199
200 private:
201 void runOnLoop(MachineLoop &L, MachineFunction &Fn);
202
203 const AArch64InstrInfo *TII;
204 const TargetRegisterInfo *TRI;
205 DenseMap<unsigned, SmallVector<MachineInstr *, 4>> TagMap;
206 bool Modified;
207 };
208
209 /// Bits from load opcodes used to compute HW prefetcher instruction tags.
210 struct LoadInfo {
211 LoadInfo() = default;
212
213 Register DestReg;
214 Register BaseReg;
215 int BaseRegIdx = -1;
216 const MachineOperand *OffsetOpnd = nullptr;
217 bool IsPrePost = false;
218 };
219
220 } // end anonymous namespace
221
222 char FalkorHWPFFix::ID = 0;
223
224 INITIALIZE_PASS_BEGIN(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late",
225 "Falkor HW Prefetch Fix Late Phase", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)226 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
227 INITIALIZE_PASS_END(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late",
228 "Falkor HW Prefetch Fix Late Phase", false, false)
229
230 static unsigned makeTag(unsigned Dest, unsigned Base, unsigned Offset) {
231 return (Dest & 0xf) | ((Base & 0xf) << 4) | ((Offset & 0x3f) << 8);
232 }
233
getLoadInfo(const MachineInstr & MI)234 static std::optional<LoadInfo> getLoadInfo(const MachineInstr &MI) {
235 int DestRegIdx;
236 int BaseRegIdx;
237 int OffsetIdx;
238 bool IsPrePost;
239
240 switch (MI.getOpcode()) {
241 default:
242 return std::nullopt;
243
244 case AArch64::LD1i64:
245 case AArch64::LD2i64:
246 DestRegIdx = 0;
247 BaseRegIdx = 3;
248 OffsetIdx = -1;
249 IsPrePost = false;
250 break;
251
252 case AArch64::LD1i8:
253 case AArch64::LD1i16:
254 case AArch64::LD1i32:
255 case AArch64::LD2i8:
256 case AArch64::LD2i16:
257 case AArch64::LD2i32:
258 case AArch64::LD3i8:
259 case AArch64::LD3i16:
260 case AArch64::LD3i32:
261 case AArch64::LD3i64:
262 case AArch64::LD4i8:
263 case AArch64::LD4i16:
264 case AArch64::LD4i32:
265 case AArch64::LD4i64:
266 DestRegIdx = -1;
267 BaseRegIdx = 3;
268 OffsetIdx = -1;
269 IsPrePost = false;
270 break;
271
272 case AArch64::LD1Onev1d:
273 case AArch64::LD1Onev2s:
274 case AArch64::LD1Onev4h:
275 case AArch64::LD1Onev8b:
276 case AArch64::LD1Onev2d:
277 case AArch64::LD1Onev4s:
278 case AArch64::LD1Onev8h:
279 case AArch64::LD1Onev16b:
280 case AArch64::LD1Rv1d:
281 case AArch64::LD1Rv2s:
282 case AArch64::LD1Rv4h:
283 case AArch64::LD1Rv8b:
284 case AArch64::LD1Rv2d:
285 case AArch64::LD1Rv4s:
286 case AArch64::LD1Rv8h:
287 case AArch64::LD1Rv16b:
288 DestRegIdx = 0;
289 BaseRegIdx = 1;
290 OffsetIdx = -1;
291 IsPrePost = false;
292 break;
293
294 case AArch64::LD1Twov1d:
295 case AArch64::LD1Twov2s:
296 case AArch64::LD1Twov4h:
297 case AArch64::LD1Twov8b:
298 case AArch64::LD1Twov2d:
299 case AArch64::LD1Twov4s:
300 case AArch64::LD1Twov8h:
301 case AArch64::LD1Twov16b:
302 case AArch64::LD1Threev1d:
303 case AArch64::LD1Threev2s:
304 case AArch64::LD1Threev4h:
305 case AArch64::LD1Threev8b:
306 case AArch64::LD1Threev2d:
307 case AArch64::LD1Threev4s:
308 case AArch64::LD1Threev8h:
309 case AArch64::LD1Threev16b:
310 case AArch64::LD1Fourv1d:
311 case AArch64::LD1Fourv2s:
312 case AArch64::LD1Fourv4h:
313 case AArch64::LD1Fourv8b:
314 case AArch64::LD1Fourv2d:
315 case AArch64::LD1Fourv4s:
316 case AArch64::LD1Fourv8h:
317 case AArch64::LD1Fourv16b:
318 case AArch64::LD2Twov2s:
319 case AArch64::LD2Twov4s:
320 case AArch64::LD2Twov8b:
321 case AArch64::LD2Twov2d:
322 case AArch64::LD2Twov4h:
323 case AArch64::LD2Twov8h:
324 case AArch64::LD2Twov16b:
325 case AArch64::LD2Rv1d:
326 case AArch64::LD2Rv2s:
327 case AArch64::LD2Rv4s:
328 case AArch64::LD2Rv8b:
329 case AArch64::LD2Rv2d:
330 case AArch64::LD2Rv4h:
331 case AArch64::LD2Rv8h:
332 case AArch64::LD2Rv16b:
333 case AArch64::LD3Threev2s:
334 case AArch64::LD3Threev4h:
335 case AArch64::LD3Threev8b:
336 case AArch64::LD3Threev2d:
337 case AArch64::LD3Threev4s:
338 case AArch64::LD3Threev8h:
339 case AArch64::LD3Threev16b:
340 case AArch64::LD3Rv1d:
341 case AArch64::LD3Rv2s:
342 case AArch64::LD3Rv4h:
343 case AArch64::LD3Rv8b:
344 case AArch64::LD3Rv2d:
345 case AArch64::LD3Rv4s:
346 case AArch64::LD3Rv8h:
347 case AArch64::LD3Rv16b:
348 case AArch64::LD4Fourv2s:
349 case AArch64::LD4Fourv4h:
350 case AArch64::LD4Fourv8b:
351 case AArch64::LD4Fourv2d:
352 case AArch64::LD4Fourv4s:
353 case AArch64::LD4Fourv8h:
354 case AArch64::LD4Fourv16b:
355 case AArch64::LD4Rv1d:
356 case AArch64::LD4Rv2s:
357 case AArch64::LD4Rv4h:
358 case AArch64::LD4Rv8b:
359 case AArch64::LD4Rv2d:
360 case AArch64::LD4Rv4s:
361 case AArch64::LD4Rv8h:
362 case AArch64::LD4Rv16b:
363 DestRegIdx = -1;
364 BaseRegIdx = 1;
365 OffsetIdx = -1;
366 IsPrePost = false;
367 break;
368
369 case AArch64::LD1i64_POST:
370 case AArch64::LD2i64_POST:
371 DestRegIdx = 1;
372 BaseRegIdx = 4;
373 OffsetIdx = 5;
374 IsPrePost = true;
375 break;
376
377 case AArch64::LD1i8_POST:
378 case AArch64::LD1i16_POST:
379 case AArch64::LD1i32_POST:
380 case AArch64::LD2i8_POST:
381 case AArch64::LD2i16_POST:
382 case AArch64::LD2i32_POST:
383 case AArch64::LD3i8_POST:
384 case AArch64::LD3i16_POST:
385 case AArch64::LD3i32_POST:
386 case AArch64::LD3i64_POST:
387 case AArch64::LD4i8_POST:
388 case AArch64::LD4i16_POST:
389 case AArch64::LD4i32_POST:
390 case AArch64::LD4i64_POST:
391 DestRegIdx = -1;
392 BaseRegIdx = 4;
393 OffsetIdx = 5;
394 IsPrePost = true;
395 break;
396
397 case AArch64::LD1Onev1d_POST:
398 case AArch64::LD1Onev2s_POST:
399 case AArch64::LD1Onev4h_POST:
400 case AArch64::LD1Onev8b_POST:
401 case AArch64::LD1Onev2d_POST:
402 case AArch64::LD1Onev4s_POST:
403 case AArch64::LD1Onev8h_POST:
404 case AArch64::LD1Onev16b_POST:
405 case AArch64::LD1Rv1d_POST:
406 case AArch64::LD1Rv2s_POST:
407 case AArch64::LD1Rv4h_POST:
408 case AArch64::LD1Rv8b_POST:
409 case AArch64::LD1Rv2d_POST:
410 case AArch64::LD1Rv4s_POST:
411 case AArch64::LD1Rv8h_POST:
412 case AArch64::LD1Rv16b_POST:
413 DestRegIdx = 1;
414 BaseRegIdx = 2;
415 OffsetIdx = 3;
416 IsPrePost = true;
417 break;
418
419 case AArch64::LD1Twov1d_POST:
420 case AArch64::LD1Twov2s_POST:
421 case AArch64::LD1Twov4h_POST:
422 case AArch64::LD1Twov8b_POST:
423 case AArch64::LD1Twov2d_POST:
424 case AArch64::LD1Twov4s_POST:
425 case AArch64::LD1Twov8h_POST:
426 case AArch64::LD1Twov16b_POST:
427 case AArch64::LD1Threev1d_POST:
428 case AArch64::LD1Threev2s_POST:
429 case AArch64::LD1Threev4h_POST:
430 case AArch64::LD1Threev8b_POST:
431 case AArch64::LD1Threev2d_POST:
432 case AArch64::LD1Threev4s_POST:
433 case AArch64::LD1Threev8h_POST:
434 case AArch64::LD1Threev16b_POST:
435 case AArch64::LD1Fourv1d_POST:
436 case AArch64::LD1Fourv2s_POST:
437 case AArch64::LD1Fourv4h_POST:
438 case AArch64::LD1Fourv8b_POST:
439 case AArch64::LD1Fourv2d_POST:
440 case AArch64::LD1Fourv4s_POST:
441 case AArch64::LD1Fourv8h_POST:
442 case AArch64::LD1Fourv16b_POST:
443 case AArch64::LD2Twov2s_POST:
444 case AArch64::LD2Twov4s_POST:
445 case AArch64::LD2Twov8b_POST:
446 case AArch64::LD2Twov2d_POST:
447 case AArch64::LD2Twov4h_POST:
448 case AArch64::LD2Twov8h_POST:
449 case AArch64::LD2Twov16b_POST:
450 case AArch64::LD2Rv1d_POST:
451 case AArch64::LD2Rv2s_POST:
452 case AArch64::LD2Rv4s_POST:
453 case AArch64::LD2Rv8b_POST:
454 case AArch64::LD2Rv2d_POST:
455 case AArch64::LD2Rv4h_POST:
456 case AArch64::LD2Rv8h_POST:
457 case AArch64::LD2Rv16b_POST:
458 case AArch64::LD3Threev2s_POST:
459 case AArch64::LD3Threev4h_POST:
460 case AArch64::LD3Threev8b_POST:
461 case AArch64::LD3Threev2d_POST:
462 case AArch64::LD3Threev4s_POST:
463 case AArch64::LD3Threev8h_POST:
464 case AArch64::LD3Threev16b_POST:
465 case AArch64::LD3Rv1d_POST:
466 case AArch64::LD3Rv2s_POST:
467 case AArch64::LD3Rv4h_POST:
468 case AArch64::LD3Rv8b_POST:
469 case AArch64::LD3Rv2d_POST:
470 case AArch64::LD3Rv4s_POST:
471 case AArch64::LD3Rv8h_POST:
472 case AArch64::LD3Rv16b_POST:
473 case AArch64::LD4Fourv2s_POST:
474 case AArch64::LD4Fourv4h_POST:
475 case AArch64::LD4Fourv8b_POST:
476 case AArch64::LD4Fourv2d_POST:
477 case AArch64::LD4Fourv4s_POST:
478 case AArch64::LD4Fourv8h_POST:
479 case AArch64::LD4Fourv16b_POST:
480 case AArch64::LD4Rv1d_POST:
481 case AArch64::LD4Rv2s_POST:
482 case AArch64::LD4Rv4h_POST:
483 case AArch64::LD4Rv8b_POST:
484 case AArch64::LD4Rv2d_POST:
485 case AArch64::LD4Rv4s_POST:
486 case AArch64::LD4Rv8h_POST:
487 case AArch64::LD4Rv16b_POST:
488 DestRegIdx = -1;
489 BaseRegIdx = 2;
490 OffsetIdx = 3;
491 IsPrePost = true;
492 break;
493
494 case AArch64::LDRBBroW:
495 case AArch64::LDRBBroX:
496 case AArch64::LDRBBui:
497 case AArch64::LDRBroW:
498 case AArch64::LDRBroX:
499 case AArch64::LDRBui:
500 case AArch64::LDRDl:
501 case AArch64::LDRDroW:
502 case AArch64::LDRDroX:
503 case AArch64::LDRDui:
504 case AArch64::LDRHHroW:
505 case AArch64::LDRHHroX:
506 case AArch64::LDRHHui:
507 case AArch64::LDRHroW:
508 case AArch64::LDRHroX:
509 case AArch64::LDRHui:
510 case AArch64::LDRQl:
511 case AArch64::LDRQroW:
512 case AArch64::LDRQroX:
513 case AArch64::LDRQui:
514 case AArch64::LDRSBWroW:
515 case AArch64::LDRSBWroX:
516 case AArch64::LDRSBWui:
517 case AArch64::LDRSBXroW:
518 case AArch64::LDRSBXroX:
519 case AArch64::LDRSBXui:
520 case AArch64::LDRSHWroW:
521 case AArch64::LDRSHWroX:
522 case AArch64::LDRSHWui:
523 case AArch64::LDRSHXroW:
524 case AArch64::LDRSHXroX:
525 case AArch64::LDRSHXui:
526 case AArch64::LDRSWl:
527 case AArch64::LDRSWroW:
528 case AArch64::LDRSWroX:
529 case AArch64::LDRSWui:
530 case AArch64::LDRSl:
531 case AArch64::LDRSroW:
532 case AArch64::LDRSroX:
533 case AArch64::LDRSui:
534 case AArch64::LDRWl:
535 case AArch64::LDRWroW:
536 case AArch64::LDRWroX:
537 case AArch64::LDRWui:
538 case AArch64::LDRXl:
539 case AArch64::LDRXroW:
540 case AArch64::LDRXroX:
541 case AArch64::LDRXui:
542 case AArch64::LDURBBi:
543 case AArch64::LDURBi:
544 case AArch64::LDURDi:
545 case AArch64::LDURHHi:
546 case AArch64::LDURHi:
547 case AArch64::LDURQi:
548 case AArch64::LDURSBWi:
549 case AArch64::LDURSBXi:
550 case AArch64::LDURSHWi:
551 case AArch64::LDURSHXi:
552 case AArch64::LDURSWi:
553 case AArch64::LDURSi:
554 case AArch64::LDURWi:
555 case AArch64::LDURXi:
556 DestRegIdx = 0;
557 BaseRegIdx = 1;
558 OffsetIdx = 2;
559 IsPrePost = false;
560 break;
561
562 case AArch64::LDRBBpost:
563 case AArch64::LDRBBpre:
564 case AArch64::LDRBpost:
565 case AArch64::LDRBpre:
566 case AArch64::LDRDpost:
567 case AArch64::LDRDpre:
568 case AArch64::LDRHHpost:
569 case AArch64::LDRHHpre:
570 case AArch64::LDRHpost:
571 case AArch64::LDRHpre:
572 case AArch64::LDRQpost:
573 case AArch64::LDRQpre:
574 case AArch64::LDRSBWpost:
575 case AArch64::LDRSBWpre:
576 case AArch64::LDRSBXpost:
577 case AArch64::LDRSBXpre:
578 case AArch64::LDRSHWpost:
579 case AArch64::LDRSHWpre:
580 case AArch64::LDRSHXpost:
581 case AArch64::LDRSHXpre:
582 case AArch64::LDRSWpost:
583 case AArch64::LDRSWpre:
584 case AArch64::LDRSpost:
585 case AArch64::LDRSpre:
586 case AArch64::LDRWpost:
587 case AArch64::LDRWpre:
588 case AArch64::LDRXpost:
589 case AArch64::LDRXpre:
590 DestRegIdx = 1;
591 BaseRegIdx = 2;
592 OffsetIdx = 3;
593 IsPrePost = true;
594 break;
595
596 case AArch64::LDNPDi:
597 case AArch64::LDNPQi:
598 case AArch64::LDNPSi:
599 case AArch64::LDPQi:
600 case AArch64::LDPDi:
601 case AArch64::LDPSi:
602 DestRegIdx = -1;
603 BaseRegIdx = 2;
604 OffsetIdx = 3;
605 IsPrePost = false;
606 break;
607
608 case AArch64::LDPSWi:
609 case AArch64::LDPWi:
610 case AArch64::LDPXi:
611 DestRegIdx = 0;
612 BaseRegIdx = 2;
613 OffsetIdx = 3;
614 IsPrePost = false;
615 break;
616
617 case AArch64::LDPQpost:
618 case AArch64::LDPQpre:
619 case AArch64::LDPDpost:
620 case AArch64::LDPDpre:
621 case AArch64::LDPSpost:
622 case AArch64::LDPSpre:
623 DestRegIdx = -1;
624 BaseRegIdx = 3;
625 OffsetIdx = 4;
626 IsPrePost = true;
627 break;
628
629 case AArch64::LDPSWpost:
630 case AArch64::LDPSWpre:
631 case AArch64::LDPWpost:
632 case AArch64::LDPWpre:
633 case AArch64::LDPXpost:
634 case AArch64::LDPXpre:
635 DestRegIdx = 1;
636 BaseRegIdx = 3;
637 OffsetIdx = 4;
638 IsPrePost = true;
639 break;
640 }
641
642 // Loads from the stack pointer don't get prefetched.
643 Register BaseReg = MI.getOperand(BaseRegIdx).getReg();
644 if (BaseReg == AArch64::SP || BaseReg == AArch64::WSP)
645 return std::nullopt;
646
647 LoadInfo LI;
648 LI.DestReg = DestRegIdx == -1 ? Register() : MI.getOperand(DestRegIdx).getReg();
649 LI.BaseReg = BaseReg;
650 LI.BaseRegIdx = BaseRegIdx;
651 LI.OffsetOpnd = OffsetIdx == -1 ? nullptr : &MI.getOperand(OffsetIdx);
652 LI.IsPrePost = IsPrePost;
653 return LI;
654 }
655
getTag(const TargetRegisterInfo * TRI,const MachineInstr & MI,const LoadInfo & LI)656 static std::optional<unsigned> getTag(const TargetRegisterInfo *TRI,
657 const MachineInstr &MI,
658 const LoadInfo &LI) {
659 unsigned Dest = LI.DestReg ? TRI->getEncodingValue(LI.DestReg) : 0;
660 unsigned Base = TRI->getEncodingValue(LI.BaseReg);
661 unsigned Off;
662 if (LI.OffsetOpnd == nullptr)
663 Off = 0;
664 else if (LI.OffsetOpnd->isGlobal() || LI.OffsetOpnd->isSymbol() ||
665 LI.OffsetOpnd->isCPI())
666 return std::nullopt;
667 else if (LI.OffsetOpnd->isReg())
668 Off = (1 << 5) | TRI->getEncodingValue(LI.OffsetOpnd->getReg());
669 else
670 Off = LI.OffsetOpnd->getImm() >> 2;
671
672 return makeTag(Dest, Base, Off);
673 }
674
runOnLoop(MachineLoop & L,MachineFunction & Fn)675 void FalkorHWPFFix::runOnLoop(MachineLoop &L, MachineFunction &Fn) {
676 // Build the initial tag map for the whole loop.
677 TagMap.clear();
678 for (MachineBasicBlock *MBB : L.getBlocks())
679 for (MachineInstr &MI : *MBB) {
680 std::optional<LoadInfo> LInfo = getLoadInfo(MI);
681 if (!LInfo)
682 continue;
683 std::optional<unsigned> Tag = getTag(TRI, MI, *LInfo);
684 if (!Tag)
685 continue;
686 TagMap[*Tag].push_back(&MI);
687 }
688
689 bool AnyCollisions = false;
690 for (auto &P : TagMap) {
691 auto Size = P.second.size();
692 if (Size > 1) {
693 for (auto *MI : P.second) {
694 if (TII->isStridedAccess(*MI)) {
695 AnyCollisions = true;
696 break;
697 }
698 }
699 }
700 if (AnyCollisions)
701 break;
702 }
703 // Nothing to fix.
704 if (!AnyCollisions)
705 return;
706
707 MachineRegisterInfo &MRI = Fn.getRegInfo();
708
709 // Go through all the basic blocks in the current loop and fix any streaming
710 // loads to avoid collisions with any other loads.
711 LiveRegUnits LR(*TRI);
712 for (MachineBasicBlock *MBB : L.getBlocks()) {
713 LR.clear();
714 LR.addLiveOuts(*MBB);
715 for (auto I = MBB->rbegin(); I != MBB->rend(); LR.stepBackward(*I), ++I) {
716 MachineInstr &MI = *I;
717 if (!TII->isStridedAccess(MI))
718 continue;
719
720 std::optional<LoadInfo> OptLdI = getLoadInfo(MI);
721 if (!OptLdI)
722 continue;
723 LoadInfo LdI = *OptLdI;
724 std::optional<unsigned> OptOldTag = getTag(TRI, MI, LdI);
725 if (!OptOldTag)
726 continue;
727 auto &OldCollisions = TagMap[*OptOldTag];
728 if (OldCollisions.size() <= 1)
729 continue;
730
731 bool Fixed = false;
732 LLVM_DEBUG(dbgs() << "Attempting to fix tag collision: " << MI);
733
734 if (!DebugCounter::shouldExecute(FixCounter)) {
735 LLVM_DEBUG(dbgs() << "Skipping fix due to debug counter:\n " << MI);
736 continue;
737 }
738
739 // Add the non-base registers of MI as live so we don't use them as
740 // scratch registers.
741 for (unsigned OpI = 0, OpE = MI.getNumOperands(); OpI < OpE; ++OpI) {
742 if (OpI == static_cast<unsigned>(LdI.BaseRegIdx))
743 continue;
744 MachineOperand &MO = MI.getOperand(OpI);
745 if (MO.isReg() && MO.readsReg())
746 LR.addReg(MO.getReg());
747 }
748
749 for (unsigned ScratchReg : AArch64::GPR64RegClass) {
750 if (!LR.available(ScratchReg) || MRI.isReserved(ScratchReg))
751 continue;
752
753 LoadInfo NewLdI(LdI);
754 NewLdI.BaseReg = ScratchReg;
755 unsigned NewTag = *getTag(TRI, MI, NewLdI);
756 // Scratch reg tag would collide too, so don't use it.
757 if (TagMap.count(NewTag))
758 continue;
759
760 LLVM_DEBUG(dbgs() << "Changing base reg to: "
761 << printReg(ScratchReg, TRI) << '\n');
762
763 // Rewrite:
764 // Xd = LOAD Xb, off
765 // to:
766 // Xc = MOV Xb
767 // Xd = LOAD Xc, off
768 DebugLoc DL = MI.getDebugLoc();
769 BuildMI(*MBB, &MI, DL, TII->get(AArch64::ORRXrs), ScratchReg)
770 .addReg(AArch64::XZR)
771 .addReg(LdI.BaseReg)
772 .addImm(0);
773 MachineOperand &BaseOpnd = MI.getOperand(LdI.BaseRegIdx);
774 BaseOpnd.setReg(ScratchReg);
775
776 // If the load does a pre/post increment, then insert a MOV after as
777 // well to update the real base register.
778 if (LdI.IsPrePost) {
779 LLVM_DEBUG(dbgs() << "Doing post MOV of incremented reg: "
780 << printReg(ScratchReg, TRI) << '\n');
781 MI.getOperand(0).setReg(
782 ScratchReg); // Change tied operand pre/post update dest.
783 BuildMI(*MBB, std::next(MachineBasicBlock::iterator(MI)), DL,
784 TII->get(AArch64::ORRXrs), LdI.BaseReg)
785 .addReg(AArch64::XZR)
786 .addReg(ScratchReg)
787 .addImm(0);
788 }
789
790 for (int I = 0, E = OldCollisions.size(); I != E; ++I)
791 if (OldCollisions[I] == &MI) {
792 std::swap(OldCollisions[I], OldCollisions[E - 1]);
793 OldCollisions.pop_back();
794 break;
795 }
796
797 // Update TagMap to reflect instruction changes to reduce the number
798 // of later MOVs to be inserted. This needs to be done after
799 // OldCollisions is updated since it may be relocated by this
800 // insertion.
801 TagMap[NewTag].push_back(&MI);
802 ++NumCollisionsAvoided;
803 Fixed = true;
804 Modified = true;
805 break;
806 }
807 if (!Fixed)
808 ++NumCollisionsNotAvoided;
809 }
810 }
811 }
812
runOnMachineFunction(MachineFunction & Fn)813 bool FalkorHWPFFix::runOnMachineFunction(MachineFunction &Fn) {
814 auto &ST = Fn.getSubtarget<AArch64Subtarget>();
815 if (ST.getProcFamily() != AArch64Subtarget::Falkor)
816 return false;
817
818 if (skipFunction(Fn.getFunction()))
819 return false;
820
821 TII = static_cast<const AArch64InstrInfo *>(ST.getInstrInfo());
822 TRI = ST.getRegisterInfo();
823
824 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
825
826 Modified = false;
827
828 for (MachineLoop *I : LI)
829 for (MachineLoop *L : depth_first(I))
830 // Only process inner-loops
831 if (L->isInnermost())
832 runOnLoop(*L, Fn);
833
834 return Modified;
835 }
836
createFalkorHWPFFixPass()837 FunctionPass *llvm::createFalkorHWPFFixPass() { return new FalkorHWPFFix(); }
838