1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a Loop Data Prefetching Pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h" 14 15 #define DEBUG_TYPE "loop-data-prefetch" 16 #include "llvm/ADT/DepthFirstIterator.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AssumptionCache.h" 19 #include "llvm/Analysis/CodeMetrics.h" 20 #include "llvm/Analysis/LoopInfo.h" 21 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 22 #include "llvm/Analysis/ScalarEvolution.h" 23 #include "llvm/Analysis/ScalarEvolutionExpander.h" 24 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 25 #include "llvm/Analysis/TargetTransformInfo.h" 26 #include "llvm/IR/CFG.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Transforms/Scalar.h" 33 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 34 #include "llvm/Transforms/Utils/ValueMapper.h" 35 using namespace llvm; 36 37 // By default, we limit this to creating 16 PHIs (which is a little over half 38 // of the allocatable register set). 39 static cl::opt<bool> 40 PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false), 41 cl::desc("Prefetch write addresses")); 42 43 static cl::opt<unsigned> 44 PrefetchDistance("prefetch-distance", 45 cl::desc("Number of instructions to prefetch ahead"), 46 cl::Hidden); 47 48 static cl::opt<unsigned> 49 MinPrefetchStride("min-prefetch-stride", 50 cl::desc("Min stride to add prefetches"), cl::Hidden); 51 52 static cl::opt<unsigned> MaxPrefetchIterationsAhead( 53 "max-prefetch-iters-ahead", 54 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden); 55 56 STATISTIC(NumPrefetches, "Number of prefetches inserted"); 57 58 namespace { 59 60 /// Loop prefetch implementation class. 61 class LoopDataPrefetch { 62 public: 63 LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, 64 const TargetTransformInfo *TTI, 65 OptimizationRemarkEmitter *ORE) 66 : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} 67 68 bool run(); 69 70 private: 71 bool runOnLoop(Loop *L); 72 73 /// Check if the stride of the accesses is large enough to 74 /// warrant a prefetch. 75 bool isStrideLargeEnough(const SCEVAddRecExpr *AR); 76 77 unsigned getMinPrefetchStride() { 78 if (MinPrefetchStride.getNumOccurrences() > 0) 79 return MinPrefetchStride; 80 return TTI->getMinPrefetchStride(); 81 } 82 83 unsigned getPrefetchDistance() { 84 if (PrefetchDistance.getNumOccurrences() > 0) 85 return PrefetchDistance; 86 return TTI->getPrefetchDistance(); 87 } 88 89 unsigned getMaxPrefetchIterationsAhead() { 90 if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) 91 return MaxPrefetchIterationsAhead; 92 return TTI->getMaxPrefetchIterationsAhead(); 93 } 94 95 AssumptionCache *AC; 96 LoopInfo *LI; 97 ScalarEvolution *SE; 98 const TargetTransformInfo *TTI; 99 OptimizationRemarkEmitter *ORE; 100 }; 101 102 /// Legacy class for inserting loop data prefetches. 103 class LoopDataPrefetchLegacyPass : public FunctionPass { 104 public: 105 static char ID; // Pass ID, replacement for typeid 106 LoopDataPrefetchLegacyPass() : FunctionPass(ID) { 107 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry()); 108 } 109 110 void getAnalysisUsage(AnalysisUsage &AU) const override { 111 AU.addRequired<AssumptionCacheTracker>(); 112 AU.addPreserved<DominatorTreeWrapperPass>(); 113 AU.addRequired<LoopInfoWrapperPass>(); 114 AU.addPreserved<LoopInfoWrapperPass>(); 115 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 116 AU.addRequired<ScalarEvolutionWrapperPass>(); 117 AU.addPreserved<ScalarEvolutionWrapperPass>(); 118 AU.addRequired<TargetTransformInfoWrapperPass>(); 119 } 120 121 bool runOnFunction(Function &F) override; 122 }; 123 } 124 125 char LoopDataPrefetchLegacyPass::ID = 0; 126 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch", 127 "Loop Data Prefetch", false, false) 128 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 129 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 130 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 131 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 132 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 133 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch", 134 "Loop Data Prefetch", false, false) 135 136 FunctionPass *llvm::createLoopDataPrefetchPass() { 137 return new LoopDataPrefetchLegacyPass(); 138 } 139 140 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { 141 unsigned TargetMinStride = getMinPrefetchStride(); 142 // No need to check if any stride goes. 143 if (TargetMinStride <= 1) 144 return true; 145 146 const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); 147 // If MinStride is set, don't prefetch unless we can ensure that stride is 148 // larger. 149 if (!ConstStride) 150 return false; 151 152 unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue()); 153 return TargetMinStride <= AbsStride; 154 } 155 156 PreservedAnalyses LoopDataPrefetchPass::run(Function &F, 157 FunctionAnalysisManager &AM) { 158 LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); 159 ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); 160 AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); 161 OptimizationRemarkEmitter *ORE = 162 &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 163 const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); 164 165 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); 166 bool Changed = LDP.run(); 167 168 if (Changed) { 169 PreservedAnalyses PA; 170 PA.preserve<DominatorTreeAnalysis>(); 171 PA.preserve<LoopAnalysis>(); 172 return PA; 173 } 174 175 return PreservedAnalyses::all(); 176 } 177 178 bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { 179 if (skipFunction(F)) 180 return false; 181 182 LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 183 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 184 AssumptionCache *AC = 185 &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 186 OptimizationRemarkEmitter *ORE = 187 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 188 const TargetTransformInfo *TTI = 189 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 190 191 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); 192 return LDP.run(); 193 } 194 195 bool LoopDataPrefetch::run() { 196 // If PrefetchDistance is not set, don't run the pass. This gives an 197 // opportunity for targets to run this pass for selected subtargets only 198 // (whose TTI sets PrefetchDistance). 199 if (getPrefetchDistance() == 0) 200 return false; 201 assert(TTI->getCacheLineSize() && "Cache line size is not set for target"); 202 203 bool MadeChange = false; 204 205 for (Loop *I : *LI) 206 for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L) 207 MadeChange |= runOnLoop(*L); 208 209 return MadeChange; 210 } 211 212 bool LoopDataPrefetch::runOnLoop(Loop *L) { 213 bool MadeChange = false; 214 215 // Only prefetch in the inner-most loop 216 if (!L->empty()) 217 return MadeChange; 218 219 SmallPtrSet<const Value *, 32> EphValues; 220 CodeMetrics::collectEphemeralValues(L, AC, EphValues); 221 222 // Calculate the number of iterations ahead to prefetch 223 CodeMetrics Metrics; 224 for (const auto BB : L->blocks()) { 225 // If the loop already has prefetches, then assume that the user knows 226 // what they are doing and don't add any more. 227 for (auto &I : *BB) 228 if (CallInst *CI = dyn_cast<CallInst>(&I)) 229 if (Function *F = CI->getCalledFunction()) 230 if (F->getIntrinsicID() == Intrinsic::prefetch) 231 return MadeChange; 232 233 Metrics.analyzeBasicBlock(BB, *TTI, EphValues); 234 } 235 unsigned LoopSize = Metrics.NumInsts; 236 if (!LoopSize) 237 LoopSize = 1; 238 239 unsigned ItersAhead = getPrefetchDistance() / LoopSize; 240 if (!ItersAhead) 241 ItersAhead = 1; 242 243 if (ItersAhead > getMaxPrefetchIterationsAhead()) 244 return MadeChange; 245 246 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead 247 << " iterations ahead (loop size: " << LoopSize << ") in " 248 << L->getHeader()->getParent()->getName() << ": " << *L); 249 250 SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; 251 for (const auto BB : L->blocks()) { 252 for (auto &I : *BB) { 253 Value *PtrValue; 254 Instruction *MemI; 255 256 if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) { 257 MemI = LMemI; 258 PtrValue = LMemI->getPointerOperand(); 259 } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { 260 if (!PrefetchWrites) continue; 261 MemI = SMemI; 262 PtrValue = SMemI->getPointerOperand(); 263 } else continue; 264 265 unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); 266 if (PtrAddrSpace) 267 continue; 268 269 if (L->isLoopInvariant(PtrValue)) 270 continue; 271 272 const SCEV *LSCEV = SE->getSCEV(PtrValue); 273 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); 274 if (!LSCEVAddRec) 275 continue; 276 277 // Check if the stride of the accesses is large enough to warrant a 278 // prefetch. 279 if (!isStrideLargeEnough(LSCEVAddRec)) 280 continue; 281 282 // We don't want to double prefetch individual cache lines. If this load 283 // is known to be within one cache line of some other load that has 284 // already been prefetched, then don't prefetch this one as well. 285 bool DupPref = false; 286 for (const auto &PrefLoad : PrefLoads) { 287 const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); 288 if (const SCEVConstant *ConstPtrDiff = 289 dyn_cast<SCEVConstant>(PtrDiff)) { 290 int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); 291 if (PD < (int64_t) TTI->getCacheLineSize()) { 292 DupPref = true; 293 break; 294 } 295 } 296 } 297 if (DupPref) 298 continue; 299 300 const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( 301 SE->getConstant(LSCEVAddRec->getType(), ItersAhead), 302 LSCEVAddRec->getStepRecurrence(*SE))); 303 if (!isSafeToExpand(NextLSCEV, *SE)) 304 continue; 305 306 PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); 307 308 Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); 309 SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); 310 Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); 311 312 IRBuilder<> Builder(MemI); 313 Module *M = BB->getParent()->getParent(); 314 Type *I32 = Type::getInt32Ty(BB->getContext()); 315 Function *PrefetchFunc = Intrinsic::getDeclaration( 316 M, Intrinsic::prefetch, PrefPtrValue->getType()); 317 Builder.CreateCall( 318 PrefetchFunc, 319 {PrefPtrValue, 320 ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), 321 ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); 322 ++NumPrefetches; 323 LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV 324 << "\n"); 325 ORE->emit([&]() { 326 return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) 327 << "prefetched memory access"; 328 }); 329 330 MadeChange = true; 331 } 332 } 333 334 return MadeChange; 335 } 336