1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass extracts the specified basic blocks from the module into their 10 // own functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/STLExtras.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/Module.h" 18 #include "llvm/InitializePasses.h" 19 #include "llvm/Pass.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/MemoryBuffer.h" 23 #include "llvm/Transforms/IPO.h" 24 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 25 #include "llvm/Transforms/Utils/CodeExtractor.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "block-extractor" 30 31 STATISTIC(NumExtracted, "Number of basic blocks extracted"); 32 33 static cl::opt<std::string> BlockExtractorFile( 34 "extract-blocks-file", cl::value_desc("filename"), 35 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden); 36 37 cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", 38 cl::desc("Erase the existing functions"), 39 cl::Hidden); 40 namespace { 41 class BlockExtractor : public ModulePass { 42 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; 43 bool EraseFunctions; 44 /// Map a function name to groups of blocks. 45 SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> 46 BlocksByName; 47 48 void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> 49 &GroupsOfBlocksToExtract) { 50 for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks : 51 GroupsOfBlocksToExtract) { 52 SmallVector<BasicBlock *, 16> NewGroup; 53 NewGroup.append(GroupOfBlocks.begin(), GroupOfBlocks.end()); 54 GroupsOfBlocks.emplace_back(NewGroup); 55 } 56 if (!BlockExtractorFile.empty()) 57 loadFile(); 58 } 59 60 public: 61 static char ID; 62 BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, 63 bool EraseFunctions) 64 : ModulePass(ID), EraseFunctions(EraseFunctions) { 65 // We want one group per element of the input list. 66 SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks; 67 for (BasicBlock *BB : BlocksToExtract) { 68 SmallVector<BasicBlock *, 16> NewGroup; 69 NewGroup.push_back(BB); 70 MassagedGroupsOfBlocks.push_back(NewGroup); 71 } 72 init(MassagedGroupsOfBlocks); 73 } 74 75 BlockExtractor(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> 76 &GroupsOfBlocksToExtract, 77 bool EraseFunctions) 78 : ModulePass(ID), EraseFunctions(EraseFunctions) { 79 init(GroupsOfBlocksToExtract); 80 } 81 82 BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} 83 bool runOnModule(Module &M) override; 84 85 private: 86 void loadFile(); 87 void splitLandingPadPreds(Function &F); 88 }; 89 } // end anonymous namespace 90 91 char BlockExtractor::ID = 0; 92 INITIALIZE_PASS(BlockExtractor, "extract-blocks", 93 "Extract basic blocks from module", false, false) 94 95 ModulePass *llvm::createBlockExtractorPass() { return new BlockExtractor(); } 96 ModulePass *llvm::createBlockExtractorPass( 97 const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { 98 return new BlockExtractor(BlocksToExtract, EraseFunctions); 99 } 100 ModulePass *llvm::createBlockExtractorPass( 101 const SmallVectorImpl<SmallVector<BasicBlock *, 16>> 102 &GroupsOfBlocksToExtract, 103 bool EraseFunctions) { 104 return new BlockExtractor(GroupsOfBlocksToExtract, EraseFunctions); 105 } 106 107 /// Gets all of the blocks specified in the input file. 108 void BlockExtractor::loadFile() { 109 auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); 110 if (ErrOrBuf.getError()) 111 report_fatal_error("BlockExtractor couldn't load the file."); 112 // Read the file. 113 auto &Buf = *ErrOrBuf; 114 SmallVector<StringRef, 16> Lines; 115 Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, 116 /*KeepEmpty=*/false); 117 for (const auto &Line : Lines) { 118 SmallVector<StringRef, 4> LineSplit; 119 Line.split(LineSplit, ' ', /*MaxSplit=*/-1, 120 /*KeepEmpty=*/false); 121 if (LineSplit.empty()) 122 continue; 123 if (LineSplit.size()!=2) 124 report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'"); 125 SmallVector<StringRef, 4> BBNames; 126 LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, 127 /*KeepEmpty=*/false); 128 if (BBNames.empty()) 129 report_fatal_error("Missing bbs name"); 130 BlocksByName.push_back( 131 {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}}); 132 } 133 } 134 135 /// Extracts the landing pads to make sure all of them have only one 136 /// predecessor. 137 void BlockExtractor::splitLandingPadPreds(Function &F) { 138 for (BasicBlock &BB : F) { 139 for (Instruction &I : BB) { 140 if (!isa<InvokeInst>(&I)) 141 continue; 142 InvokeInst *II = cast<InvokeInst>(&I); 143 BasicBlock *Parent = II->getParent(); 144 BasicBlock *LPad = II->getUnwindDest(); 145 146 // Look through the landing pad's predecessors. If one of them ends in an 147 // 'invoke', then we want to split the landing pad. 148 bool Split = false; 149 for (auto PredBB : predecessors(LPad)) { 150 if (PredBB->isLandingPad() && PredBB != Parent && 151 isa<InvokeInst>(Parent->getTerminator())) { 152 Split = true; 153 break; 154 } 155 } 156 157 if (!Split) 158 continue; 159 160 SmallVector<BasicBlock *, 2> NewBBs; 161 SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); 162 } 163 } 164 } 165 166 bool BlockExtractor::runOnModule(Module &M) { 167 168 bool Changed = false; 169 170 // Get all the functions. 171 SmallVector<Function *, 4> Functions; 172 for (Function &F : M) { 173 splitLandingPadPreds(F); 174 Functions.push_back(&F); 175 } 176 177 // Get all the blocks specified in the input file. 178 unsigned NextGroupIdx = GroupsOfBlocks.size(); 179 GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); 180 for (const auto &BInfo : BlocksByName) { 181 Function *F = M.getFunction(BInfo.first); 182 if (!F) 183 report_fatal_error("Invalid function name specified in the input file"); 184 for (const auto &BBInfo : BInfo.second) { 185 auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { 186 return BB.getName().equals(BBInfo); 187 }); 188 if (Res == F->end()) 189 report_fatal_error("Invalid block name specified in the input file"); 190 GroupsOfBlocks[NextGroupIdx].push_back(&*Res); 191 } 192 ++NextGroupIdx; 193 } 194 195 // Extract each group of basic blocks. 196 for (auto &BBs : GroupsOfBlocks) { 197 SmallVector<BasicBlock *, 32> BlocksToExtractVec; 198 for (BasicBlock *BB : BBs) { 199 // Check if the module contains BB. 200 if (BB->getParent()->getParent() != &M) 201 report_fatal_error("Invalid basic block"); 202 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " 203 << BB->getParent()->getName() << ":" << BB->getName() 204 << "\n"); 205 BlocksToExtractVec.push_back(BB); 206 if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) 207 BlocksToExtractVec.push_back(II->getUnwindDest()); 208 ++NumExtracted; 209 Changed = true; 210 } 211 CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent()); 212 Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC); 213 if (F) 214 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() 215 << "' in: " << F->getName() << '\n'); 216 else 217 LLVM_DEBUG(dbgs() << "Failed to extract for group '" 218 << (*BBs.begin())->getName() << "'\n"); 219 } 220 221 // Erase the functions. 222 if (EraseFunctions || BlockExtractorEraseFuncs) { 223 for (Function *F : Functions) { 224 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName() 225 << "\n"); 226 F->deleteBody(); 227 } 228 // Set linkage as ExternalLinkage to avoid erasing unreachable functions. 229 for (Function &F : M) 230 F.setLinkage(GlobalValue::ExternalLinkage); 231 Changed = true; 232 } 233 234 return Changed; 235 } 236