xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp (revision f126890ac5386406dadf7c4cfa9566cbb56537c5)
1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass extracts the specified basic blocks from the module into their
10 // own functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/IPO/BlockExtractor.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/PassManager.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Transforms/IPO.h"
24 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25 #include "llvm/Transforms/Utils/CodeExtractor.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "block-extractor"
30 
31 STATISTIC(NumExtracted, "Number of basic blocks extracted");
32 
33 static cl::opt<std::string> BlockExtractorFile(
34     "extract-blocks-file", cl::value_desc("filename"),
35     cl::desc("A file containing list of basic blocks to extract"), cl::Hidden);
36 
37 static cl::opt<bool>
38     BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
39                              cl::desc("Erase the existing functions"),
40                              cl::Hidden);
41 namespace {
42 class BlockExtractor {
43 public:
44   BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {}
45   bool runOnModule(Module &M);
46   void
47   init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) {
48     GroupsOfBlocks = GroupsOfBlocksToExtract;
49     if (!BlockExtractorFile.empty())
50       loadFile();
51   }
52 
53 private:
54   std::vector<std::vector<BasicBlock *>> GroupsOfBlocks;
55   bool EraseFunctions;
56   /// Map a function name to groups of blocks.
57   SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4>
58       BlocksByName;
59 
60   void loadFile();
61   void splitLandingPadPreds(Function &F);
62 };
63 
64 } // end anonymous namespace
65 
66 /// Gets all of the blocks specified in the input file.
67 void BlockExtractor::loadFile() {
68   auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile);
69   if (ErrOrBuf.getError())
70     report_fatal_error("BlockExtractor couldn't load the file.");
71   // Read the file.
72   auto &Buf = *ErrOrBuf;
73   SmallVector<StringRef, 16> Lines;
74   Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1,
75                          /*KeepEmpty=*/false);
76   for (const auto &Line : Lines) {
77     SmallVector<StringRef, 4> LineSplit;
78     Line.split(LineSplit, ' ', /*MaxSplit=*/-1,
79                /*KeepEmpty=*/false);
80     if (LineSplit.empty())
81       continue;
82     if (LineSplit.size()!=2)
83       report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'",
84                          /*GenCrashDiag=*/false);
85     SmallVector<StringRef, 4> BBNames;
86     LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1,
87                        /*KeepEmpty=*/false);
88     if (BBNames.empty())
89       report_fatal_error("Missing bbs name");
90     BlocksByName.push_back(
91         {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}});
92   }
93 }
94 
95 /// Extracts the landing pads to make sure all of them have only one
96 /// predecessor.
97 void BlockExtractor::splitLandingPadPreds(Function &F) {
98   for (BasicBlock &BB : F) {
99     for (Instruction &I : BB) {
100       if (!isa<InvokeInst>(&I))
101         continue;
102       InvokeInst *II = cast<InvokeInst>(&I);
103       BasicBlock *Parent = II->getParent();
104       BasicBlock *LPad = II->getUnwindDest();
105 
106       // Look through the landing pad's predecessors. If one of them ends in an
107       // 'invoke', then we want to split the landing pad.
108       bool Split = false;
109       for (auto *PredBB : predecessors(LPad)) {
110         if (PredBB->isLandingPad() && PredBB != Parent &&
111             isa<InvokeInst>(Parent->getTerminator())) {
112           Split = true;
113           break;
114         }
115       }
116 
117       if (!Split)
118         continue;
119 
120       SmallVector<BasicBlock *, 2> NewBBs;
121       SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs);
122     }
123   }
124 }
125 
126 bool BlockExtractor::runOnModule(Module &M) {
127   bool Changed = false;
128 
129   // Get all the functions.
130   SmallVector<Function *, 4> Functions;
131   for (Function &F : M) {
132     splitLandingPadPreds(F);
133     Functions.push_back(&F);
134   }
135 
136   // Get all the blocks specified in the input file.
137   unsigned NextGroupIdx = GroupsOfBlocks.size();
138   GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size());
139   for (const auto &BInfo : BlocksByName) {
140     Function *F = M.getFunction(BInfo.first);
141     if (!F)
142       report_fatal_error("Invalid function name specified in the input file",
143                          /*GenCrashDiag=*/false);
144     for (const auto &BBInfo : BInfo.second) {
145       auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) {
146         return BB.getName().equals(BBInfo);
147       });
148       if (Res == F->end())
149         report_fatal_error("Invalid block name specified in the input file",
150                            /*GenCrashDiag=*/false);
151       GroupsOfBlocks[NextGroupIdx].push_back(&*Res);
152     }
153     ++NextGroupIdx;
154   }
155 
156   // Extract each group of basic blocks.
157   for (auto &BBs : GroupsOfBlocks) {
158     SmallVector<BasicBlock *, 32> BlocksToExtractVec;
159     for (BasicBlock *BB : BBs) {
160       // Check if the module contains BB.
161       if (BB->getParent()->getParent() != &M)
162         report_fatal_error("Invalid basic block", /*GenCrashDiag=*/false);
163       LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
164                         << BB->getParent()->getName() << ":" << BB->getName()
165                         << "\n");
166       BlocksToExtractVec.push_back(BB);
167       if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()))
168         BlocksToExtractVec.push_back(II->getUnwindDest());
169       ++NumExtracted;
170       Changed = true;
171     }
172     CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent());
173     Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC);
174     if (F)
175       LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName()
176                         << "' in: " << F->getName() << '\n');
177     else
178       LLVM_DEBUG(dbgs() << "Failed to extract for group '"
179                         << (*BBs.begin())->getName() << "'\n");
180   }
181 
182   // Erase the functions.
183   if (EraseFunctions || BlockExtractorEraseFuncs) {
184     for (Function *F : Functions) {
185       LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName()
186                         << "\n");
187       F->deleteBody();
188     }
189     // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
190     for (Function &F : M)
191       F.setLinkage(GlobalValue::ExternalLinkage);
192     Changed = true;
193   }
194 
195   return Changed;
196 }
197 
198 BlockExtractorPass::BlockExtractorPass(
199     std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks,
200     bool EraseFunctions)
201     : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {}
202 
203 PreservedAnalyses BlockExtractorPass::run(Module &M,
204                                           ModuleAnalysisManager &AM) {
205   BlockExtractor BE(EraseFunctions);
206   BE.init(GroupsOfBlocks);
207   return BE.runOnModule(M) ? PreservedAnalyses::none()
208                            : PreservedAnalyses::all();
209 }
210