1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/Module.h" 22 #include "llvm/IRPrinter/IRPrintingPasses.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Passes/PassBuilder.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Error.h" 27 #include "llvm/Support/FileSystem.h" 28 #include "llvm/Support/InitLLVM.h" 29 #include "llvm/Support/Regex.h" 30 #include "llvm/Support/SourceMgr.h" 31 #include "llvm/Support/SystemUtils.h" 32 #include "llvm/Support/ToolOutputFile.h" 33 #include "llvm/Transforms/IPO.h" 34 #include "llvm/Transforms/IPO/BlockExtractor.h" 35 #include "llvm/Transforms/IPO/ExtractGV.h" 36 #include "llvm/Transforms/IPO/GlobalDCE.h" 37 #include "llvm/Transforms/IPO/StripDeadPrototypes.h" 38 #include "llvm/Transforms/IPO/StripSymbols.h" 39 #include <memory> 40 #include <utility> 41 42 using namespace llvm; 43 44 cl::OptionCategory ExtractCat("llvm-extract Options"); 45 46 // InputFilename - The filename to read from. 47 static cl::opt<std::string> InputFilename(cl::Positional, 48 cl::desc("<input bitcode file>"), 49 cl::init("-"), 50 cl::value_desc("filename")); 51 52 static cl::opt<std::string> OutputFilename("o", 53 cl::desc("Specify output filename"), 54 cl::value_desc("filename"), 55 cl::init("-"), cl::cat(ExtractCat)); 56 57 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 58 cl::cat(ExtractCat)); 59 60 static cl::opt<bool> DeleteFn("delete", 61 cl::desc("Delete specified Globals from Module"), 62 cl::cat(ExtractCat)); 63 64 static cl::opt<bool> KeepConstInit("keep-const-init", 65 cl::desc("Keep initializers of constants"), 66 cl::cat(ExtractCat)); 67 68 static cl::opt<bool> 69 Recursive("recursive", cl::desc("Recursively extract all called functions"), 70 cl::cat(ExtractCat)); 71 72 // ExtractFuncs - The functions to extract from the module. 73 static cl::list<std::string> 74 ExtractFuncs("func", cl::desc("Specify function to extract"), 75 cl::value_desc("function"), cl::cat(ExtractCat)); 76 77 // ExtractRegExpFuncs - The functions, matched via regular expression, to 78 // extract from the module. 79 static cl::list<std::string> 80 ExtractRegExpFuncs("rfunc", 81 cl::desc("Specify function(s) to extract using a " 82 "regular expression"), 83 cl::value_desc("rfunction"), cl::cat(ExtractCat)); 84 85 // ExtractBlocks - The blocks to extract from the module. 86 static cl::list<std::string> ExtractBlocks( 87 "bb", 88 cl::desc( 89 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" 90 "Each pair will create a function.\n" 91 "If multiple basic blocks are specified in one pair,\n" 92 "the first block in the sequence should dominate the rest.\n" 93 "eg:\n" 94 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" 95 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " 96 "with bb2."), 97 cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat)); 98 99 // ExtractAlias - The alias to extract from the module. 100 static cl::list<std::string> 101 ExtractAliases("alias", cl::desc("Specify alias to extract"), 102 cl::value_desc("alias"), cl::cat(ExtractCat)); 103 104 // ExtractRegExpAliases - The aliases, matched via regular expression, to 105 // extract from the module. 106 static cl::list<std::string> 107 ExtractRegExpAliases("ralias", 108 cl::desc("Specify alias(es) to extract using a " 109 "regular expression"), 110 cl::value_desc("ralias"), cl::cat(ExtractCat)); 111 112 // ExtractGlobals - The globals to extract from the module. 113 static cl::list<std::string> 114 ExtractGlobals("glob", cl::desc("Specify global to extract"), 115 cl::value_desc("global"), cl::cat(ExtractCat)); 116 117 // ExtractRegExpGlobals - The globals, matched via regular expression, to 118 // extract from the module... 119 static cl::list<std::string> 120 ExtractRegExpGlobals("rglob", 121 cl::desc("Specify global(s) to extract using a " 122 "regular expression"), 123 cl::value_desc("rglobal"), cl::cat(ExtractCat)); 124 125 static cl::opt<bool> OutputAssembly("S", 126 cl::desc("Write output as LLVM assembly"), 127 cl::Hidden, cl::cat(ExtractCat)); 128 129 static cl::opt<bool> PreserveBitcodeUseListOrder( 130 "preserve-bc-uselistorder", 131 cl::desc("Preserve use-list order when writing LLVM bitcode."), 132 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 133 134 static cl::opt<bool> PreserveAssemblyUseListOrder( 135 "preserve-ll-uselistorder", 136 cl::desc("Preserve use-list order when writing LLVM assembly."), 137 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 138 139 int main(int argc, char **argv) { 140 InitLLVM X(argc, argv); 141 142 LLVMContext Context; 143 cl::HideUnrelatedOptions(ExtractCat); 144 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 145 146 // Use lazy loading, since we only care about selected global values. 147 SMDiagnostic Err; 148 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 149 150 if (!M.get()) { 151 Err.print(argv[0], errs()); 152 return 1; 153 } 154 155 // Use SetVector to avoid duplicates. 156 SetVector<GlobalValue *> GVs; 157 158 // Figure out which aliases we should extract. 159 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 160 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 161 if (!GA) { 162 errs() << argv[0] << ": program doesn't contain alias named '" 163 << ExtractAliases[i] << "'!\n"; 164 return 1; 165 } 166 GVs.insert(GA); 167 } 168 169 // Extract aliases via regular expression matching. 170 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 171 std::string Error; 172 Regex RegEx(ExtractRegExpAliases[i]); 173 if (!RegEx.isValid(Error)) { 174 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 175 "invalid regex: " << Error; 176 } 177 bool match = false; 178 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 179 GA != E; GA++) { 180 if (RegEx.match(GA->getName())) { 181 GVs.insert(&*GA); 182 match = true; 183 } 184 } 185 if (!match) { 186 errs() << argv[0] << ": program doesn't contain global named '" 187 << ExtractRegExpAliases[i] << "'!\n"; 188 return 1; 189 } 190 } 191 192 // Figure out which globals we should extract. 193 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 194 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 195 if (!GV) { 196 errs() << argv[0] << ": program doesn't contain global named '" 197 << ExtractGlobals[i] << "'!\n"; 198 return 1; 199 } 200 GVs.insert(GV); 201 } 202 203 // Extract globals via regular expression matching. 204 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 205 std::string Error; 206 Regex RegEx(ExtractRegExpGlobals[i]); 207 if (!RegEx.isValid(Error)) { 208 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 209 "invalid regex: " << Error; 210 } 211 bool match = false; 212 for (auto &GV : M->globals()) { 213 if (RegEx.match(GV.getName())) { 214 GVs.insert(&GV); 215 match = true; 216 } 217 } 218 if (!match) { 219 errs() << argv[0] << ": program doesn't contain global named '" 220 << ExtractRegExpGlobals[i] << "'!\n"; 221 return 1; 222 } 223 } 224 225 // Figure out which functions we should extract. 226 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 227 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 228 if (!GV) { 229 errs() << argv[0] << ": program doesn't contain function named '" 230 << ExtractFuncs[i] << "'!\n"; 231 return 1; 232 } 233 GVs.insert(GV); 234 } 235 // Extract functions via regular expression matching. 236 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 237 std::string Error; 238 StringRef RegExStr = ExtractRegExpFuncs[i]; 239 Regex RegEx(RegExStr); 240 if (!RegEx.isValid(Error)) { 241 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 242 "invalid regex: " << Error; 243 } 244 bool match = false; 245 for (Module::iterator F = M->begin(), E = M->end(); F != E; 246 F++) { 247 if (RegEx.match(F->getName())) { 248 GVs.insert(&*F); 249 match = true; 250 } 251 } 252 if (!match) { 253 errs() << argv[0] << ": program doesn't contain global named '" 254 << ExtractRegExpFuncs[i] << "'!\n"; 255 return 1; 256 } 257 } 258 259 // Figure out which BasicBlocks we should extract. 260 SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; 261 for (StringRef StrPair : ExtractBlocks) { 262 SmallVector<StringRef, 16> BBNames; 263 auto BBInfo = StrPair.split(':'); 264 // Get the function. 265 Function *F = M->getFunction(BBInfo.first); 266 if (!F) { 267 errs() << argv[0] << ": program doesn't contain a function named '" 268 << BBInfo.first << "'!\n"; 269 return 1; 270 } 271 // Add the function to the materialize list, and store the basic block names 272 // to check after materialization. 273 GVs.insert(F); 274 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); 275 BBMap.push_back({F, std::move(BBNames)}); 276 } 277 278 // Use *argv instead of argv[0] to work around a wrong GCC warning. 279 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 280 281 if (Recursive) { 282 std::vector<llvm::Function *> Workqueue; 283 for (GlobalValue *GV : GVs) { 284 if (auto *F = dyn_cast<Function>(GV)) { 285 Workqueue.push_back(F); 286 } 287 } 288 while (!Workqueue.empty()) { 289 Function *F = &*Workqueue.back(); 290 Workqueue.pop_back(); 291 ExitOnErr(F->materialize()); 292 for (auto &BB : *F) { 293 for (auto &I : BB) { 294 CallBase *CB = dyn_cast<CallBase>(&I); 295 if (!CB) 296 continue; 297 Function *CF = CB->getCalledFunction(); 298 if (!CF) 299 continue; 300 if (CF->isDeclaration() || GVs.count(CF)) 301 continue; 302 GVs.insert(CF); 303 Workqueue.push_back(CF); 304 } 305 } 306 } 307 } 308 309 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 310 311 // Materialize requisite global values. 312 if (!DeleteFn) { 313 for (size_t i = 0, e = GVs.size(); i != e; ++i) 314 Materialize(*GVs[i]); 315 } else { 316 // Deleting. Materialize every GV that's *not* in GVs. 317 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 318 for (auto &F : *M) { 319 if (!GVSet.count(&F)) 320 Materialize(F); 321 } 322 } 323 324 { 325 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 326 LoopAnalysisManager LAM; 327 FunctionAnalysisManager FAM; 328 CGSCCAnalysisManager CGAM; 329 ModuleAnalysisManager MAM; 330 331 PassBuilder PB; 332 333 PB.registerModuleAnalyses(MAM); 334 PB.registerCGSCCAnalyses(CGAM); 335 PB.registerFunctionAnalyses(FAM); 336 PB.registerLoopAnalyses(LAM); 337 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 338 339 ModulePassManager PM; 340 PM.addPass(ExtractGVPass(Gvs, DeleteFn, KeepConstInit)); 341 PM.run(*M, MAM); 342 343 // Now that we have all the GVs we want, mark the module as fully 344 // materialized. 345 // FIXME: should the GVExtractionPass handle this? 346 ExitOnErr(M->materializeAll()); 347 } 348 349 // Extract the specified basic blocks from the module and erase the existing 350 // functions. 351 if (!ExtractBlocks.empty()) { 352 // Figure out which BasicBlocks we should extract. 353 std::vector<std::vector<BasicBlock *>> GroupOfBBs; 354 for (auto &P : BBMap) { 355 std::vector<BasicBlock *> BBs; 356 for (StringRef BBName : P.second) { 357 // The function has been materialized, so add its matching basic blocks 358 // to the block extractor list, or fail if a name is not found. 359 auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) { 360 return BB.getName().equals(BBName); 361 }); 362 if (Res == P.first->end()) { 363 errs() << argv[0] << ": function " << P.first->getName() 364 << " doesn't contain a basic block named '" << BBName 365 << "'!\n"; 366 return 1; 367 } 368 BBs.push_back(&*Res); 369 } 370 GroupOfBBs.push_back(BBs); 371 } 372 373 LoopAnalysisManager LAM; 374 FunctionAnalysisManager FAM; 375 CGSCCAnalysisManager CGAM; 376 ModuleAnalysisManager MAM; 377 378 PassBuilder PB; 379 380 PB.registerModuleAnalyses(MAM); 381 PB.registerCGSCCAnalyses(CGAM); 382 PB.registerFunctionAnalyses(FAM); 383 PB.registerLoopAnalyses(LAM); 384 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 385 386 ModulePassManager PM; 387 PM.addPass(BlockExtractorPass(std::move(GroupOfBBs), true)); 388 PM.run(*M, MAM); 389 } 390 391 // In addition to deleting all other functions, we also want to spiff it 392 // up a little bit. Do this now. 393 394 LoopAnalysisManager LAM; 395 FunctionAnalysisManager FAM; 396 CGSCCAnalysisManager CGAM; 397 ModuleAnalysisManager MAM; 398 399 PassBuilder PB; 400 401 PB.registerModuleAnalyses(MAM); 402 PB.registerCGSCCAnalyses(CGAM); 403 PB.registerFunctionAnalyses(FAM); 404 PB.registerLoopAnalyses(LAM); 405 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 406 407 ModulePassManager PM; 408 if (!DeleteFn) 409 PM.addPass(GlobalDCEPass()); 410 PM.addPass(StripDeadDebugInfoPass()); 411 PM.addPass(StripDeadPrototypesPass()); 412 413 std::error_code EC; 414 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); 415 if (EC) { 416 errs() << EC.message() << '\n'; 417 return 1; 418 } 419 420 if (OutputAssembly) 421 PM.addPass(PrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 422 else if (Force || !CheckBitcodeOutputToConsole(Out.os())) 423 PM.addPass(BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 424 425 PM.run(*M, MAM); 426 427 // Declare success. 428 Out.keep(); 429 430 return 0; 431 } 432