1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/Error.h" 26 #include "llvm/Support/FileSystem.h" 27 #include "llvm/Support/InitLLVM.h" 28 #include "llvm/Support/Regex.h" 29 #include "llvm/Support/SourceMgr.h" 30 #include "llvm/Support/SystemUtils.h" 31 #include "llvm/Support/ToolOutputFile.h" 32 #include "llvm/Transforms/IPO.h" 33 #include <memory> 34 #include <utility> 35 using namespace llvm; 36 37 cl::OptionCategory ExtractCat("llvm-extract Options"); 38 39 // InputFilename - The filename to read from. 40 static cl::opt<std::string> InputFilename(cl::Positional, 41 cl::desc("<input bitcode file>"), 42 cl::init("-"), 43 cl::value_desc("filename")); 44 45 static cl::opt<std::string> OutputFilename("o", 46 cl::desc("Specify output filename"), 47 cl::value_desc("filename"), 48 cl::init("-"), cl::cat(ExtractCat)); 49 50 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 51 cl::cat(ExtractCat)); 52 53 static cl::opt<bool> DeleteFn("delete", 54 cl::desc("Delete specified Globals from Module"), 55 cl::cat(ExtractCat)); 56 57 static cl::opt<bool> KeepConstInit("keep-const-init", 58 cl::desc("Keep initializers of constants"), 59 cl::cat(ExtractCat)); 60 61 static cl::opt<bool> 62 Recursive("recursive", cl::desc("Recursively extract all called functions"), 63 cl::cat(ExtractCat)); 64 65 // ExtractFuncs - The functions to extract from the module. 66 static cl::list<std::string> 67 ExtractFuncs("func", cl::desc("Specify function to extract"), 68 cl::ZeroOrMore, cl::value_desc("function"), 69 cl::cat(ExtractCat)); 70 71 // ExtractRegExpFuncs - The functions, matched via regular expression, to 72 // extract from the module. 73 static cl::list<std::string> 74 ExtractRegExpFuncs("rfunc", 75 cl::desc("Specify function(s) to extract using a " 76 "regular expression"), 77 cl::ZeroOrMore, cl::value_desc("rfunction"), 78 cl::cat(ExtractCat)); 79 80 // ExtractBlocks - The blocks to extract from the module. 81 static cl::list<std::string> ExtractBlocks( 82 "bb", 83 cl::desc( 84 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" 85 "Each pair will create a function.\n" 86 "If multiple basic blocks are specified in one pair,\n" 87 "the first block in the sequence should dominate the rest.\n" 88 "eg:\n" 89 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" 90 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " 91 "with bb2."), 92 cl::ZeroOrMore, cl::value_desc("function:bb1[;bb2...]"), 93 cl::cat(ExtractCat)); 94 95 // ExtractAlias - The alias to extract from the module. 96 static cl::list<std::string> 97 ExtractAliases("alias", cl::desc("Specify alias to extract"), 98 cl::ZeroOrMore, cl::value_desc("alias"), 99 cl::cat(ExtractCat)); 100 101 // ExtractRegExpAliases - The aliases, matched via regular expression, to 102 // extract from the module. 103 static cl::list<std::string> 104 ExtractRegExpAliases("ralias", 105 cl::desc("Specify alias(es) to extract using a " 106 "regular expression"), 107 cl::ZeroOrMore, cl::value_desc("ralias"), 108 cl::cat(ExtractCat)); 109 110 // ExtractGlobals - The globals to extract from the module. 111 static cl::list<std::string> 112 ExtractGlobals("glob", cl::desc("Specify global to extract"), 113 cl::ZeroOrMore, cl::value_desc("global"), 114 cl::cat(ExtractCat)); 115 116 // ExtractRegExpGlobals - The globals, matched via regular expression, to 117 // extract from the module... 118 static cl::list<std::string> 119 ExtractRegExpGlobals("rglob", 120 cl::desc("Specify global(s) to extract using a " 121 "regular expression"), 122 cl::ZeroOrMore, cl::value_desc("rglobal"), 123 cl::cat(ExtractCat)); 124 125 static cl::opt<bool> OutputAssembly("S", 126 cl::desc("Write output as LLVM assembly"), 127 cl::Hidden, cl::cat(ExtractCat)); 128 129 static cl::opt<bool> PreserveBitcodeUseListOrder( 130 "preserve-bc-uselistorder", 131 cl::desc("Preserve use-list order when writing LLVM bitcode."), 132 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 133 134 static cl::opt<bool> PreserveAssemblyUseListOrder( 135 "preserve-ll-uselistorder", 136 cl::desc("Preserve use-list order when writing LLVM assembly."), 137 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 138 139 int main(int argc, char **argv) { 140 InitLLVM X(argc, argv); 141 142 LLVMContext Context; 143 cl::HideUnrelatedOptions(ExtractCat); 144 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 145 146 // Use lazy loading, since we only care about selected global values. 147 SMDiagnostic Err; 148 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 149 150 if (!M.get()) { 151 Err.print(argv[0], errs()); 152 return 1; 153 } 154 155 // Use SetVector to avoid duplicates. 156 SetVector<GlobalValue *> GVs; 157 158 // Figure out which aliases we should extract. 159 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 160 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 161 if (!GA) { 162 errs() << argv[0] << ": program doesn't contain alias named '" 163 << ExtractAliases[i] << "'!\n"; 164 return 1; 165 } 166 GVs.insert(GA); 167 } 168 169 // Extract aliases via regular expression matching. 170 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 171 std::string Error; 172 Regex RegEx(ExtractRegExpAliases[i]); 173 if (!RegEx.isValid(Error)) { 174 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 175 "invalid regex: " << Error; 176 } 177 bool match = false; 178 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 179 GA != E; GA++) { 180 if (RegEx.match(GA->getName())) { 181 GVs.insert(&*GA); 182 match = true; 183 } 184 } 185 if (!match) { 186 errs() << argv[0] << ": program doesn't contain global named '" 187 << ExtractRegExpAliases[i] << "'!\n"; 188 return 1; 189 } 190 } 191 192 // Figure out which globals we should extract. 193 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 194 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 195 if (!GV) { 196 errs() << argv[0] << ": program doesn't contain global named '" 197 << ExtractGlobals[i] << "'!\n"; 198 return 1; 199 } 200 GVs.insert(GV); 201 } 202 203 // Extract globals via regular expression matching. 204 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 205 std::string Error; 206 Regex RegEx(ExtractRegExpGlobals[i]); 207 if (!RegEx.isValid(Error)) { 208 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 209 "invalid regex: " << Error; 210 } 211 bool match = false; 212 for (auto &GV : M->globals()) { 213 if (RegEx.match(GV.getName())) { 214 GVs.insert(&GV); 215 match = true; 216 } 217 } 218 if (!match) { 219 errs() << argv[0] << ": program doesn't contain global named '" 220 << ExtractRegExpGlobals[i] << "'!\n"; 221 return 1; 222 } 223 } 224 225 // Figure out which functions we should extract. 226 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 227 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 228 if (!GV) { 229 errs() << argv[0] << ": program doesn't contain function named '" 230 << ExtractFuncs[i] << "'!\n"; 231 return 1; 232 } 233 GVs.insert(GV); 234 } 235 // Extract functions via regular expression matching. 236 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 237 std::string Error; 238 StringRef RegExStr = ExtractRegExpFuncs[i]; 239 Regex RegEx(RegExStr); 240 if (!RegEx.isValid(Error)) { 241 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 242 "invalid regex: " << Error; 243 } 244 bool match = false; 245 for (Module::iterator F = M->begin(), E = M->end(); F != E; 246 F++) { 247 if (RegEx.match(F->getName())) { 248 GVs.insert(&*F); 249 match = true; 250 } 251 } 252 if (!match) { 253 errs() << argv[0] << ": program doesn't contain global named '" 254 << ExtractRegExpFuncs[i] << "'!\n"; 255 return 1; 256 } 257 } 258 259 // Figure out which BasicBlocks we should extract. 260 SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; 261 for (StringRef StrPair : ExtractBlocks) { 262 SmallVector<StringRef, 16> BBNames; 263 auto BBInfo = StrPair.split(':'); 264 // Get the function. 265 Function *F = M->getFunction(BBInfo.first); 266 if (!F) { 267 errs() << argv[0] << ": program doesn't contain a function named '" 268 << BBInfo.first << "'!\n"; 269 return 1; 270 } 271 // Add the function to the materialize list, and store the basic block names 272 // to check after materialization. 273 GVs.insert(F); 274 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); 275 BBMap.push_back({F, std::move(BBNames)}); 276 } 277 278 // Use *argv instead of argv[0] to work around a wrong GCC warning. 279 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 280 281 if (Recursive) { 282 std::vector<llvm::Function *> Workqueue; 283 for (GlobalValue *GV : GVs) { 284 if (auto *F = dyn_cast<Function>(GV)) { 285 Workqueue.push_back(F); 286 } 287 } 288 while (!Workqueue.empty()) { 289 Function *F = &*Workqueue.back(); 290 Workqueue.pop_back(); 291 ExitOnErr(F->materialize()); 292 for (auto &BB : *F) { 293 for (auto &I : BB) { 294 CallBase *CB = dyn_cast<CallBase>(&I); 295 if (!CB) 296 continue; 297 Function *CF = CB->getCalledFunction(); 298 if (!CF) 299 continue; 300 if (CF->isDeclaration() || GVs.count(CF)) 301 continue; 302 GVs.insert(CF); 303 Workqueue.push_back(CF); 304 } 305 } 306 } 307 } 308 309 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 310 311 // Materialize requisite global values. 312 if (!DeleteFn) { 313 for (size_t i = 0, e = GVs.size(); i != e; ++i) 314 Materialize(*GVs[i]); 315 } else { 316 // Deleting. Materialize every GV that's *not* in GVs. 317 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 318 for (auto &F : *M) { 319 if (!GVSet.count(&F)) 320 Materialize(F); 321 } 322 } 323 324 { 325 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 326 legacy::PassManager Extract; 327 Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit)); 328 Extract.run(*M); 329 330 // Now that we have all the GVs we want, mark the module as fully 331 // materialized. 332 // FIXME: should the GVExtractionPass handle this? 333 ExitOnErr(M->materializeAll()); 334 } 335 336 // Extract the specified basic blocks from the module and erase the existing 337 // functions. 338 if (!ExtractBlocks.empty()) { 339 // Figure out which BasicBlocks we should extract. 340 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs; 341 for (auto &P : BBMap) { 342 SmallVector<BasicBlock *, 16> BBs; 343 for (StringRef BBName : P.second) { 344 // The function has been materialized, so add its matching basic blocks 345 // to the block extractor list, or fail if a name is not found. 346 auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) { 347 return BB.getName().equals(BBName); 348 }); 349 if (Res == P.first->end()) { 350 errs() << argv[0] << ": function " << P.first->getName() 351 << " doesn't contain a basic block named '" << BBName 352 << "'!\n"; 353 return 1; 354 } 355 BBs.push_back(&*Res); 356 } 357 GroupOfBBs.push_back(BBs); 358 } 359 360 legacy::PassManager PM; 361 PM.add(createBlockExtractorPass(GroupOfBBs, true)); 362 PM.run(*M); 363 } 364 365 // In addition to deleting all other functions, we also want to spiff it 366 // up a little bit. Do this now. 367 legacy::PassManager Passes; 368 369 if (!DeleteFn) 370 Passes.add(createGlobalDCEPass()); // Delete unreachable globals 371 Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info 372 Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls 373 374 std::error_code EC; 375 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); 376 if (EC) { 377 errs() << EC.message() << '\n'; 378 return 1; 379 } 380 381 if (OutputAssembly) 382 Passes.add( 383 createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 384 else if (Force || !CheckBitcodeOutputToConsole(Out.os())) 385 Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 386 387 Passes.run(*M.get()); 388 389 // Declare success. 390 Out.keep(); 391 392 return 0; 393 } 394