1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass wrapper around the ExtractLoop() scalar transformation to extract each 10 // top-level loop into its own new function. If the loop is the ONLY loop in a 11 // given function, it is not touched. This is a pass most useful for debugging 12 // via bugpoint. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/IPO/LoopExtractor.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AssumptionCache.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/PassManager.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Transforms/IPO.h" 27 #include "llvm/Transforms/Utils.h" 28 #include "llvm/Transforms/Utils/CodeExtractor.h" 29 using namespace llvm; 30 31 #define DEBUG_TYPE "loop-extract" 32 33 STATISTIC(NumExtracted, "Number of loops extracted"); 34 35 namespace { 36 struct LoopExtractorLegacyPass : public ModulePass { 37 static char ID; // Pass identification, replacement for typeid 38 39 unsigned NumLoops; 40 41 explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0) 42 : ModulePass(ID), NumLoops(NumLoops) { 43 initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry()); 44 } 45 46 bool runOnModule(Module &M) override; 47 48 void getAnalysisUsage(AnalysisUsage &AU) const override { 49 AU.addRequiredID(BreakCriticalEdgesID); 50 AU.addRequired<DominatorTreeWrapperPass>(); 51 AU.addRequired<LoopInfoWrapperPass>(); 52 AU.addPreserved<LoopInfoWrapperPass>(); 53 AU.addRequiredID(LoopSimplifyID); 54 AU.addUsedIfAvailable<AssumptionCacheTracker>(); 55 } 56 }; 57 58 struct LoopExtractor { 59 explicit LoopExtractor( 60 unsigned NumLoops, 61 function_ref<DominatorTree &(Function &)> LookupDomTree, 62 function_ref<LoopInfo &(Function &)> LookupLoopInfo, 63 function_ref<AssumptionCache *(Function &)> LookupAssumptionCache) 64 : NumLoops(NumLoops), LookupDomTree(LookupDomTree), 65 LookupLoopInfo(LookupLoopInfo), 66 LookupAssumptionCache(LookupAssumptionCache) {} 67 bool runOnModule(Module &M); 68 69 private: 70 // The number of natural loops to extract from the program into functions. 71 unsigned NumLoops; 72 73 function_ref<DominatorTree &(Function &)> LookupDomTree; 74 function_ref<LoopInfo &(Function &)> LookupLoopInfo; 75 function_ref<AssumptionCache *(Function &)> LookupAssumptionCache; 76 77 bool runOnFunction(Function &F); 78 79 bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI, 80 DominatorTree &DT); 81 bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT); 82 }; 83 } // namespace 84 85 char LoopExtractorLegacyPass::ID = 0; 86 INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract", 87 "Extract loops into new functions", false, false) 88 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) 89 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 90 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 91 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 92 INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract", 93 "Extract loops into new functions", false, false) 94 95 namespace { 96 /// SingleLoopExtractor - For bugpoint. 97 struct SingleLoopExtractor : public LoopExtractorLegacyPass { 98 static char ID; // Pass identification, replacement for typeid 99 SingleLoopExtractor() : LoopExtractorLegacyPass(1) {} 100 }; 101 } // End anonymous namespace 102 103 char SingleLoopExtractor::ID = 0; 104 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", 105 "Extract at most one loop into a new function", false, false) 106 107 // createLoopExtractorPass - This pass extracts all natural loops from the 108 // program into a function if it can. 109 // 110 Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); } 111 112 bool LoopExtractorLegacyPass::runOnModule(Module &M) { 113 if (skipModule(M)) 114 return false; 115 116 bool Changed = false; 117 auto LookupDomTree = [this](Function &F) -> DominatorTree & { 118 return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); 119 }; 120 auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & { 121 return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo(); 122 }; 123 auto LookupACT = [this](Function &F) -> AssumptionCache * { 124 if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>()) 125 return ACT->lookupAssumptionCache(F); 126 return nullptr; 127 }; 128 return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT) 129 .runOnModule(M) || 130 Changed; 131 } 132 133 bool LoopExtractor::runOnModule(Module &M) { 134 if (M.empty()) 135 return false; 136 137 if (!NumLoops) 138 return false; 139 140 bool Changed = false; 141 142 // The end of the function list may change (new functions will be added at the 143 // end), so we run from the first to the current last. 144 auto I = M.begin(), E = --M.end(); 145 while (true) { 146 Function &F = *I; 147 148 Changed |= runOnFunction(F); 149 if (!NumLoops) 150 break; 151 152 // If this is the last function. 153 if (I == E) 154 break; 155 156 ++I; 157 } 158 return Changed; 159 } 160 161 bool LoopExtractor::runOnFunction(Function &F) { 162 // Do not modify `optnone` functions. 163 if (F.hasOptNone()) 164 return false; 165 166 if (F.empty()) 167 return false; 168 169 bool Changed = false; 170 LoopInfo &LI = LookupLoopInfo(F); 171 172 // If there are no loops in the function. 173 if (LI.empty()) 174 return Changed; 175 176 DominatorTree &DT = LookupDomTree(F); 177 178 // If there is more than one top-level loop in this function, extract all of 179 // the loops. 180 if (std::next(LI.begin()) != LI.end()) 181 return Changed | extractLoops(LI.begin(), LI.end(), LI, DT); 182 183 // Otherwise there is exactly one top-level loop. 184 Loop *TLL = *LI.begin(); 185 186 // If the loop is in LoopSimplify form, then extract it only if this function 187 // is more than a minimal wrapper around the loop. 188 if (TLL->isLoopSimplifyForm()) { 189 bool ShouldExtractLoop = false; 190 191 // Extract the loop if the entry block doesn't branch to the loop header. 192 Instruction *EntryTI = F.getEntryBlock().getTerminator(); 193 if (!isa<BranchInst>(EntryTI) || 194 !cast<BranchInst>(EntryTI)->isUnconditional() || 195 EntryTI->getSuccessor(0) != TLL->getHeader()) { 196 ShouldExtractLoop = true; 197 } else { 198 // Check to see if any exits from the loop are more than just return 199 // blocks. 200 SmallVector<BasicBlock *, 8> ExitBlocks; 201 TLL->getExitBlocks(ExitBlocks); 202 for (auto *ExitBlock : ExitBlocks) 203 if (!isa<ReturnInst>(ExitBlock->getTerminator())) { 204 ShouldExtractLoop = true; 205 break; 206 } 207 } 208 209 if (ShouldExtractLoop) 210 return Changed | extractLoop(TLL, LI, DT); 211 } 212 213 // Okay, this function is a minimal container around the specified loop. 214 // If we extract the loop, we will continue to just keep extracting it 215 // infinitely... so don't extract it. However, if the loop contains any 216 // sub-loops, extract them. 217 return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT); 218 } 219 220 bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To, 221 LoopInfo &LI, DominatorTree &DT) { 222 bool Changed = false; 223 SmallVector<Loop *, 8> Loops; 224 225 // Save the list of loops, as it may change. 226 Loops.assign(From, To); 227 for (Loop *L : Loops) { 228 // If LoopSimplify form is not available, stay out of trouble. 229 if (!L->isLoopSimplifyForm()) 230 continue; 231 232 Changed |= extractLoop(L, LI, DT); 233 if (!NumLoops) 234 break; 235 } 236 return Changed; 237 } 238 239 bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) { 240 assert(NumLoops != 0); 241 Function &Func = *L->getHeader()->getParent(); 242 AssumptionCache *AC = LookupAssumptionCache(Func); 243 CodeExtractorAnalysisCache CEAC(Func); 244 CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); 245 if (Extractor.extractCodeRegion(CEAC)) { 246 LI.erase(L); 247 --NumLoops; 248 ++NumExtracted; 249 return true; 250 } 251 return false; 252 } 253 254 // createSingleLoopExtractorPass - This pass extracts one natural loop from the 255 // program into a function if it can. This is used by bugpoint. 256 // 257 Pass *llvm::createSingleLoopExtractorPass() { 258 return new SingleLoopExtractor(); 259 } 260 261 PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { 262 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 263 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 264 return FAM.getResult<DominatorTreeAnalysis>(F); 265 }; 266 auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & { 267 return FAM.getResult<LoopAnalysis>(F); 268 }; 269 auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { 270 return FAM.getCachedResult<AssumptionAnalysis>(F); 271 }; 272 if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, 273 LookupAssumptionCache) 274 .runOnModule(M)) 275 return PreservedAnalyses::all(); 276 277 PreservedAnalyses PA; 278 PA.preserve<LoopAnalysis>(); 279 return PA; 280 } 281 282 void LoopExtractorPass::printPipeline( 283 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 284 static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline( 285 OS, MapClassName2PassName); 286 OS << '<'; 287 if (NumLoops == 1) 288 OS << "single"; 289 OS << '>'; 290 } 291