xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp (revision bfce69ae9a0731f87574fcbe96f14e5f8e77f036)
1  //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements a Loop Data Prefetching Pass.
10  //
11  //===----------------------------------------------------------------------===//
12  
13  #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14  #include "llvm/InitializePasses.h"
15  
16  #define DEBUG_TYPE "loop-data-prefetch"
17  #include "llvm/ADT/DepthFirstIterator.h"
18  #include "llvm/ADT/Statistic.h"
19  #include "llvm/Analysis/AssumptionCache.h"
20  #include "llvm/Analysis/CodeMetrics.h"
21  #include "llvm/Analysis/LoopInfo.h"
22  #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23  #include "llvm/Analysis/ScalarEvolution.h"
24  #include "llvm/Analysis/ScalarEvolutionExpressions.h"
25  #include "llvm/Analysis/TargetTransformInfo.h"
26  #include "llvm/IR/CFG.h"
27  #include "llvm/IR/Dominators.h"
28  #include "llvm/IR/Function.h"
29  #include "llvm/IR/Module.h"
30  #include "llvm/Support/CommandLine.h"
31  #include "llvm/Support/Debug.h"
32  #include "llvm/Transforms/Scalar.h"
33  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34  #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
35  #include "llvm/Transforms/Utils/ValueMapper.h"
36  using namespace llvm;
37  
38  // By default, we limit this to creating 16 PHIs (which is a little over half
39  // of the allocatable register set).
40  static cl::opt<bool>
41  PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
42                 cl::desc("Prefetch write addresses"));
43  
44  static cl::opt<unsigned>
45      PrefetchDistance("prefetch-distance",
46                       cl::desc("Number of instructions to prefetch ahead"),
47                       cl::Hidden);
48  
49  static cl::opt<unsigned>
50      MinPrefetchStride("min-prefetch-stride",
51                        cl::desc("Min stride to add prefetches"), cl::Hidden);
52  
53  static cl::opt<unsigned> MaxPrefetchIterationsAhead(
54      "max-prefetch-iters-ahead",
55      cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
56  
57  STATISTIC(NumPrefetches, "Number of prefetches inserted");
58  
59  namespace {
60  
61  /// Loop prefetch implementation class.
62  class LoopDataPrefetch {
63  public:
64    LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
65                     ScalarEvolution *SE, const TargetTransformInfo *TTI,
66                     OptimizationRemarkEmitter *ORE)
67        : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
68  
69    bool run();
70  
71  private:
72    bool runOnLoop(Loop *L);
73  
74    /// Check if the stride of the accesses is large enough to
75    /// warrant a prefetch.
76    bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
77  
78    unsigned getMinPrefetchStride(unsigned NumMemAccesses,
79                                  unsigned NumStridedMemAccesses,
80                                  unsigned NumPrefetches,
81                                  bool HasCall) {
82      if (MinPrefetchStride.getNumOccurrences() > 0)
83        return MinPrefetchStride;
84      return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
85                                       NumPrefetches, HasCall);
86    }
87  
88    unsigned getPrefetchDistance() {
89      if (PrefetchDistance.getNumOccurrences() > 0)
90        return PrefetchDistance;
91      return TTI->getPrefetchDistance();
92    }
93  
94    unsigned getMaxPrefetchIterationsAhead() {
95      if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
96        return MaxPrefetchIterationsAhead;
97      return TTI->getMaxPrefetchIterationsAhead();
98    }
99  
100    bool doPrefetchWrites() {
101      if (PrefetchWrites.getNumOccurrences() > 0)
102        return PrefetchWrites;
103      return TTI->enableWritePrefetching();
104    }
105  
106    AssumptionCache *AC;
107    DominatorTree *DT;
108    LoopInfo *LI;
109    ScalarEvolution *SE;
110    const TargetTransformInfo *TTI;
111    OptimizationRemarkEmitter *ORE;
112  };
113  
114  /// Legacy class for inserting loop data prefetches.
115  class LoopDataPrefetchLegacyPass : public FunctionPass {
116  public:
117    static char ID; // Pass ID, replacement for typeid
118    LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
119      initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
120    }
121  
122    void getAnalysisUsage(AnalysisUsage &AU) const override {
123      AU.addRequired<AssumptionCacheTracker>();
124      AU.addRequired<DominatorTreeWrapperPass>();
125      AU.addPreserved<DominatorTreeWrapperPass>();
126      AU.addRequired<LoopInfoWrapperPass>();
127      AU.addPreserved<LoopInfoWrapperPass>();
128      AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
129      AU.addRequired<ScalarEvolutionWrapperPass>();
130      AU.addPreserved<ScalarEvolutionWrapperPass>();
131      AU.addRequired<TargetTransformInfoWrapperPass>();
132    }
133  
134    bool runOnFunction(Function &F) override;
135    };
136  }
137  
138  char LoopDataPrefetchLegacyPass::ID = 0;
139  INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
140                        "Loop Data Prefetch", false, false)
141  INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
142  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
143  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
144  INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
145  INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
146  INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
147                      "Loop Data Prefetch", false, false)
148  
149  FunctionPass *llvm::createLoopDataPrefetchPass() {
150    return new LoopDataPrefetchLegacyPass();
151  }
152  
153  bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
154                                             unsigned TargetMinStride) {
155    // No need to check if any stride goes.
156    if (TargetMinStride <= 1)
157      return true;
158  
159    const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
160    // If MinStride is set, don't prefetch unless we can ensure that stride is
161    // larger.
162    if (!ConstStride)
163      return false;
164  
165    unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
166    return TargetMinStride <= AbsStride;
167  }
168  
169  PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
170                                              FunctionAnalysisManager &AM) {
171    DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
172    LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
173    ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
174    AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
175    OptimizationRemarkEmitter *ORE =
176        &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
177    const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
178  
179    LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
180    bool Changed = LDP.run();
181  
182    if (Changed) {
183      PreservedAnalyses PA;
184      PA.preserve<DominatorTreeAnalysis>();
185      PA.preserve<LoopAnalysis>();
186      return PA;
187    }
188  
189    return PreservedAnalyses::all();
190  }
191  
192  bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
193    if (skipFunction(F))
194      return false;
195  
196    DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
197    LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
198    ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
199    AssumptionCache *AC =
200        &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
201    OptimizationRemarkEmitter *ORE =
202        &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
203    const TargetTransformInfo *TTI =
204        &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
205  
206    LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
207    return LDP.run();
208  }
209  
210  bool LoopDataPrefetch::run() {
211    // If PrefetchDistance is not set, don't run the pass.  This gives an
212    // opportunity for targets to run this pass for selected subtargets only
213    // (whose TTI sets PrefetchDistance).
214    if (getPrefetchDistance() == 0)
215      return false;
216    assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
217  
218    bool MadeChange = false;
219  
220    for (Loop *I : *LI)
221      for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
222        MadeChange |= runOnLoop(*L);
223  
224    return MadeChange;
225  }
226  
227  /// A record for a potential prefetch made during the initial scan of the
228  /// loop. This is used to let a single prefetch target multiple memory accesses.
229  struct Prefetch {
230    /// The address formula for this prefetch as returned by ScalarEvolution.
231    const SCEVAddRecExpr *LSCEVAddRec;
232    /// The point of insertion for the prefetch instruction.
233    Instruction *InsertPt;
234    /// True if targeting a write memory access.
235    bool Writes;
236    /// The (first seen) prefetched instruction.
237    Instruction *MemI;
238  
239    /// Constructor to create a new Prefetch for \p I.
240    Prefetch(const SCEVAddRecExpr *L, Instruction *I)
241        : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
242      addInstruction(I);
243    };
244  
245    /// Add the instruction \param I to this prefetch. If it's not the first
246    /// one, 'InsertPt' and 'Writes' will be updated as required.
247    /// \param PtrDiff the known constant address difference to the first added
248    /// instruction.
249    void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
250                        int64_t PtrDiff = 0) {
251      if (!InsertPt) {
252        MemI = I;
253        InsertPt = I;
254        Writes = isa<StoreInst>(I);
255      } else {
256        BasicBlock *PrefBB = InsertPt->getParent();
257        BasicBlock *InsBB = I->getParent();
258        if (PrefBB != InsBB) {
259          BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
260          if (DomBB != PrefBB)
261            InsertPt = DomBB->getTerminator();
262        }
263  
264        if (isa<StoreInst>(I) && PtrDiff == 0)
265          Writes = true;
266      }
267    }
268  };
269  
270  bool LoopDataPrefetch::runOnLoop(Loop *L) {
271    bool MadeChange = false;
272  
273    // Only prefetch in the inner-most loop
274    if (!L->isInnermost())
275      return MadeChange;
276  
277    SmallPtrSet<const Value *, 32> EphValues;
278    CodeMetrics::collectEphemeralValues(L, AC, EphValues);
279  
280    // Calculate the number of iterations ahead to prefetch
281    CodeMetrics Metrics;
282    bool HasCall = false;
283    for (const auto BB : L->blocks()) {
284      // If the loop already has prefetches, then assume that the user knows
285      // what they are doing and don't add any more.
286      for (auto &I : *BB) {
287        if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
288          if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
289            if (F->getIntrinsicID() == Intrinsic::prefetch)
290              return MadeChange;
291            if (TTI->isLoweredToCall(F))
292              HasCall = true;
293          } else { // indirect call.
294            HasCall = true;
295          }
296        }
297      }
298      Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
299    }
300    unsigned LoopSize = Metrics.NumInsts;
301    if (!LoopSize)
302      LoopSize = 1;
303  
304    unsigned ItersAhead = getPrefetchDistance() / LoopSize;
305    if (!ItersAhead)
306      ItersAhead = 1;
307  
308    if (ItersAhead > getMaxPrefetchIterationsAhead())
309      return MadeChange;
310  
311    unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
312    if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
313      return MadeChange;
314  
315    unsigned NumMemAccesses = 0;
316    unsigned NumStridedMemAccesses = 0;
317    SmallVector<Prefetch, 16> Prefetches;
318    for (const auto BB : L->blocks())
319      for (auto &I : *BB) {
320        Value *PtrValue;
321        Instruction *MemI;
322  
323        if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
324          MemI = LMemI;
325          PtrValue = LMemI->getPointerOperand();
326        } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
327          if (!doPrefetchWrites()) continue;
328          MemI = SMemI;
329          PtrValue = SMemI->getPointerOperand();
330        } else continue;
331  
332        unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
333        if (PtrAddrSpace)
334          continue;
335        NumMemAccesses++;
336        if (L->isLoopInvariant(PtrValue))
337          continue;
338  
339        const SCEV *LSCEV = SE->getSCEV(PtrValue);
340        const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
341        if (!LSCEVAddRec)
342          continue;
343        NumStridedMemAccesses++;
344  
345        // We don't want to double prefetch individual cache lines. If this
346        // access is known to be within one cache line of some other one that
347        // has already been prefetched, then don't prefetch this one as well.
348        bool DupPref = false;
349        for (auto &Pref : Prefetches) {
350          const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
351          if (const SCEVConstant *ConstPtrDiff =
352              dyn_cast<SCEVConstant>(PtrDiff)) {
353            int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
354            if (PD < (int64_t) TTI->getCacheLineSize()) {
355              Pref.addInstruction(MemI, DT, PD);
356              DupPref = true;
357              break;
358            }
359          }
360        }
361        if (!DupPref)
362          Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
363      }
364  
365    unsigned TargetMinStride =
366      getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
367                           Prefetches.size(), HasCall);
368  
369    LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
370               << " iterations ahead (loop size: " << LoopSize << ") in "
371               << L->getHeader()->getParent()->getName() << ": " << *L);
372    LLVM_DEBUG(dbgs() << "Loop has: "
373               << NumMemAccesses << " memory accesses, "
374               << NumStridedMemAccesses << " strided memory accesses, "
375               << Prefetches.size() << " potential prefetch(es), "
376               << "a minimum stride of " << TargetMinStride << ", "
377               << (HasCall ? "calls" : "no calls") << ".\n");
378  
379    for (auto &P : Prefetches) {
380      // Check if the stride of the accesses is large enough to warrant a
381      // prefetch.
382      if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
383        continue;
384  
385      const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
386        SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
387        P.LSCEVAddRec->getStepRecurrence(*SE)));
388      if (!isSafeToExpand(NextLSCEV, *SE))
389        continue;
390  
391      BasicBlock *BB = P.InsertPt->getParent();
392      Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
393      SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
394      Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
395  
396      IRBuilder<> Builder(P.InsertPt);
397      Module *M = BB->getParent()->getParent();
398      Type *I32 = Type::getInt32Ty(BB->getContext());
399      Function *PrefetchFunc = Intrinsic::getDeclaration(
400          M, Intrinsic::prefetch, PrefPtrValue->getType());
401      Builder.CreateCall(
402          PrefetchFunc,
403          {PrefPtrValue,
404           ConstantInt::get(I32, P.Writes),
405           ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
406      ++NumPrefetches;
407      LLVM_DEBUG(dbgs() << "  Access: "
408                 << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
409                 << ", SCEV: " << *P.LSCEVAddRec << "\n");
410      ORE->emit([&]() {
411          return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
412            << "prefetched memory access";
413        });
414  
415      MadeChange = true;
416    }
417  
418    return MadeChange;
419  }
420