1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a Loop Data Prefetching Pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h" 14 #include "llvm/InitializePasses.h" 15 16 #define DEBUG_TYPE "loop-data-prefetch" 17 #include "llvm/ADT/DepthFirstIterator.h" 18 #include "llvm/ADT/Statistic.h" 19 #include "llvm/Analysis/AssumptionCache.h" 20 #include "llvm/Analysis/CodeMetrics.h" 21 #include "llvm/Analysis/LoopInfo.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ScalarEvolution.h" 24 #include "llvm/Analysis/ScalarEvolutionExpander.h" 25 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 26 #include "llvm/Analysis/TargetTransformInfo.h" 27 #include "llvm/IR/CFG.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Transforms/Scalar.h" 34 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 35 #include "llvm/Transforms/Utils/ValueMapper.h" 36 using namespace llvm; 37 38 // By default, we limit this to creating 16 PHIs (which is a little over half 39 // of the allocatable register set). 40 static cl::opt<bool> 41 PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false), 42 cl::desc("Prefetch write addresses")); 43 44 static cl::opt<unsigned> 45 PrefetchDistance("prefetch-distance", 46 cl::desc("Number of instructions to prefetch ahead"), 47 cl::Hidden); 48 49 static cl::opt<unsigned> 50 MinPrefetchStride("min-prefetch-stride", 51 cl::desc("Min stride to add prefetches"), cl::Hidden); 52 53 static cl::opt<unsigned> MaxPrefetchIterationsAhead( 54 "max-prefetch-iters-ahead", 55 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden); 56 57 STATISTIC(NumPrefetches, "Number of prefetches inserted"); 58 59 namespace { 60 61 /// Loop prefetch implementation class. 62 class LoopDataPrefetch { 63 public: 64 LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, 65 const TargetTransformInfo *TTI, 66 OptimizationRemarkEmitter *ORE) 67 : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} 68 69 bool run(); 70 71 private: 72 bool runOnLoop(Loop *L); 73 74 /// Check if the stride of the accesses is large enough to 75 /// warrant a prefetch. 76 bool isStrideLargeEnough(const SCEVAddRecExpr *AR); 77 78 unsigned getMinPrefetchStride() { 79 if (MinPrefetchStride.getNumOccurrences() > 0) 80 return MinPrefetchStride; 81 return TTI->getMinPrefetchStride(); 82 } 83 84 unsigned getPrefetchDistance() { 85 if (PrefetchDistance.getNumOccurrences() > 0) 86 return PrefetchDistance; 87 return TTI->getPrefetchDistance(); 88 } 89 90 unsigned getMaxPrefetchIterationsAhead() { 91 if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) 92 return MaxPrefetchIterationsAhead; 93 return TTI->getMaxPrefetchIterationsAhead(); 94 } 95 96 AssumptionCache *AC; 97 LoopInfo *LI; 98 ScalarEvolution *SE; 99 const TargetTransformInfo *TTI; 100 OptimizationRemarkEmitter *ORE; 101 }; 102 103 /// Legacy class for inserting loop data prefetches. 104 class LoopDataPrefetchLegacyPass : public FunctionPass { 105 public: 106 static char ID; // Pass ID, replacement for typeid 107 LoopDataPrefetchLegacyPass() : FunctionPass(ID) { 108 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry()); 109 } 110 111 void getAnalysisUsage(AnalysisUsage &AU) const override { 112 AU.addRequired<AssumptionCacheTracker>(); 113 AU.addPreserved<DominatorTreeWrapperPass>(); 114 AU.addRequired<LoopInfoWrapperPass>(); 115 AU.addPreserved<LoopInfoWrapperPass>(); 116 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 117 AU.addRequired<ScalarEvolutionWrapperPass>(); 118 AU.addPreserved<ScalarEvolutionWrapperPass>(); 119 AU.addRequired<TargetTransformInfoWrapperPass>(); 120 } 121 122 bool runOnFunction(Function &F) override; 123 }; 124 } 125 126 char LoopDataPrefetchLegacyPass::ID = 0; 127 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch", 128 "Loop Data Prefetch", false, false) 129 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 130 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 131 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 132 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 133 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 134 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch", 135 "Loop Data Prefetch", false, false) 136 137 FunctionPass *llvm::createLoopDataPrefetchPass() { 138 return new LoopDataPrefetchLegacyPass(); 139 } 140 141 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { 142 unsigned TargetMinStride = getMinPrefetchStride(); 143 // No need to check if any stride goes. 144 if (TargetMinStride <= 1) 145 return true; 146 147 const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); 148 // If MinStride is set, don't prefetch unless we can ensure that stride is 149 // larger. 150 if (!ConstStride) 151 return false; 152 153 unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue()); 154 return TargetMinStride <= AbsStride; 155 } 156 157 PreservedAnalyses LoopDataPrefetchPass::run(Function &F, 158 FunctionAnalysisManager &AM) { 159 LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); 160 ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); 161 AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); 162 OptimizationRemarkEmitter *ORE = 163 &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 164 const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); 165 166 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); 167 bool Changed = LDP.run(); 168 169 if (Changed) { 170 PreservedAnalyses PA; 171 PA.preserve<DominatorTreeAnalysis>(); 172 PA.preserve<LoopAnalysis>(); 173 return PA; 174 } 175 176 return PreservedAnalyses::all(); 177 } 178 179 bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { 180 if (skipFunction(F)) 181 return false; 182 183 LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 184 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 185 AssumptionCache *AC = 186 &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 187 OptimizationRemarkEmitter *ORE = 188 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 189 const TargetTransformInfo *TTI = 190 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 191 192 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); 193 return LDP.run(); 194 } 195 196 bool LoopDataPrefetch::run() { 197 // If PrefetchDistance is not set, don't run the pass. This gives an 198 // opportunity for targets to run this pass for selected subtargets only 199 // (whose TTI sets PrefetchDistance). 200 if (getPrefetchDistance() == 0) 201 return false; 202 assert(TTI->getCacheLineSize() && "Cache line size is not set for target"); 203 204 bool MadeChange = false; 205 206 for (Loop *I : *LI) 207 for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L) 208 MadeChange |= runOnLoop(*L); 209 210 return MadeChange; 211 } 212 213 bool LoopDataPrefetch::runOnLoop(Loop *L) { 214 bool MadeChange = false; 215 216 // Only prefetch in the inner-most loop 217 if (!L->empty()) 218 return MadeChange; 219 220 SmallPtrSet<const Value *, 32> EphValues; 221 CodeMetrics::collectEphemeralValues(L, AC, EphValues); 222 223 // Calculate the number of iterations ahead to prefetch 224 CodeMetrics Metrics; 225 for (const auto BB : L->blocks()) { 226 // If the loop already has prefetches, then assume that the user knows 227 // what they are doing and don't add any more. 228 for (auto &I : *BB) 229 if (CallInst *CI = dyn_cast<CallInst>(&I)) 230 if (Function *F = CI->getCalledFunction()) 231 if (F->getIntrinsicID() == Intrinsic::prefetch) 232 return MadeChange; 233 234 Metrics.analyzeBasicBlock(BB, *TTI, EphValues); 235 } 236 unsigned LoopSize = Metrics.NumInsts; 237 if (!LoopSize) 238 LoopSize = 1; 239 240 unsigned ItersAhead = getPrefetchDistance() / LoopSize; 241 if (!ItersAhead) 242 ItersAhead = 1; 243 244 if (ItersAhead > getMaxPrefetchIterationsAhead()) 245 return MadeChange; 246 247 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead 248 << " iterations ahead (loop size: " << LoopSize << ") in " 249 << L->getHeader()->getParent()->getName() << ": " << *L); 250 251 SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; 252 for (const auto BB : L->blocks()) { 253 for (auto &I : *BB) { 254 Value *PtrValue; 255 Instruction *MemI; 256 257 if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) { 258 MemI = LMemI; 259 PtrValue = LMemI->getPointerOperand(); 260 } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { 261 if (!PrefetchWrites) continue; 262 MemI = SMemI; 263 PtrValue = SMemI->getPointerOperand(); 264 } else continue; 265 266 unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); 267 if (PtrAddrSpace) 268 continue; 269 270 if (L->isLoopInvariant(PtrValue)) 271 continue; 272 273 const SCEV *LSCEV = SE->getSCEV(PtrValue); 274 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); 275 if (!LSCEVAddRec) 276 continue; 277 278 // Check if the stride of the accesses is large enough to warrant a 279 // prefetch. 280 if (!isStrideLargeEnough(LSCEVAddRec)) 281 continue; 282 283 // We don't want to double prefetch individual cache lines. If this load 284 // is known to be within one cache line of some other load that has 285 // already been prefetched, then don't prefetch this one as well. 286 bool DupPref = false; 287 for (const auto &PrefLoad : PrefLoads) { 288 const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); 289 if (const SCEVConstant *ConstPtrDiff = 290 dyn_cast<SCEVConstant>(PtrDiff)) { 291 int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); 292 if (PD < (int64_t) TTI->getCacheLineSize()) { 293 DupPref = true; 294 break; 295 } 296 } 297 } 298 if (DupPref) 299 continue; 300 301 const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( 302 SE->getConstant(LSCEVAddRec->getType(), ItersAhead), 303 LSCEVAddRec->getStepRecurrence(*SE))); 304 if (!isSafeToExpand(NextLSCEV, *SE)) 305 continue; 306 307 PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); 308 309 Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); 310 SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); 311 Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); 312 313 IRBuilder<> Builder(MemI); 314 Module *M = BB->getParent()->getParent(); 315 Type *I32 = Type::getInt32Ty(BB->getContext()); 316 Function *PrefetchFunc = Intrinsic::getDeclaration( 317 M, Intrinsic::prefetch, PrefPtrValue->getType()); 318 Builder.CreateCall( 319 PrefetchFunc, 320 {PrefPtrValue, 321 ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), 322 ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); 323 ++NumPrefetches; 324 LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV 325 << "\n"); 326 ORE->emit([&]() { 327 return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) 328 << "prefetched memory access"; 329 }); 330 331 MadeChange = true; 332 } 333 } 334 335 return MadeChange; 336 } 337