xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp (revision 29fc4075e69fd27de0cded313ac6000165d99f8b)
1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
10 // top-level loop into its own new function. If the loop is the ONLY loop in a
11 // given function, it is not touched. This is a pass most useful for debugging
12 // via bugpoint.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/IPO/LoopExtractor.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/PassManager.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/IPO.h"
28 #include "llvm/Transforms/Scalar.h"
29 #include "llvm/Transforms/Utils.h"
30 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
31 #include "llvm/Transforms/Utils/CodeExtractor.h"
32 #include <fstream>
33 #include <set>
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "loop-extract"
37 
38 STATISTIC(NumExtracted, "Number of loops extracted");
39 
40 namespace {
41 struct LoopExtractorLegacyPass : public ModulePass {
42   static char ID; // Pass identification, replacement for typeid
43 
44   unsigned NumLoops;
45 
46   explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0)
47       : ModulePass(ID), NumLoops(NumLoops) {
48     initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
49   }
50 
51   bool runOnModule(Module &M) override;
52 
53   void getAnalysisUsage(AnalysisUsage &AU) const override {
54     AU.addRequiredID(BreakCriticalEdgesID);
55     AU.addRequired<DominatorTreeWrapperPass>();
56     AU.addRequired<LoopInfoWrapperPass>();
57     AU.addPreserved<LoopInfoWrapperPass>();
58     AU.addRequiredID(LoopSimplifyID);
59     AU.addUsedIfAvailable<AssumptionCacheTracker>();
60   }
61 };
62 
63 struct LoopExtractor {
64   explicit LoopExtractor(
65       unsigned NumLoops,
66       function_ref<DominatorTree &(Function &)> LookupDomTree,
67       function_ref<LoopInfo &(Function &)> LookupLoopInfo,
68       function_ref<AssumptionCache *(Function &)> LookupAssumptionCache)
69       : NumLoops(NumLoops), LookupDomTree(LookupDomTree),
70         LookupLoopInfo(LookupLoopInfo),
71         LookupAssumptionCache(LookupAssumptionCache) {}
72   bool runOnModule(Module &M);
73 
74 private:
75   // The number of natural loops to extract from the program into functions.
76   unsigned NumLoops;
77 
78   function_ref<DominatorTree &(Function &)> LookupDomTree;
79   function_ref<LoopInfo &(Function &)> LookupLoopInfo;
80   function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
81 
82   bool runOnFunction(Function &F);
83 
84   bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI,
85                     DominatorTree &DT);
86   bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT);
87 };
88 } // namespace
89 
90 char LoopExtractorLegacyPass::ID = 0;
91 INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract",
92                       "Extract loops into new functions", false, false)
93 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
94 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
95 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
96 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
97 INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract",
98                     "Extract loops into new functions", false, false)
99 
100 namespace {
101   /// SingleLoopExtractor - For bugpoint.
102 struct SingleLoopExtractor : public LoopExtractorLegacyPass {
103   static char ID; // Pass identification, replacement for typeid
104   SingleLoopExtractor() : LoopExtractorLegacyPass(1) {}
105 };
106 } // End anonymous namespace
107 
108 char SingleLoopExtractor::ID = 0;
109 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
110                 "Extract at most one loop into a new function", false, false)
111 
112 // createLoopExtractorPass - This pass extracts all natural loops from the
113 // program into a function if it can.
114 //
115 Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); }
116 
117 bool LoopExtractorLegacyPass::runOnModule(Module &M) {
118   if (skipModule(M))
119     return false;
120 
121   bool Changed = false;
122   auto LookupDomTree = [this](Function &F) -> DominatorTree & {
123     return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
124   };
125   auto LookupLoopInfo = [this, &Changed](Function &F) -> LoopInfo & {
126     return this->getAnalysis<LoopInfoWrapperPass>(F, &Changed).getLoopInfo();
127   };
128   auto LookupACT = [this](Function &F) -> AssumptionCache * {
129     if (auto *ACT = this->getAnalysisIfAvailable<AssumptionCacheTracker>())
130       return ACT->lookupAssumptionCache(F);
131     return nullptr;
132   };
133   return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
134              .runOnModule(M) ||
135          Changed;
136 }
137 
138 bool LoopExtractor::runOnModule(Module &M) {
139   if (M.empty())
140     return false;
141 
142   if (!NumLoops)
143     return false;
144 
145   bool Changed = false;
146 
147   // The end of the function list may change (new functions will be added at the
148   // end), so we run from the first to the current last.
149   auto I = M.begin(), E = --M.end();
150   while (true) {
151     Function &F = *I;
152 
153     Changed |= runOnFunction(F);
154     if (!NumLoops)
155       break;
156 
157     // If this is the last function.
158     if (I == E)
159       break;
160 
161     ++I;
162   }
163   return Changed;
164 }
165 
166 bool LoopExtractor::runOnFunction(Function &F) {
167   // Do not modify `optnone` functions.
168   if (F.hasOptNone())
169     return false;
170 
171   if (F.empty())
172     return false;
173 
174   bool Changed = false;
175   LoopInfo &LI = LookupLoopInfo(F);
176 
177   // If there are no loops in the function.
178   if (LI.empty())
179     return Changed;
180 
181   DominatorTree &DT = LookupDomTree(F);
182 
183   // If there is more than one top-level loop in this function, extract all of
184   // the loops.
185   if (std::next(LI.begin()) != LI.end())
186     return Changed | extractLoops(LI.begin(), LI.end(), LI, DT);
187 
188   // Otherwise there is exactly one top-level loop.
189   Loop *TLL = *LI.begin();
190 
191   // If the loop is in LoopSimplify form, then extract it only if this function
192   // is more than a minimal wrapper around the loop.
193   if (TLL->isLoopSimplifyForm()) {
194     bool ShouldExtractLoop = false;
195 
196     // Extract the loop if the entry block doesn't branch to the loop header.
197     Instruction *EntryTI = F.getEntryBlock().getTerminator();
198     if (!isa<BranchInst>(EntryTI) ||
199         !cast<BranchInst>(EntryTI)->isUnconditional() ||
200         EntryTI->getSuccessor(0) != TLL->getHeader()) {
201       ShouldExtractLoop = true;
202     } else {
203       // Check to see if any exits from the loop are more than just return
204       // blocks.
205       SmallVector<BasicBlock *, 8> ExitBlocks;
206       TLL->getExitBlocks(ExitBlocks);
207       for (auto *ExitBlock : ExitBlocks)
208         if (!isa<ReturnInst>(ExitBlock->getTerminator())) {
209           ShouldExtractLoop = true;
210           break;
211         }
212     }
213 
214     if (ShouldExtractLoop)
215       return Changed | extractLoop(TLL, LI, DT);
216   }
217 
218   // Okay, this function is a minimal container around the specified loop.
219   // If we extract the loop, we will continue to just keep extracting it
220   // infinitely... so don't extract it. However, if the loop contains any
221   // sub-loops, extract them.
222   return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT);
223 }
224 
225 bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
226                                  LoopInfo &LI, DominatorTree &DT) {
227   bool Changed = false;
228   SmallVector<Loop *, 8> Loops;
229 
230   // Save the list of loops, as it may change.
231   Loops.assign(From, To);
232   for (Loop *L : Loops) {
233     // If LoopSimplify form is not available, stay out of trouble.
234     if (!L->isLoopSimplifyForm())
235       continue;
236 
237     Changed |= extractLoop(L, LI, DT);
238     if (!NumLoops)
239       break;
240   }
241   return Changed;
242 }
243 
244 bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
245   assert(NumLoops != 0);
246   Function &Func = *L->getHeader()->getParent();
247   AssumptionCache *AC = LookupAssumptionCache(Func);
248   CodeExtractorAnalysisCache CEAC(Func);
249   CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
250   if (Extractor.extractCodeRegion(CEAC)) {
251     LI.erase(L);
252     --NumLoops;
253     ++NumExtracted;
254     return true;
255   }
256   return false;
257 }
258 
259 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
260 // program into a function if it can.  This is used by bugpoint.
261 //
262 Pass *llvm::createSingleLoopExtractorPass() {
263   return new SingleLoopExtractor();
264 }
265 
266 PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
267   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
268   auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
269     return FAM.getResult<DominatorTreeAnalysis>(F);
270   };
271   auto LookupLoopInfo = [&FAM](Function &F) -> LoopInfo & {
272     return FAM.getResult<LoopAnalysis>(F);
273   };
274   auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
275     return FAM.getCachedResult<AssumptionAnalysis>(F);
276   };
277   if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
278                      LookupAssumptionCache)
279            .runOnModule(M))
280     return PreservedAnalyses::all();
281 
282   PreservedAnalyses PA;
283   PA.preserve<LoopAnalysis>();
284   return PA;
285 }
286 
287 void LoopExtractorPass::printPipeline(
288     raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
289   static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
290       OS, MapClassName2PassName);
291   OS << "<";
292   if (NumLoops == 1)
293     OS << "single";
294   OS << ">";
295 }
296