1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass extracts the specified basic blocks from the module into their 10 // own functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Transforms/IPO/BlockExtractor.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/IR/PassManager.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/MemoryBuffer.h" 23 #include "llvm/Transforms/IPO.h" 24 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 25 #include "llvm/Transforms/Utils/CodeExtractor.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "block-extractor" 30 31 STATISTIC(NumExtracted, "Number of basic blocks extracted"); 32 33 static cl::opt<std::string> BlockExtractorFile( 34 "extract-blocks-file", cl::value_desc("filename"), 35 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden); 36 37 static cl::opt<bool> 38 BlockExtractorEraseFuncs("extract-blocks-erase-funcs", 39 cl::desc("Erase the existing functions"), 40 cl::Hidden); 41 namespace { 42 class BlockExtractor { 43 public: 44 BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {} 45 bool runOnModule(Module &M); 46 void 47 init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) { 48 GroupsOfBlocks = GroupsOfBlocksToExtract; 49 if (!BlockExtractorFile.empty()) 50 loadFile(); 51 } 52 53 private: 54 std::vector<std::vector<BasicBlock *>> GroupsOfBlocks; 55 bool EraseFunctions; 56 /// Map a function name to groups of blocks. 57 SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> 58 BlocksByName; 59 60 void loadFile(); 61 void splitLandingPadPreds(Function &F); 62 }; 63 64 } // end anonymous namespace 65 66 /// Gets all of the blocks specified in the input file. 67 void BlockExtractor::loadFile() { 68 auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); 69 if (ErrOrBuf.getError()) 70 report_fatal_error("BlockExtractor couldn't load the file."); 71 // Read the file. 72 auto &Buf = *ErrOrBuf; 73 SmallVector<StringRef, 16> Lines; 74 Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, 75 /*KeepEmpty=*/false); 76 for (const auto &Line : Lines) { 77 SmallVector<StringRef, 4> LineSplit; 78 Line.split(LineSplit, ' ', /*MaxSplit=*/-1, 79 /*KeepEmpty=*/false); 80 if (LineSplit.empty()) 81 continue; 82 if (LineSplit.size()!=2) 83 reportFatalUsageError( 84 "Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'"); 85 SmallVector<StringRef, 4> BBNames; 86 LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, 87 /*KeepEmpty=*/false); 88 if (BBNames.empty()) 89 report_fatal_error("Missing bbs name"); 90 BlocksByName.push_back( 91 {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}}); 92 } 93 } 94 95 /// Extracts the landing pads to make sure all of them have only one 96 /// predecessor. 97 void BlockExtractor::splitLandingPadPreds(Function &F) { 98 for (BasicBlock &BB : F) { 99 for (Instruction &I : BB) { 100 if (!isa<InvokeInst>(&I)) 101 continue; 102 InvokeInst *II = cast<InvokeInst>(&I); 103 BasicBlock *Parent = II->getParent(); 104 BasicBlock *LPad = II->getUnwindDest(); 105 106 // Look through the landing pad's predecessors. If one of them ends in an 107 // 'invoke', then we want to split the landing pad. 108 bool Split = false; 109 for (auto *PredBB : predecessors(LPad)) { 110 if (PredBB->isLandingPad() && PredBB != Parent && 111 isa<InvokeInst>(Parent->getTerminator())) { 112 Split = true; 113 break; 114 } 115 } 116 117 if (!Split) 118 continue; 119 120 SmallVector<BasicBlock *, 2> NewBBs; 121 SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); 122 } 123 } 124 } 125 126 bool BlockExtractor::runOnModule(Module &M) { 127 bool Changed = false; 128 129 // Get all the functions. 130 SmallVector<Function *, 4> Functions; 131 for (Function &F : M) { 132 splitLandingPadPreds(F); 133 Functions.push_back(&F); 134 } 135 136 // Get all the blocks specified in the input file. 137 unsigned NextGroupIdx = GroupsOfBlocks.size(); 138 GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); 139 for (const auto &BInfo : BlocksByName) { 140 Function *F = M.getFunction(BInfo.first); 141 if (!F) 142 reportFatalUsageError( 143 "Invalid function name specified in the input file"); 144 for (const auto &BBInfo : BInfo.second) { 145 auto Res = llvm::find_if( 146 *F, [&](const BasicBlock &BB) { return BB.getName() == BBInfo; }); 147 if (Res == F->end()) 148 reportFatalUsageError("Invalid block name specified in the input file"); 149 GroupsOfBlocks[NextGroupIdx].push_back(&*Res); 150 } 151 ++NextGroupIdx; 152 } 153 154 // Extract each group of basic blocks. 155 for (auto &BBs : GroupsOfBlocks) { 156 SmallVector<BasicBlock *, 32> BlocksToExtractVec; 157 for (BasicBlock *BB : BBs) { 158 // Check if the module contains BB. 159 if (BB->getParent()->getParent() != &M) 160 reportFatalUsageError("Invalid basic block"); 161 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " 162 << BB->getParent()->getName() << ":" << BB->getName() 163 << "\n"); 164 BlocksToExtractVec.push_back(BB); 165 if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) 166 BlocksToExtractVec.push_back(II->getUnwindDest()); 167 ++NumExtracted; 168 Changed = true; 169 } 170 CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent()); 171 Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC); 172 if (F) 173 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() 174 << "' in: " << F->getName() << '\n'); 175 else 176 LLVM_DEBUG(dbgs() << "Failed to extract for group '" 177 << (*BBs.begin())->getName() << "'\n"); 178 } 179 180 // Erase the functions. 181 if (EraseFunctions || BlockExtractorEraseFuncs) { 182 for (Function *F : Functions) { 183 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName() 184 << "\n"); 185 F->deleteBody(); 186 } 187 // Set linkage as ExternalLinkage to avoid erasing unreachable functions. 188 for (Function &F : M) 189 F.setLinkage(GlobalValue::ExternalLinkage); 190 Changed = true; 191 } 192 193 return Changed; 194 } 195 196 BlockExtractorPass::BlockExtractorPass( 197 std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks, 198 bool EraseFunctions) 199 : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {} 200 201 PreservedAnalyses BlockExtractorPass::run(Module &M, 202 ModuleAnalysisManager &AM) { 203 BlockExtractor BE(EraseFunctions); 204 BE.init(GroupsOfBlocks); 205 return BE.runOnModule(M) ? PreservedAnalyses::none() 206 : PreservedAnalyses::all(); 207 } 208