1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass extracts the specified basic blocks from the module into their 10 // own functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Transforms/IPO/BlockExtractor.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/IR/PassManager.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/MemoryBuffer.h" 23 #include "llvm/Transforms/IPO.h" 24 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 25 #include "llvm/Transforms/Utils/CodeExtractor.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "block-extractor" 30 31 STATISTIC(NumExtracted, "Number of basic blocks extracted"); 32 33 static cl::opt<std::string> BlockExtractorFile( 34 "extract-blocks-file", cl::value_desc("filename"), 35 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden); 36 37 static cl::opt<bool> 38 BlockExtractorEraseFuncs("extract-blocks-erase-funcs", 39 cl::desc("Erase the existing functions"), 40 cl::Hidden); 41 namespace { 42 class BlockExtractor { 43 public: 44 BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {} 45 bool runOnModule(Module &M); 46 void 47 init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) { 48 GroupsOfBlocks = GroupsOfBlocksToExtract; 49 if (!BlockExtractorFile.empty()) 50 loadFile(); 51 } 52 53 private: 54 std::vector<std::vector<BasicBlock *>> GroupsOfBlocks; 55 bool EraseFunctions; 56 /// Map a function name to groups of blocks. 57 SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> 58 BlocksByName; 59 60 void loadFile(); 61 void splitLandingPadPreds(Function &F); 62 }; 63 64 } // end anonymous namespace 65 66 /// Gets all of the blocks specified in the input file. 67 void BlockExtractor::loadFile() { 68 auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); 69 if (ErrOrBuf.getError()) 70 report_fatal_error("BlockExtractor couldn't load the file."); 71 // Read the file. 72 auto &Buf = *ErrOrBuf; 73 SmallVector<StringRef, 16> Lines; 74 Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, 75 /*KeepEmpty=*/false); 76 for (const auto &Line : Lines) { 77 SmallVector<StringRef, 4> LineSplit; 78 Line.split(LineSplit, ' ', /*MaxSplit=*/-1, 79 /*KeepEmpty=*/false); 80 if (LineSplit.empty()) 81 continue; 82 if (LineSplit.size()!=2) 83 report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'", 84 /*GenCrashDiag=*/false); 85 SmallVector<StringRef, 4> BBNames; 86 LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, 87 /*KeepEmpty=*/false); 88 if (BBNames.empty()) 89 report_fatal_error("Missing bbs name"); 90 BlocksByName.push_back( 91 {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}}); 92 } 93 } 94 95 /// Extracts the landing pads to make sure all of them have only one 96 /// predecessor. 97 void BlockExtractor::splitLandingPadPreds(Function &F) { 98 for (BasicBlock &BB : F) { 99 for (Instruction &I : BB) { 100 if (!isa<InvokeInst>(&I)) 101 continue; 102 InvokeInst *II = cast<InvokeInst>(&I); 103 BasicBlock *Parent = II->getParent(); 104 BasicBlock *LPad = II->getUnwindDest(); 105 106 // Look through the landing pad's predecessors. If one of them ends in an 107 // 'invoke', then we want to split the landing pad. 108 bool Split = false; 109 for (auto *PredBB : predecessors(LPad)) { 110 if (PredBB->isLandingPad() && PredBB != Parent && 111 isa<InvokeInst>(Parent->getTerminator())) { 112 Split = true; 113 break; 114 } 115 } 116 117 if (!Split) 118 continue; 119 120 SmallVector<BasicBlock *, 2> NewBBs; 121 SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); 122 } 123 } 124 } 125 126 bool BlockExtractor::runOnModule(Module &M) { 127 bool Changed = false; 128 129 // Get all the functions. 130 SmallVector<Function *, 4> Functions; 131 for (Function &F : M) { 132 splitLandingPadPreds(F); 133 Functions.push_back(&F); 134 } 135 136 // Get all the blocks specified in the input file. 137 unsigned NextGroupIdx = GroupsOfBlocks.size(); 138 GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); 139 for (const auto &BInfo : BlocksByName) { 140 Function *F = M.getFunction(BInfo.first); 141 if (!F) 142 report_fatal_error("Invalid function name specified in the input file", 143 /*GenCrashDiag=*/false); 144 for (const auto &BBInfo : BInfo.second) { 145 auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { 146 return BB.getName().equals(BBInfo); 147 }); 148 if (Res == F->end()) 149 report_fatal_error("Invalid block name specified in the input file", 150 /*GenCrashDiag=*/false); 151 GroupsOfBlocks[NextGroupIdx].push_back(&*Res); 152 } 153 ++NextGroupIdx; 154 } 155 156 // Extract each group of basic blocks. 157 for (auto &BBs : GroupsOfBlocks) { 158 SmallVector<BasicBlock *, 32> BlocksToExtractVec; 159 for (BasicBlock *BB : BBs) { 160 // Check if the module contains BB. 161 if (BB->getParent()->getParent() != &M) 162 report_fatal_error("Invalid basic block", /*GenCrashDiag=*/false); 163 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " 164 << BB->getParent()->getName() << ":" << BB->getName() 165 << "\n"); 166 BlocksToExtractVec.push_back(BB); 167 if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) 168 BlocksToExtractVec.push_back(II->getUnwindDest()); 169 ++NumExtracted; 170 Changed = true; 171 } 172 CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent()); 173 Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC); 174 if (F) 175 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() 176 << "' in: " << F->getName() << '\n'); 177 else 178 LLVM_DEBUG(dbgs() << "Failed to extract for group '" 179 << (*BBs.begin())->getName() << "'\n"); 180 } 181 182 // Erase the functions. 183 if (EraseFunctions || BlockExtractorEraseFuncs) { 184 for (Function *F : Functions) { 185 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName() 186 << "\n"); 187 F->deleteBody(); 188 } 189 // Set linkage as ExternalLinkage to avoid erasing unreachable functions. 190 for (Function &F : M) 191 F.setLinkage(GlobalValue::ExternalLinkage); 192 Changed = true; 193 } 194 195 return Changed; 196 } 197 198 BlockExtractorPass::BlockExtractorPass( 199 std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks, 200 bool EraseFunctions) 201 : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {} 202 203 PreservedAnalyses BlockExtractorPass::run(Module &M, 204 ModuleAnalysisManager &AM) { 205 BlockExtractor BE(EraseFunctions); 206 BE.init(GroupsOfBlocks); 207 return BE.runOnModule(M) ? PreservedAnalyses::none() 208 : PreservedAnalyses::all(); 209 } 210