1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/Error.h" 26 #include "llvm/Support/FileSystem.h" 27 #include "llvm/Support/InitLLVM.h" 28 #include "llvm/Support/Regex.h" 29 #include "llvm/Support/SourceMgr.h" 30 #include "llvm/Support/SystemUtils.h" 31 #include "llvm/Support/ToolOutputFile.h" 32 #include "llvm/Transforms/IPO.h" 33 #include <memory> 34 using namespace llvm; 35 36 cl::OptionCategory ExtractCat("llvm-extract Options"); 37 38 // InputFilename - The filename to read from. 39 static cl::opt<std::string> InputFilename(cl::Positional, 40 cl::desc("<input bitcode file>"), 41 cl::init("-"), 42 cl::value_desc("filename")); 43 44 static cl::opt<std::string> OutputFilename("o", 45 cl::desc("Specify output filename"), 46 cl::value_desc("filename"), 47 cl::init("-"), cl::cat(ExtractCat)); 48 49 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 50 cl::cat(ExtractCat)); 51 52 static cl::opt<bool> DeleteFn("delete", 53 cl::desc("Delete specified Globals from Module"), 54 cl::cat(ExtractCat)); 55 56 static cl::opt<bool> 57 Recursive("recursive", cl::desc("Recursively extract all called functions"), 58 cl::cat(ExtractCat)); 59 60 // ExtractFuncs - The functions to extract from the module. 61 static cl::list<std::string> 62 ExtractFuncs("func", cl::desc("Specify function to extract"), 63 cl::ZeroOrMore, cl::value_desc("function"), 64 cl::cat(ExtractCat)); 65 66 // ExtractRegExpFuncs - The functions, matched via regular expression, to 67 // extract from the module. 68 static cl::list<std::string> 69 ExtractRegExpFuncs("rfunc", 70 cl::desc("Specify function(s) to extract using a " 71 "regular expression"), 72 cl::ZeroOrMore, cl::value_desc("rfunction"), 73 cl::cat(ExtractCat)); 74 75 // ExtractBlocks - The blocks to extract from the module. 76 static cl::list<std::string> ExtractBlocks( 77 "bb", 78 cl::desc( 79 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" 80 "Each pair will create a function.\n" 81 "If multiple basic blocks are specified in one pair,\n" 82 "the first block in the sequence should dominate the rest.\n" 83 "eg:\n" 84 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" 85 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " 86 "with bb2."), 87 cl::ZeroOrMore, cl::value_desc("function:bb1[;bb2...]"), 88 cl::cat(ExtractCat)); 89 90 // ExtractAlias - The alias to extract from the module. 91 static cl::list<std::string> 92 ExtractAliases("alias", cl::desc("Specify alias to extract"), 93 cl::ZeroOrMore, cl::value_desc("alias"), 94 cl::cat(ExtractCat)); 95 96 // ExtractRegExpAliases - The aliases, matched via regular expression, to 97 // extract from the module. 98 static cl::list<std::string> 99 ExtractRegExpAliases("ralias", 100 cl::desc("Specify alias(es) to extract using a " 101 "regular expression"), 102 cl::ZeroOrMore, cl::value_desc("ralias"), 103 cl::cat(ExtractCat)); 104 105 // ExtractGlobals - The globals to extract from the module. 106 static cl::list<std::string> 107 ExtractGlobals("glob", cl::desc("Specify global to extract"), 108 cl::ZeroOrMore, cl::value_desc("global"), 109 cl::cat(ExtractCat)); 110 111 // ExtractRegExpGlobals - The globals, matched via regular expression, to 112 // extract from the module... 113 static cl::list<std::string> 114 ExtractRegExpGlobals("rglob", 115 cl::desc("Specify global(s) to extract using a " 116 "regular expression"), 117 cl::ZeroOrMore, cl::value_desc("rglobal"), 118 cl::cat(ExtractCat)); 119 120 static cl::opt<bool> OutputAssembly("S", 121 cl::desc("Write output as LLVM assembly"), 122 cl::Hidden, cl::cat(ExtractCat)); 123 124 static cl::opt<bool> PreserveBitcodeUseListOrder( 125 "preserve-bc-uselistorder", 126 cl::desc("Preserve use-list order when writing LLVM bitcode."), 127 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 128 129 static cl::opt<bool> PreserveAssemblyUseListOrder( 130 "preserve-ll-uselistorder", 131 cl::desc("Preserve use-list order when writing LLVM assembly."), 132 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 133 134 int main(int argc, char **argv) { 135 InitLLVM X(argc, argv); 136 137 LLVMContext Context; 138 cl::HideUnrelatedOptions(ExtractCat); 139 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 140 141 // Use lazy loading, since we only care about selected global values. 142 SMDiagnostic Err; 143 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 144 145 if (!M.get()) { 146 Err.print(argv[0], errs()); 147 return 1; 148 } 149 150 // Use SetVector to avoid duplicates. 151 SetVector<GlobalValue *> GVs; 152 153 // Figure out which aliases we should extract. 154 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 155 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 156 if (!GA) { 157 errs() << argv[0] << ": program doesn't contain alias named '" 158 << ExtractAliases[i] << "'!\n"; 159 return 1; 160 } 161 GVs.insert(GA); 162 } 163 164 // Extract aliases via regular expression matching. 165 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 166 std::string Error; 167 Regex RegEx(ExtractRegExpAliases[i]); 168 if (!RegEx.isValid(Error)) { 169 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 170 "invalid regex: " << Error; 171 } 172 bool match = false; 173 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 174 GA != E; GA++) { 175 if (RegEx.match(GA->getName())) { 176 GVs.insert(&*GA); 177 match = true; 178 } 179 } 180 if (!match) { 181 errs() << argv[0] << ": program doesn't contain global named '" 182 << ExtractRegExpAliases[i] << "'!\n"; 183 return 1; 184 } 185 } 186 187 // Figure out which globals we should extract. 188 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 189 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 190 if (!GV) { 191 errs() << argv[0] << ": program doesn't contain global named '" 192 << ExtractGlobals[i] << "'!\n"; 193 return 1; 194 } 195 GVs.insert(GV); 196 } 197 198 // Extract globals via regular expression matching. 199 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 200 std::string Error; 201 Regex RegEx(ExtractRegExpGlobals[i]); 202 if (!RegEx.isValid(Error)) { 203 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 204 "invalid regex: " << Error; 205 } 206 bool match = false; 207 for (auto &GV : M->globals()) { 208 if (RegEx.match(GV.getName())) { 209 GVs.insert(&GV); 210 match = true; 211 } 212 } 213 if (!match) { 214 errs() << argv[0] << ": program doesn't contain global named '" 215 << ExtractRegExpGlobals[i] << "'!\n"; 216 return 1; 217 } 218 } 219 220 // Figure out which functions we should extract. 221 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 222 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 223 if (!GV) { 224 errs() << argv[0] << ": program doesn't contain function named '" 225 << ExtractFuncs[i] << "'!\n"; 226 return 1; 227 } 228 GVs.insert(GV); 229 } 230 // Extract functions via regular expression matching. 231 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 232 std::string Error; 233 StringRef RegExStr = ExtractRegExpFuncs[i]; 234 Regex RegEx(RegExStr); 235 if (!RegEx.isValid(Error)) { 236 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 237 "invalid regex: " << Error; 238 } 239 bool match = false; 240 for (Module::iterator F = M->begin(), E = M->end(); F != E; 241 F++) { 242 if (RegEx.match(F->getName())) { 243 GVs.insert(&*F); 244 match = true; 245 } 246 } 247 if (!match) { 248 errs() << argv[0] << ": program doesn't contain global named '" 249 << ExtractRegExpFuncs[i] << "'!\n"; 250 return 1; 251 } 252 } 253 254 // Figure out which BasicBlocks we should extract. 255 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs; 256 for (StringRef StrPair : ExtractBlocks) { 257 auto BBInfo = StrPair.split(':'); 258 // Get the function. 259 Function *F = M->getFunction(BBInfo.first); 260 if (!F) { 261 errs() << argv[0] << ": program doesn't contain a function named '" 262 << BBInfo.first << "'!\n"; 263 return 1; 264 } 265 // Do not materialize this function. 266 GVs.insert(F); 267 // Get the basic blocks. 268 SmallVector<BasicBlock *, 16> BBs; 269 SmallVector<StringRef, 16> BBNames; 270 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, 271 /*KeepEmpty=*/false); 272 for (StringRef BBName : BBNames) { 273 auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { 274 return BB.getName().equals(BBName); 275 }); 276 if (Res == F->end()) { 277 errs() << argv[0] << ": function " << F->getName() 278 << " doesn't contain a basic block named '" << BBInfo.second 279 << "'!\n"; 280 return 1; 281 } 282 BBs.push_back(&*Res); 283 } 284 GroupOfBBs.push_back(BBs); 285 } 286 287 // Use *argv instead of argv[0] to work around a wrong GCC warning. 288 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 289 290 if (Recursive) { 291 std::vector<llvm::Function *> Workqueue; 292 for (GlobalValue *GV : GVs) { 293 if (auto *F = dyn_cast<Function>(GV)) { 294 Workqueue.push_back(F); 295 } 296 } 297 while (!Workqueue.empty()) { 298 Function *F = &*Workqueue.back(); 299 Workqueue.pop_back(); 300 ExitOnErr(F->materialize()); 301 for (auto &BB : *F) { 302 for (auto &I : BB) { 303 CallBase *CB = dyn_cast<CallBase>(&I); 304 if (!CB) 305 continue; 306 Function *CF = CB->getCalledFunction(); 307 if (!CF) 308 continue; 309 if (CF->isDeclaration() || GVs.count(CF)) 310 continue; 311 GVs.insert(CF); 312 Workqueue.push_back(CF); 313 } 314 } 315 } 316 } 317 318 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 319 320 // Materialize requisite global values. 321 if (!DeleteFn) { 322 for (size_t i = 0, e = GVs.size(); i != e; ++i) 323 Materialize(*GVs[i]); 324 } else { 325 // Deleting. Materialize every GV that's *not* in GVs. 326 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 327 for (auto &F : *M) { 328 if (!GVSet.count(&F)) 329 Materialize(F); 330 } 331 } 332 333 { 334 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 335 legacy::PassManager Extract; 336 Extract.add(createGVExtractionPass(Gvs, DeleteFn)); 337 Extract.run(*M); 338 339 // Now that we have all the GVs we want, mark the module as fully 340 // materialized. 341 // FIXME: should the GVExtractionPass handle this? 342 ExitOnErr(M->materializeAll()); 343 } 344 345 // Extract the specified basic blocks from the module and erase the existing 346 // functions. 347 if (!ExtractBlocks.empty()) { 348 legacy::PassManager PM; 349 PM.add(createBlockExtractorPass(GroupOfBBs, true)); 350 PM.run(*M); 351 } 352 353 // In addition to deleting all other functions, we also want to spiff it 354 // up a little bit. Do this now. 355 legacy::PassManager Passes; 356 357 if (!DeleteFn) 358 Passes.add(createGlobalDCEPass()); // Delete unreachable globals 359 Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info 360 Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls 361 362 std::error_code EC; 363 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); 364 if (EC) { 365 errs() << EC.message() << '\n'; 366 return 1; 367 } 368 369 if (OutputAssembly) 370 Passes.add( 371 createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 372 else if (Force || !CheckBitcodeOutputToConsole(Out.os(), true)) 373 Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 374 375 Passes.run(*M.get()); 376 377 // Declare success. 378 Out.keep(); 379 380 return 0; 381 } 382