xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp (revision e40139ff33b48b56a24c808b166b04b8ee6f5b21)
1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a Loop Data Prefetching Pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14 
15 #define DEBUG_TYPE "loop-data-prefetch"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpander.h"
24 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/ValueMapper.h"
35 using namespace llvm;
36 
37 // By default, we limit this to creating 16 PHIs (which is a little over half
38 // of the allocatable register set).
39 static cl::opt<bool>
40 PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
41                cl::desc("Prefetch write addresses"));
42 
43 static cl::opt<unsigned>
44     PrefetchDistance("prefetch-distance",
45                      cl::desc("Number of instructions to prefetch ahead"),
46                      cl::Hidden);
47 
48 static cl::opt<unsigned>
49     MinPrefetchStride("min-prefetch-stride",
50                       cl::desc("Min stride to add prefetches"), cl::Hidden);
51 
52 static cl::opt<unsigned> MaxPrefetchIterationsAhead(
53     "max-prefetch-iters-ahead",
54     cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
55 
56 STATISTIC(NumPrefetches, "Number of prefetches inserted");
57 
58 namespace {
59 
60 /// Loop prefetch implementation class.
61 class LoopDataPrefetch {
62 public:
63   LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE,
64                    const TargetTransformInfo *TTI,
65                    OptimizationRemarkEmitter *ORE)
66       : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
67 
68   bool run();
69 
70 private:
71   bool runOnLoop(Loop *L);
72 
73   /// Check if the stride of the accesses is large enough to
74   /// warrant a prefetch.
75   bool isStrideLargeEnough(const SCEVAddRecExpr *AR);
76 
77   unsigned getMinPrefetchStride() {
78     if (MinPrefetchStride.getNumOccurrences() > 0)
79       return MinPrefetchStride;
80     return TTI->getMinPrefetchStride();
81   }
82 
83   unsigned getPrefetchDistance() {
84     if (PrefetchDistance.getNumOccurrences() > 0)
85       return PrefetchDistance;
86     return TTI->getPrefetchDistance();
87   }
88 
89   unsigned getMaxPrefetchIterationsAhead() {
90     if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
91       return MaxPrefetchIterationsAhead;
92     return TTI->getMaxPrefetchIterationsAhead();
93   }
94 
95   AssumptionCache *AC;
96   LoopInfo *LI;
97   ScalarEvolution *SE;
98   const TargetTransformInfo *TTI;
99   OptimizationRemarkEmitter *ORE;
100 };
101 
102 /// Legacy class for inserting loop data prefetches.
103 class LoopDataPrefetchLegacyPass : public FunctionPass {
104 public:
105   static char ID; // Pass ID, replacement for typeid
106   LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
107     initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
108   }
109 
110   void getAnalysisUsage(AnalysisUsage &AU) const override {
111     AU.addRequired<AssumptionCacheTracker>();
112     AU.addPreserved<DominatorTreeWrapperPass>();
113     AU.addRequired<LoopInfoWrapperPass>();
114     AU.addPreserved<LoopInfoWrapperPass>();
115     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
116     AU.addRequired<ScalarEvolutionWrapperPass>();
117     AU.addPreserved<ScalarEvolutionWrapperPass>();
118     AU.addRequired<TargetTransformInfoWrapperPass>();
119   }
120 
121   bool runOnFunction(Function &F) override;
122   };
123 }
124 
125 char LoopDataPrefetchLegacyPass::ID = 0;
126 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
127                       "Loop Data Prefetch", false, false)
128 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
129 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
130 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
131 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
132 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
133 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
134                     "Loop Data Prefetch", false, false)
135 
136 FunctionPass *llvm::createLoopDataPrefetchPass() {
137   return new LoopDataPrefetchLegacyPass();
138 }
139 
140 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
141   unsigned TargetMinStride = getMinPrefetchStride();
142   // No need to check if any stride goes.
143   if (TargetMinStride <= 1)
144     return true;
145 
146   const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
147   // If MinStride is set, don't prefetch unless we can ensure that stride is
148   // larger.
149   if (!ConstStride)
150     return false;
151 
152   unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
153   return TargetMinStride <= AbsStride;
154 }
155 
156 PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
157                                             FunctionAnalysisManager &AM) {
158   LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
159   ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
160   AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
161   OptimizationRemarkEmitter *ORE =
162       &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
163   const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
164 
165   LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
166   bool Changed = LDP.run();
167 
168   if (Changed) {
169     PreservedAnalyses PA;
170     PA.preserve<DominatorTreeAnalysis>();
171     PA.preserve<LoopAnalysis>();
172     return PA;
173   }
174 
175   return PreservedAnalyses::all();
176 }
177 
178 bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
179   if (skipFunction(F))
180     return false;
181 
182   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
183   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
184   AssumptionCache *AC =
185       &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
186   OptimizationRemarkEmitter *ORE =
187       &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
188   const TargetTransformInfo *TTI =
189       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
190 
191   LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
192   return LDP.run();
193 }
194 
195 bool LoopDataPrefetch::run() {
196   // If PrefetchDistance is not set, don't run the pass.  This gives an
197   // opportunity for targets to run this pass for selected subtargets only
198   // (whose TTI sets PrefetchDistance).
199   if (getPrefetchDistance() == 0)
200     return false;
201   assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
202 
203   bool MadeChange = false;
204 
205   for (Loop *I : *LI)
206     for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
207       MadeChange |= runOnLoop(*L);
208 
209   return MadeChange;
210 }
211 
212 bool LoopDataPrefetch::runOnLoop(Loop *L) {
213   bool MadeChange = false;
214 
215   // Only prefetch in the inner-most loop
216   if (!L->empty())
217     return MadeChange;
218 
219   SmallPtrSet<const Value *, 32> EphValues;
220   CodeMetrics::collectEphemeralValues(L, AC, EphValues);
221 
222   // Calculate the number of iterations ahead to prefetch
223   CodeMetrics Metrics;
224   for (const auto BB : L->blocks()) {
225     // If the loop already has prefetches, then assume that the user knows
226     // what they are doing and don't add any more.
227     for (auto &I : *BB)
228       if (CallInst *CI = dyn_cast<CallInst>(&I))
229         if (Function *F = CI->getCalledFunction())
230           if (F->getIntrinsicID() == Intrinsic::prefetch)
231             return MadeChange;
232 
233     Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
234   }
235   unsigned LoopSize = Metrics.NumInsts;
236   if (!LoopSize)
237     LoopSize = 1;
238 
239   unsigned ItersAhead = getPrefetchDistance() / LoopSize;
240   if (!ItersAhead)
241     ItersAhead = 1;
242 
243   if (ItersAhead > getMaxPrefetchIterationsAhead())
244     return MadeChange;
245 
246   LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
247                     << " iterations ahead (loop size: " << LoopSize << ") in "
248                     << L->getHeader()->getParent()->getName() << ": " << *L);
249 
250   SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
251   for (const auto BB : L->blocks()) {
252     for (auto &I : *BB) {
253       Value *PtrValue;
254       Instruction *MemI;
255 
256       if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
257         MemI = LMemI;
258         PtrValue = LMemI->getPointerOperand();
259       } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
260         if (!PrefetchWrites) continue;
261         MemI = SMemI;
262         PtrValue = SMemI->getPointerOperand();
263       } else continue;
264 
265       unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
266       if (PtrAddrSpace)
267         continue;
268 
269       if (L->isLoopInvariant(PtrValue))
270         continue;
271 
272       const SCEV *LSCEV = SE->getSCEV(PtrValue);
273       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
274       if (!LSCEVAddRec)
275         continue;
276 
277       // Check if the stride of the accesses is large enough to warrant a
278       // prefetch.
279       if (!isStrideLargeEnough(LSCEVAddRec))
280         continue;
281 
282       // We don't want to double prefetch individual cache lines. If this load
283       // is known to be within one cache line of some other load that has
284       // already been prefetched, then don't prefetch this one as well.
285       bool DupPref = false;
286       for (const auto &PrefLoad : PrefLoads) {
287         const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second);
288         if (const SCEVConstant *ConstPtrDiff =
289             dyn_cast<SCEVConstant>(PtrDiff)) {
290           int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
291           if (PD < (int64_t) TTI->getCacheLineSize()) {
292             DupPref = true;
293             break;
294           }
295         }
296       }
297       if (DupPref)
298         continue;
299 
300       const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
301         SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
302         LSCEVAddRec->getStepRecurrence(*SE)));
303       if (!isSafeToExpand(NextLSCEV, *SE))
304         continue;
305 
306       PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
307 
308       Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
309       SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr");
310       Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
311 
312       IRBuilder<> Builder(MemI);
313       Module *M = BB->getParent()->getParent();
314       Type *I32 = Type::getInt32Ty(BB->getContext());
315       Function *PrefetchFunc = Intrinsic::getDeclaration(
316           M, Intrinsic::prefetch, PrefPtrValue->getType());
317       Builder.CreateCall(
318           PrefetchFunc,
319           {PrefPtrValue,
320            ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
321            ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
322       ++NumPrefetches;
323       LLVM_DEBUG(dbgs() << "  Access: " << *PtrValue << ", SCEV: " << *LSCEV
324                         << "\n");
325       ORE->emit([&]() {
326         return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI)
327                << "prefetched memory access";
328       });
329 
330       MadeChange = true;
331     }
332   }
333 
334   return MadeChange;
335 }
336