xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp (revision b49b6e0f95c89f8dcb5898424c360b46019254b4)
1  //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements a Loop Data Prefetching Pass.
10  //
11  //===----------------------------------------------------------------------===//
12  
13  #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14  #include "llvm/InitializePasses.h"
15  
16  #include "llvm/ADT/DepthFirstIterator.h"
17  #include "llvm/ADT/Statistic.h"
18  #include "llvm/Analysis/AssumptionCache.h"
19  #include "llvm/Analysis/CodeMetrics.h"
20  #include "llvm/Analysis/LoopInfo.h"
21  #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22  #include "llvm/Analysis/ScalarEvolution.h"
23  #include "llvm/Analysis/ScalarEvolutionExpressions.h"
24  #include "llvm/Analysis/TargetTransformInfo.h"
25  #include "llvm/IR/CFG.h"
26  #include "llvm/IR/Dominators.h"
27  #include "llvm/IR/Function.h"
28  #include "llvm/IR/Module.h"
29  #include "llvm/Support/CommandLine.h"
30  #include "llvm/Support/Debug.h"
31  #include "llvm/Transforms/Scalar.h"
32  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
33  #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
34  #include "llvm/Transforms/Utils/ValueMapper.h"
35  
36  #define DEBUG_TYPE "loop-data-prefetch"
37  
38  using namespace llvm;
39  
40  // By default, we limit this to creating 16 PHIs (which is a little over half
41  // of the allocatable register set).
42  static cl::opt<bool>
43  PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
44                 cl::desc("Prefetch write addresses"));
45  
46  static cl::opt<unsigned>
47      PrefetchDistance("prefetch-distance",
48                       cl::desc("Number of instructions to prefetch ahead"),
49                       cl::Hidden);
50  
51  static cl::opt<unsigned>
52      MinPrefetchStride("min-prefetch-stride",
53                        cl::desc("Min stride to add prefetches"), cl::Hidden);
54  
55  static cl::opt<unsigned> MaxPrefetchIterationsAhead(
56      "max-prefetch-iters-ahead",
57      cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
58  
59  STATISTIC(NumPrefetches, "Number of prefetches inserted");
60  
61  namespace {
62  
63  /// Loop prefetch implementation class.
64  class LoopDataPrefetch {
65  public:
66    LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
67                     ScalarEvolution *SE, const TargetTransformInfo *TTI,
68                     OptimizationRemarkEmitter *ORE)
69        : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
70  
71    bool run();
72  
73  private:
74    bool runOnLoop(Loop *L);
75  
76    /// Check if the stride of the accesses is large enough to
77    /// warrant a prefetch.
78    bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
79  
80    unsigned getMinPrefetchStride(unsigned NumMemAccesses,
81                                  unsigned NumStridedMemAccesses,
82                                  unsigned NumPrefetches,
83                                  bool HasCall) {
84      if (MinPrefetchStride.getNumOccurrences() > 0)
85        return MinPrefetchStride;
86      return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
87                                       NumPrefetches, HasCall);
88    }
89  
90    unsigned getPrefetchDistance() {
91      if (PrefetchDistance.getNumOccurrences() > 0)
92        return PrefetchDistance;
93      return TTI->getPrefetchDistance();
94    }
95  
96    unsigned getMaxPrefetchIterationsAhead() {
97      if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
98        return MaxPrefetchIterationsAhead;
99      return TTI->getMaxPrefetchIterationsAhead();
100    }
101  
102    bool doPrefetchWrites() {
103      if (PrefetchWrites.getNumOccurrences() > 0)
104        return PrefetchWrites;
105      return TTI->enableWritePrefetching();
106    }
107  
108    AssumptionCache *AC;
109    DominatorTree *DT;
110    LoopInfo *LI;
111    ScalarEvolution *SE;
112    const TargetTransformInfo *TTI;
113    OptimizationRemarkEmitter *ORE;
114  };
115  
116  /// Legacy class for inserting loop data prefetches.
117  class LoopDataPrefetchLegacyPass : public FunctionPass {
118  public:
119    static char ID; // Pass ID, replacement for typeid
120    LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
121      initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
122    }
123  
124    void getAnalysisUsage(AnalysisUsage &AU) const override {
125      AU.addRequired<AssumptionCacheTracker>();
126      AU.addRequired<DominatorTreeWrapperPass>();
127      AU.addPreserved<DominatorTreeWrapperPass>();
128      AU.addRequired<LoopInfoWrapperPass>();
129      AU.addPreserved<LoopInfoWrapperPass>();
130      AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
131      AU.addRequired<ScalarEvolutionWrapperPass>();
132      AU.addPreserved<ScalarEvolutionWrapperPass>();
133      AU.addRequired<TargetTransformInfoWrapperPass>();
134    }
135  
136    bool runOnFunction(Function &F) override;
137    };
138  }
139  
140  char LoopDataPrefetchLegacyPass::ID = 0;
141  INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
142                        "Loop Data Prefetch", false, false)
143  INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
144  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
145  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
146  INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
147  INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
148  INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
149                      "Loop Data Prefetch", false, false)
150  
151  FunctionPass *llvm::createLoopDataPrefetchPass() {
152    return new LoopDataPrefetchLegacyPass();
153  }
154  
155  bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
156                                             unsigned TargetMinStride) {
157    // No need to check if any stride goes.
158    if (TargetMinStride <= 1)
159      return true;
160  
161    const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
162    // If MinStride is set, don't prefetch unless we can ensure that stride is
163    // larger.
164    if (!ConstStride)
165      return false;
166  
167    unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
168    return TargetMinStride <= AbsStride;
169  }
170  
171  PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
172                                              FunctionAnalysisManager &AM) {
173    DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
174    LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
175    ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
176    AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
177    OptimizationRemarkEmitter *ORE =
178        &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
179    const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
180  
181    LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
182    bool Changed = LDP.run();
183  
184    if (Changed) {
185      PreservedAnalyses PA;
186      PA.preserve<DominatorTreeAnalysis>();
187      PA.preserve<LoopAnalysis>();
188      return PA;
189    }
190  
191    return PreservedAnalyses::all();
192  }
193  
194  bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
195    if (skipFunction(F))
196      return false;
197  
198    DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
199    LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
200    ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
201    AssumptionCache *AC =
202        &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
203    OptimizationRemarkEmitter *ORE =
204        &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
205    const TargetTransformInfo *TTI =
206        &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
207  
208    LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
209    return LDP.run();
210  }
211  
212  bool LoopDataPrefetch::run() {
213    // If PrefetchDistance is not set, don't run the pass.  This gives an
214    // opportunity for targets to run this pass for selected subtargets only
215    // (whose TTI sets PrefetchDistance).
216    if (getPrefetchDistance() == 0)
217      return false;
218    assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
219  
220    bool MadeChange = false;
221  
222    for (Loop *I : *LI)
223      for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
224        MadeChange |= runOnLoop(*L);
225  
226    return MadeChange;
227  }
228  
229  /// A record for a potential prefetch made during the initial scan of the
230  /// loop. This is used to let a single prefetch target multiple memory accesses.
231  struct Prefetch {
232    /// The address formula for this prefetch as returned by ScalarEvolution.
233    const SCEVAddRecExpr *LSCEVAddRec;
234    /// The point of insertion for the prefetch instruction.
235    Instruction *InsertPt;
236    /// True if targeting a write memory access.
237    bool Writes;
238    /// The (first seen) prefetched instruction.
239    Instruction *MemI;
240  
241    /// Constructor to create a new Prefetch for \p I.
242    Prefetch(const SCEVAddRecExpr *L, Instruction *I)
243        : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
244      addInstruction(I);
245    };
246  
247    /// Add the instruction \param I to this prefetch. If it's not the first
248    /// one, 'InsertPt' and 'Writes' will be updated as required.
249    /// \param PtrDiff the known constant address difference to the first added
250    /// instruction.
251    void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
252                        int64_t PtrDiff = 0) {
253      if (!InsertPt) {
254        MemI = I;
255        InsertPt = I;
256        Writes = isa<StoreInst>(I);
257      } else {
258        BasicBlock *PrefBB = InsertPt->getParent();
259        BasicBlock *InsBB = I->getParent();
260        if (PrefBB != InsBB) {
261          BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
262          if (DomBB != PrefBB)
263            InsertPt = DomBB->getTerminator();
264        }
265  
266        if (isa<StoreInst>(I) && PtrDiff == 0)
267          Writes = true;
268      }
269    }
270  };
271  
272  bool LoopDataPrefetch::runOnLoop(Loop *L) {
273    bool MadeChange = false;
274  
275    // Only prefetch in the inner-most loop
276    if (!L->isInnermost())
277      return MadeChange;
278  
279    SmallPtrSet<const Value *, 32> EphValues;
280    CodeMetrics::collectEphemeralValues(L, AC, EphValues);
281  
282    // Calculate the number of iterations ahead to prefetch
283    CodeMetrics Metrics;
284    bool HasCall = false;
285    for (const auto BB : L->blocks()) {
286      // If the loop already has prefetches, then assume that the user knows
287      // what they are doing and don't add any more.
288      for (auto &I : *BB) {
289        if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
290          if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
291            if (F->getIntrinsicID() == Intrinsic::prefetch)
292              return MadeChange;
293            if (TTI->isLoweredToCall(F))
294              HasCall = true;
295          } else { // indirect call.
296            HasCall = true;
297          }
298        }
299      }
300      Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
301    }
302    unsigned LoopSize = Metrics.NumInsts;
303    if (!LoopSize)
304      LoopSize = 1;
305  
306    unsigned ItersAhead = getPrefetchDistance() / LoopSize;
307    if (!ItersAhead)
308      ItersAhead = 1;
309  
310    if (ItersAhead > getMaxPrefetchIterationsAhead())
311      return MadeChange;
312  
313    unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
314    if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
315      return MadeChange;
316  
317    unsigned NumMemAccesses = 0;
318    unsigned NumStridedMemAccesses = 0;
319    SmallVector<Prefetch, 16> Prefetches;
320    for (const auto BB : L->blocks())
321      for (auto &I : *BB) {
322        Value *PtrValue;
323        Instruction *MemI;
324  
325        if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
326          MemI = LMemI;
327          PtrValue = LMemI->getPointerOperand();
328        } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
329          if (!doPrefetchWrites()) continue;
330          MemI = SMemI;
331          PtrValue = SMemI->getPointerOperand();
332        } else continue;
333  
334        unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
335        if (PtrAddrSpace)
336          continue;
337        NumMemAccesses++;
338        if (L->isLoopInvariant(PtrValue))
339          continue;
340  
341        const SCEV *LSCEV = SE->getSCEV(PtrValue);
342        const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
343        if (!LSCEVAddRec)
344          continue;
345        NumStridedMemAccesses++;
346  
347        // We don't want to double prefetch individual cache lines. If this
348        // access is known to be within one cache line of some other one that
349        // has already been prefetched, then don't prefetch this one as well.
350        bool DupPref = false;
351        for (auto &Pref : Prefetches) {
352          const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
353          if (const SCEVConstant *ConstPtrDiff =
354              dyn_cast<SCEVConstant>(PtrDiff)) {
355            int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
356            if (PD < (int64_t) TTI->getCacheLineSize()) {
357              Pref.addInstruction(MemI, DT, PD);
358              DupPref = true;
359              break;
360            }
361          }
362        }
363        if (!DupPref)
364          Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
365      }
366  
367    unsigned TargetMinStride =
368      getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
369                           Prefetches.size(), HasCall);
370  
371    LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
372               << " iterations ahead (loop size: " << LoopSize << ") in "
373               << L->getHeader()->getParent()->getName() << ": " << *L);
374    LLVM_DEBUG(dbgs() << "Loop has: "
375               << NumMemAccesses << " memory accesses, "
376               << NumStridedMemAccesses << " strided memory accesses, "
377               << Prefetches.size() << " potential prefetch(es), "
378               << "a minimum stride of " << TargetMinStride << ", "
379               << (HasCall ? "calls" : "no calls") << ".\n");
380  
381    for (auto &P : Prefetches) {
382      // Check if the stride of the accesses is large enough to warrant a
383      // prefetch.
384      if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
385        continue;
386  
387      const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
388        SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
389        P.LSCEVAddRec->getStepRecurrence(*SE)));
390      if (!isSafeToExpand(NextLSCEV, *SE))
391        continue;
392  
393      BasicBlock *BB = P.InsertPt->getParent();
394      Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
395      SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
396      Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
397  
398      IRBuilder<> Builder(P.InsertPt);
399      Module *M = BB->getParent()->getParent();
400      Type *I32 = Type::getInt32Ty(BB->getContext());
401      Function *PrefetchFunc = Intrinsic::getDeclaration(
402          M, Intrinsic::prefetch, PrefPtrValue->getType());
403      Builder.CreateCall(
404          PrefetchFunc,
405          {PrefPtrValue,
406           ConstantInt::get(I32, P.Writes),
407           ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
408      ++NumPrefetches;
409      LLVM_DEBUG(dbgs() << "  Access: "
410                 << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
411                 << ", SCEV: " << *P.LSCEVAddRec << "\n");
412      ORE->emit([&]() {
413          return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
414            << "prefetched memory access";
415        });
416  
417      MadeChange = true;
418    }
419  
420    return MadeChange;
421  }
422