1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/Error.h" 26 #include "llvm/Support/FileSystem.h" 27 #include "llvm/Support/InitLLVM.h" 28 #include "llvm/Support/Regex.h" 29 #include "llvm/Support/SourceMgr.h" 30 #include "llvm/Support/SystemUtils.h" 31 #include "llvm/Support/ToolOutputFile.h" 32 #include "llvm/Transforms/IPO.h" 33 #include <memory> 34 using namespace llvm; 35 36 cl::OptionCategory ExtractCat("llvm-extract Options"); 37 38 // InputFilename - The filename to read from. 39 static cl::opt<std::string> InputFilename(cl::Positional, 40 cl::desc("<input bitcode file>"), 41 cl::init("-"), 42 cl::value_desc("filename")); 43 44 static cl::opt<std::string> OutputFilename("o", 45 cl::desc("Specify output filename"), 46 cl::value_desc("filename"), 47 cl::init("-"), cl::cat(ExtractCat)); 48 49 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 50 cl::cat(ExtractCat)); 51 52 static cl::opt<bool> DeleteFn("delete", 53 cl::desc("Delete specified Globals from Module"), 54 cl::cat(ExtractCat)); 55 56 static cl::opt<bool> 57 Recursive("recursive", cl::desc("Recursively extract all called functions"), 58 cl::cat(ExtractCat)); 59 60 // ExtractFuncs - The functions to extract from the module. 61 static cl::list<std::string> 62 ExtractFuncs("func", cl::desc("Specify function to extract"), 63 cl::ZeroOrMore, cl::value_desc("function"), 64 cl::cat(ExtractCat)); 65 66 // ExtractRegExpFuncs - The functions, matched via regular expression, to 67 // extract from the module. 68 static cl::list<std::string> 69 ExtractRegExpFuncs("rfunc", 70 cl::desc("Specify function(s) to extract using a " 71 "regular expression"), 72 cl::ZeroOrMore, cl::value_desc("rfunction"), 73 cl::cat(ExtractCat)); 74 75 // ExtractBlocks - The blocks to extract from the module. 76 static cl::list<std::string> ExtractBlocks( 77 "bb", cl::desc("Specify <function, basic block> pairs to extract"), 78 cl::ZeroOrMore, cl::value_desc("function:bb"), cl::cat(ExtractCat)); 79 80 // ExtractAlias - The alias to extract from the module. 81 static cl::list<std::string> 82 ExtractAliases("alias", cl::desc("Specify alias to extract"), 83 cl::ZeroOrMore, cl::value_desc("alias"), 84 cl::cat(ExtractCat)); 85 86 // ExtractRegExpAliases - The aliases, matched via regular expression, to 87 // extract from the module. 88 static cl::list<std::string> 89 ExtractRegExpAliases("ralias", 90 cl::desc("Specify alias(es) to extract using a " 91 "regular expression"), 92 cl::ZeroOrMore, cl::value_desc("ralias"), 93 cl::cat(ExtractCat)); 94 95 // ExtractGlobals - The globals to extract from the module. 96 static cl::list<std::string> 97 ExtractGlobals("glob", cl::desc("Specify global to extract"), 98 cl::ZeroOrMore, cl::value_desc("global"), 99 cl::cat(ExtractCat)); 100 101 // ExtractRegExpGlobals - The globals, matched via regular expression, to 102 // extract from the module... 103 static cl::list<std::string> 104 ExtractRegExpGlobals("rglob", 105 cl::desc("Specify global(s) to extract using a " 106 "regular expression"), 107 cl::ZeroOrMore, cl::value_desc("rglobal"), 108 cl::cat(ExtractCat)); 109 110 static cl::opt<bool> OutputAssembly("S", 111 cl::desc("Write output as LLVM assembly"), 112 cl::Hidden, cl::cat(ExtractCat)); 113 114 static cl::opt<bool> PreserveBitcodeUseListOrder( 115 "preserve-bc-uselistorder", 116 cl::desc("Preserve use-list order when writing LLVM bitcode."), 117 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 118 119 static cl::opt<bool> PreserveAssemblyUseListOrder( 120 "preserve-ll-uselistorder", 121 cl::desc("Preserve use-list order when writing LLVM assembly."), 122 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 123 124 int main(int argc, char **argv) { 125 InitLLVM X(argc, argv); 126 127 LLVMContext Context; 128 cl::HideUnrelatedOptions(ExtractCat); 129 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 130 131 // Use lazy loading, since we only care about selected global values. 132 SMDiagnostic Err; 133 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 134 135 if (!M.get()) { 136 Err.print(argv[0], errs()); 137 return 1; 138 } 139 140 // Use SetVector to avoid duplicates. 141 SetVector<GlobalValue *> GVs; 142 143 // Figure out which aliases we should extract. 144 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 145 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 146 if (!GA) { 147 errs() << argv[0] << ": program doesn't contain alias named '" 148 << ExtractAliases[i] << "'!\n"; 149 return 1; 150 } 151 GVs.insert(GA); 152 } 153 154 // Extract aliases via regular expression matching. 155 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 156 std::string Error; 157 Regex RegEx(ExtractRegExpAliases[i]); 158 if (!RegEx.isValid(Error)) { 159 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 160 "invalid regex: " << Error; 161 } 162 bool match = false; 163 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 164 GA != E; GA++) { 165 if (RegEx.match(GA->getName())) { 166 GVs.insert(&*GA); 167 match = true; 168 } 169 } 170 if (!match) { 171 errs() << argv[0] << ": program doesn't contain global named '" 172 << ExtractRegExpAliases[i] << "'!\n"; 173 return 1; 174 } 175 } 176 177 // Figure out which globals we should extract. 178 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 179 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 180 if (!GV) { 181 errs() << argv[0] << ": program doesn't contain global named '" 182 << ExtractGlobals[i] << "'!\n"; 183 return 1; 184 } 185 GVs.insert(GV); 186 } 187 188 // Extract globals via regular expression matching. 189 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 190 std::string Error; 191 Regex RegEx(ExtractRegExpGlobals[i]); 192 if (!RegEx.isValid(Error)) { 193 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 194 "invalid regex: " << Error; 195 } 196 bool match = false; 197 for (auto &GV : M->globals()) { 198 if (RegEx.match(GV.getName())) { 199 GVs.insert(&GV); 200 match = true; 201 } 202 } 203 if (!match) { 204 errs() << argv[0] << ": program doesn't contain global named '" 205 << ExtractRegExpGlobals[i] << "'!\n"; 206 return 1; 207 } 208 } 209 210 // Figure out which functions we should extract. 211 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 212 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 213 if (!GV) { 214 errs() << argv[0] << ": program doesn't contain function named '" 215 << ExtractFuncs[i] << "'!\n"; 216 return 1; 217 } 218 GVs.insert(GV); 219 } 220 // Extract functions via regular expression matching. 221 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 222 std::string Error; 223 StringRef RegExStr = ExtractRegExpFuncs[i]; 224 Regex RegEx(RegExStr); 225 if (!RegEx.isValid(Error)) { 226 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 227 "invalid regex: " << Error; 228 } 229 bool match = false; 230 for (Module::iterator F = M->begin(), E = M->end(); F != E; 231 F++) { 232 if (RegEx.match(F->getName())) { 233 GVs.insert(&*F); 234 match = true; 235 } 236 } 237 if (!match) { 238 errs() << argv[0] << ": program doesn't contain global named '" 239 << ExtractRegExpFuncs[i] << "'!\n"; 240 return 1; 241 } 242 } 243 244 // Figure out which BasicBlocks we should extract. 245 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs; 246 for (StringRef StrPair : ExtractBlocks) { 247 auto BBInfo = StrPair.split(':'); 248 // Get the function. 249 Function *F = M->getFunction(BBInfo.first); 250 if (!F) { 251 errs() << argv[0] << ": program doesn't contain a function named '" 252 << BBInfo.first << "'!\n"; 253 return 1; 254 } 255 // Do not materialize this function. 256 GVs.insert(F); 257 // Get the basic blocks. 258 SmallVector<BasicBlock *, 16> BBs; 259 SmallVector<StringRef, 16> BBNames; 260 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, 261 /*KeepEmpty=*/false); 262 for (StringRef BBName : BBNames) { 263 auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { 264 return BB.getName().equals(BBName); 265 }); 266 if (Res == F->end()) { 267 errs() << argv[0] << ": function " << F->getName() 268 << " doesn't contain a basic block named '" << BBInfo.second 269 << "'!\n"; 270 return 1; 271 } 272 BBs.push_back(&*Res); 273 } 274 GroupOfBBs.push_back(BBs); 275 } 276 277 // Use *argv instead of argv[0] to work around a wrong GCC warning. 278 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 279 280 if (Recursive) { 281 std::vector<llvm::Function *> Workqueue; 282 for (GlobalValue *GV : GVs) { 283 if (auto *F = dyn_cast<Function>(GV)) { 284 Workqueue.push_back(F); 285 } 286 } 287 while (!Workqueue.empty()) { 288 Function *F = &*Workqueue.back(); 289 Workqueue.pop_back(); 290 ExitOnErr(F->materialize()); 291 for (auto &BB : *F) { 292 for (auto &I : BB) { 293 CallBase *CB = dyn_cast<CallBase>(&I); 294 if (!CB) 295 continue; 296 Function *CF = CB->getCalledFunction(); 297 if (!CF) 298 continue; 299 if (CF->isDeclaration() || GVs.count(CF)) 300 continue; 301 GVs.insert(CF); 302 Workqueue.push_back(CF); 303 } 304 } 305 } 306 } 307 308 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 309 310 // Materialize requisite global values. 311 if (!DeleteFn) { 312 for (size_t i = 0, e = GVs.size(); i != e; ++i) 313 Materialize(*GVs[i]); 314 } else { 315 // Deleting. Materialize every GV that's *not* in GVs. 316 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 317 for (auto &F : *M) { 318 if (!GVSet.count(&F)) 319 Materialize(F); 320 } 321 } 322 323 { 324 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 325 legacy::PassManager Extract; 326 Extract.add(createGVExtractionPass(Gvs, DeleteFn)); 327 Extract.run(*M); 328 329 // Now that we have all the GVs we want, mark the module as fully 330 // materialized. 331 // FIXME: should the GVExtractionPass handle this? 332 ExitOnErr(M->materializeAll()); 333 } 334 335 // Extract the specified basic blocks from the module and erase the existing 336 // functions. 337 if (!ExtractBlocks.empty()) { 338 legacy::PassManager PM; 339 PM.add(createBlockExtractorPass(GroupOfBBs, true)); 340 PM.run(*M); 341 } 342 343 // In addition to deleting all other functions, we also want to spiff it 344 // up a little bit. Do this now. 345 legacy::PassManager Passes; 346 347 if (!DeleteFn) 348 Passes.add(createGlobalDCEPass()); // Delete unreachable globals 349 Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info 350 Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls 351 352 std::error_code EC; 353 ToolOutputFile Out(OutputFilename, EC, sys::fs::F_None); 354 if (EC) { 355 errs() << EC.message() << '\n'; 356 return 1; 357 } 358 359 if (OutputAssembly) 360 Passes.add( 361 createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 362 else if (Force || !CheckBitcodeOutputToConsole(Out.os(), true)) 363 Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 364 365 Passes.run(*M.get()); 366 367 // Declare success. 368 Out.keep(); 369 370 return 0; 371 } 372