1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass wrapper around the ExtractLoop() scalar transformation to extract each 10 // top-level loop into its own new function. If the loop is the ONLY loop in a 11 // given function, it is not touched. This is a pass most useful for debugging 12 // via bugpoint. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/IPO/LoopExtractor.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AssumptionCache.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/PassManager.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Transforms/IPO.h" 28 #include "llvm/Transforms/Scalar.h" 29 #include "llvm/Transforms/Utils.h" 30 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 31 #include "llvm/Transforms/Utils/CodeExtractor.h" 32 #include <fstream> 33 #include <set> 34 using namespace llvm; 35 36 #define DEBUG_TYPE "loop-extract" 37 38 STATISTIC(NumExtracted, "Number of loops extracted"); 39 40 namespace { 41 struct LoopExtractorLegacyPass : public ModulePass { 42 static char ID; // Pass identification, replacement for typeid 43 44 unsigned NumLoops; 45 46 explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0) 47 : ModulePass(ID), NumLoops(NumLoops) { 48 initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry()); 49 } 50 51 bool runOnModule(Module &M) override; 52 53 void getAnalysisUsage(AnalysisUsage &AU) const override { 54 AU.addRequiredID(BreakCriticalEdgesID); 55 AU.addRequired<DominatorTreeWrapperPass>(); 56 AU.addRequired<LoopInfoWrapperPass>(); 57 AU.addPreserved<LoopInfoWrapperPass>(); 58 AU.addRequiredID(LoopSimplifyID); 59 AU.addUsedIfAvailable<AssumptionCacheTracker>(); 60 } 61 }; 62 63 struct LoopExtractor { 64 explicit LoopExtractor( 65 unsigned NumLoops, 66 function_ref<DominatorTree &(Function &)> LookupDomTree, 67 function_ref<LoopInfo &(Function &)> LookupLoopInfo, 68 function_ref<AssumptionCache *(Function &)> LookupAssumptionCache) 69 : NumLoops(NumLoops), LookupDomTree(LookupDomTree), 70 LookupLoopInfo(LookupLoopInfo), 71 LookupAssumptionCache(LookupAssumptionCache) {} 72 bool runOnModule(Module &M); 73 74 private: 75 // The number of natural loops to extract from the program into functions. 76 unsigned NumLoops; 77 78 function_ref<DominatorTree &(Function &)> LookupDomTree; 79 function_ref<LoopInfo &(Function &)> LookupLoopInfo; 80 function_ref<AssumptionCache *(Function &)> LookupAssumptionCache; 81 82 bool runOnFunction(Function &F); 83 84 bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI, 85 DominatorTree &DT); 86 bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT); 87 }; 88 } // namespace 89 90 char LoopExtractorLegacyPass::ID = 0; 91 INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract", 92 "Extract loops into new functions", false, false) 93 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) 94 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 95 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 96 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 97 INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract", 98 "Extract loops into new functions", false, false) 99 100 namespace { 101 /// SingleLoopExtractor - For bugpoint. 102 struct SingleLoopExtractor : public LoopExtractorLegacyPass { 103 static char ID; // Pass identification, replacement for typeid 104 SingleLoopExtractor() : LoopExtractorLegacyPass(1) {} 105 }; 106 } // End anonymous namespace 107 108 char SingleLoopExtractor::ID = 0; 109 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", 110 "Extract at most one loop into a new function", false, false) 111 112 // createLoopExtractorPass - This pass extracts all natural loops from the 113 // program into a function if it can. 114 // 115 Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); } 116 117 bool LoopExtractorLegacyPass::runOnModule(Module &M) { 118 if (skipModule(M)) 119 return false; 120 121 bool Changed = false; 122 auto LookupDomTree = [this](Function &F) -> DominatorTree & { 123 return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); 124 }; 125 auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & { 126 return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo(); 127 }; 128 auto LookupACT = [this](Function &F) -> AssumptionCache * { 129 if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>()) 130 return ACT->lookupAssumptionCache(F); 131 return nullptr; 132 }; 133 return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT) 134 .runOnModule(M) || 135 Changed; 136 } 137 138 bool LoopExtractor::runOnModule(Module &M) { 139 if (M.empty()) 140 return false; 141 142 if (!NumLoops) 143 return false; 144 145 bool Changed = false; 146 147 // The end of the function list may change (new functions will be added at the 148 // end), so we run from the first to the current last. 149 auto I = M.begin(), E = --M.end(); 150 while (true) { 151 Function &F = *I; 152 153 Changed |= runOnFunction(F); 154 if (!NumLoops) 155 break; 156 157 // If this is the last function. 158 if (I == E) 159 break; 160 161 ++I; 162 } 163 return Changed; 164 } 165 166 bool LoopExtractor::runOnFunction(Function &F) { 167 // Do not modify `optnone` functions. 168 if (F.hasOptNone()) 169 return false; 170 171 if (F.empty()) 172 return false; 173 174 bool Changed = false; 175 LoopInfo &LI = LookupLoopInfo(F); 176 177 // If there are no loops in the function. 178 if (LI.empty()) 179 return Changed; 180 181 DominatorTree &DT = LookupDomTree(F); 182 183 // If there is more than one top-level loop in this function, extract all of 184 // the loops. 185 if (std::next(LI.begin()) != LI.end()) 186 return Changed | extractLoops(LI.begin(), LI.end(), LI, DT); 187 188 // Otherwise there is exactly one top-level loop. 189 Loop *TLL = *LI.begin(); 190 191 // If the loop is in LoopSimplify form, then extract it only if this function 192 // is more than a minimal wrapper around the loop. 193 if (TLL->isLoopSimplifyForm()) { 194 bool ShouldExtractLoop = false; 195 196 // Extract the loop if the entry block doesn't branch to the loop header. 197 Instruction *EntryTI = F.getEntryBlock().getTerminator(); 198 if (!isa<BranchInst>(EntryTI) || 199 !cast<BranchInst>(EntryTI)->isUnconditional() || 200 EntryTI->getSuccessor(0) != TLL->getHeader()) { 201 ShouldExtractLoop = true; 202 } else { 203 // Check to see if any exits from the loop are more than just return 204 // blocks. 205 SmallVector<BasicBlock *, 8> ExitBlocks; 206 TLL->getExitBlocks(ExitBlocks); 207 for (auto *ExitBlock : ExitBlocks) 208 if (!isa<ReturnInst>(ExitBlock->getTerminator())) { 209 ShouldExtractLoop = true; 210 break; 211 } 212 } 213 214 if (ShouldExtractLoop) 215 return Changed | extractLoop(TLL, LI, DT); 216 } 217 218 // Okay, this function is a minimal container around the specified loop. 219 // If we extract the loop, we will continue to just keep extracting it 220 // infinitely... so don't extract it. However, if the loop contains any 221 // sub-loops, extract them. 222 return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT); 223 } 224 225 bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To, 226 LoopInfo &LI, DominatorTree &DT) { 227 bool Changed = false; 228 SmallVector<Loop *, 8> Loops; 229 230 // Save the list of loops, as it may change. 231 Loops.assign(From, To); 232 for (Loop *L : Loops) { 233 // If LoopSimplify form is not available, stay out of trouble. 234 if (!L->isLoopSimplifyForm()) 235 continue; 236 237 Changed |= extractLoop(L, LI, DT); 238 if (!NumLoops) 239 break; 240 } 241 return Changed; 242 } 243 244 bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) { 245 assert(NumLoops != 0); 246 Function &Func = *L->getHeader()->getParent(); 247 AssumptionCache *AC = LookupAssumptionCache(Func); 248 CodeExtractorAnalysisCache CEAC(Func); 249 CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); 250 if (Extractor.extractCodeRegion(CEAC)) { 251 LI.erase(L); 252 --NumLoops; 253 ++NumExtracted; 254 return true; 255 } 256 return false; 257 } 258 259 // createSingleLoopExtractorPass - This pass extracts one natural loop from the 260 // program into a function if it can. This is used by bugpoint. 261 // 262 Pass *llvm::createSingleLoopExtractorPass() { 263 return new SingleLoopExtractor(); 264 } 265 266 PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { 267 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 268 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 269 return FAM.getResult<DominatorTreeAnalysis>(F); 270 }; 271 auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & { 272 return FAM.getResult<LoopAnalysis>(F); 273 }; 274 auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { 275 return FAM.getCachedResult<AssumptionAnalysis>(F); 276 }; 277 if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, 278 LookupAssumptionCache) 279 .runOnModule(M)) 280 return PreservedAnalyses::all(); 281 282 PreservedAnalyses PA; 283 PA.preserve<LoopAnalysis>(); 284 return PA; 285 } 286 287 void LoopExtractorPass::printPipeline( 288 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 289 static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline( 290 OS, MapClassName2PassName); 291 OS << "<"; 292 if (NumLoops == 1) 293 OS << "single"; 294 OS << ">"; 295 } 296