1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Error.h" 27 #include "llvm/Support/FileSystem.h" 28 #include "llvm/Support/InitLLVM.h" 29 #include "llvm/Support/Regex.h" 30 #include "llvm/Support/SourceMgr.h" 31 #include "llvm/Support/SystemUtils.h" 32 #include "llvm/Support/ToolOutputFile.h" 33 #include "llvm/Transforms/IPO.h" 34 #include <memory> 35 #include <utility> 36 using namespace llvm; 37 38 cl::OptionCategory ExtractCat("llvm-extract Options"); 39 40 // InputFilename - The filename to read from. 41 static cl::opt<std::string> InputFilename(cl::Positional, 42 cl::desc("<input bitcode file>"), 43 cl::init("-"), 44 cl::value_desc("filename")); 45 46 static cl::opt<std::string> OutputFilename("o", 47 cl::desc("Specify output filename"), 48 cl::value_desc("filename"), 49 cl::init("-"), cl::cat(ExtractCat)); 50 51 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 52 cl::cat(ExtractCat)); 53 54 static cl::opt<bool> DeleteFn("delete", 55 cl::desc("Delete specified Globals from Module"), 56 cl::cat(ExtractCat)); 57 58 static cl::opt<bool> KeepConstInit("keep-const-init", 59 cl::desc("Keep initializers of constants"), 60 cl::cat(ExtractCat)); 61 62 static cl::opt<bool> 63 Recursive("recursive", cl::desc("Recursively extract all called functions"), 64 cl::cat(ExtractCat)); 65 66 // ExtractFuncs - The functions to extract from the module. 67 static cl::list<std::string> 68 ExtractFuncs("func", cl::desc("Specify function to extract"), 69 cl::value_desc("function"), cl::cat(ExtractCat)); 70 71 // ExtractRegExpFuncs - The functions, matched via regular expression, to 72 // extract from the module. 73 static cl::list<std::string> 74 ExtractRegExpFuncs("rfunc", 75 cl::desc("Specify function(s) to extract using a " 76 "regular expression"), 77 cl::value_desc("rfunction"), cl::cat(ExtractCat)); 78 79 // ExtractBlocks - The blocks to extract from the module. 80 static cl::list<std::string> ExtractBlocks( 81 "bb", 82 cl::desc( 83 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" 84 "Each pair will create a function.\n" 85 "If multiple basic blocks are specified in one pair,\n" 86 "the first block in the sequence should dominate the rest.\n" 87 "eg:\n" 88 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" 89 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " 90 "with bb2."), 91 cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat)); 92 93 // ExtractAlias - The alias to extract from the module. 94 static cl::list<std::string> 95 ExtractAliases("alias", cl::desc("Specify alias to extract"), 96 cl::value_desc("alias"), cl::cat(ExtractCat)); 97 98 // ExtractRegExpAliases - The aliases, matched via regular expression, to 99 // extract from the module. 100 static cl::list<std::string> 101 ExtractRegExpAliases("ralias", 102 cl::desc("Specify alias(es) to extract using a " 103 "regular expression"), 104 cl::value_desc("ralias"), cl::cat(ExtractCat)); 105 106 // ExtractGlobals - The globals to extract from the module. 107 static cl::list<std::string> 108 ExtractGlobals("glob", cl::desc("Specify global to extract"), 109 cl::value_desc("global"), cl::cat(ExtractCat)); 110 111 // ExtractRegExpGlobals - The globals, matched via regular expression, to 112 // extract from the module... 113 static cl::list<std::string> 114 ExtractRegExpGlobals("rglob", 115 cl::desc("Specify global(s) to extract using a " 116 "regular expression"), 117 cl::value_desc("rglobal"), cl::cat(ExtractCat)); 118 119 static cl::opt<bool> OutputAssembly("S", 120 cl::desc("Write output as LLVM assembly"), 121 cl::Hidden, cl::cat(ExtractCat)); 122 123 static cl::opt<bool> PreserveBitcodeUseListOrder( 124 "preserve-bc-uselistorder", 125 cl::desc("Preserve use-list order when writing LLVM bitcode."), 126 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 127 128 static cl::opt<bool> PreserveAssemblyUseListOrder( 129 "preserve-ll-uselistorder", 130 cl::desc("Preserve use-list order when writing LLVM assembly."), 131 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 132 133 int main(int argc, char **argv) { 134 InitLLVM X(argc, argv); 135 136 LLVMContext Context; 137 cl::HideUnrelatedOptions(ExtractCat); 138 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 139 140 // Use lazy loading, since we only care about selected global values. 141 SMDiagnostic Err; 142 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 143 144 if (!M.get()) { 145 Err.print(argv[0], errs()); 146 return 1; 147 } 148 149 // Use SetVector to avoid duplicates. 150 SetVector<GlobalValue *> GVs; 151 152 // Figure out which aliases we should extract. 153 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 154 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 155 if (!GA) { 156 errs() << argv[0] << ": program doesn't contain alias named '" 157 << ExtractAliases[i] << "'!\n"; 158 return 1; 159 } 160 GVs.insert(GA); 161 } 162 163 // Extract aliases via regular expression matching. 164 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 165 std::string Error; 166 Regex RegEx(ExtractRegExpAliases[i]); 167 if (!RegEx.isValid(Error)) { 168 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 169 "invalid regex: " << Error; 170 } 171 bool match = false; 172 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 173 GA != E; GA++) { 174 if (RegEx.match(GA->getName())) { 175 GVs.insert(&*GA); 176 match = true; 177 } 178 } 179 if (!match) { 180 errs() << argv[0] << ": program doesn't contain global named '" 181 << ExtractRegExpAliases[i] << "'!\n"; 182 return 1; 183 } 184 } 185 186 // Figure out which globals we should extract. 187 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 188 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 189 if (!GV) { 190 errs() << argv[0] << ": program doesn't contain global named '" 191 << ExtractGlobals[i] << "'!\n"; 192 return 1; 193 } 194 GVs.insert(GV); 195 } 196 197 // Extract globals via regular expression matching. 198 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 199 std::string Error; 200 Regex RegEx(ExtractRegExpGlobals[i]); 201 if (!RegEx.isValid(Error)) { 202 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 203 "invalid regex: " << Error; 204 } 205 bool match = false; 206 for (auto &GV : M->globals()) { 207 if (RegEx.match(GV.getName())) { 208 GVs.insert(&GV); 209 match = true; 210 } 211 } 212 if (!match) { 213 errs() << argv[0] << ": program doesn't contain global named '" 214 << ExtractRegExpGlobals[i] << "'!\n"; 215 return 1; 216 } 217 } 218 219 // Figure out which functions we should extract. 220 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 221 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 222 if (!GV) { 223 errs() << argv[0] << ": program doesn't contain function named '" 224 << ExtractFuncs[i] << "'!\n"; 225 return 1; 226 } 227 GVs.insert(GV); 228 } 229 // Extract functions via regular expression matching. 230 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 231 std::string Error; 232 StringRef RegExStr = ExtractRegExpFuncs[i]; 233 Regex RegEx(RegExStr); 234 if (!RegEx.isValid(Error)) { 235 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 236 "invalid regex: " << Error; 237 } 238 bool match = false; 239 for (Module::iterator F = M->begin(), E = M->end(); F != E; 240 F++) { 241 if (RegEx.match(F->getName())) { 242 GVs.insert(&*F); 243 match = true; 244 } 245 } 246 if (!match) { 247 errs() << argv[0] << ": program doesn't contain global named '" 248 << ExtractRegExpFuncs[i] << "'!\n"; 249 return 1; 250 } 251 } 252 253 // Figure out which BasicBlocks we should extract. 254 SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; 255 for (StringRef StrPair : ExtractBlocks) { 256 SmallVector<StringRef, 16> BBNames; 257 auto BBInfo = StrPair.split(':'); 258 // Get the function. 259 Function *F = M->getFunction(BBInfo.first); 260 if (!F) { 261 errs() << argv[0] << ": program doesn't contain a function named '" 262 << BBInfo.first << "'!\n"; 263 return 1; 264 } 265 // Add the function to the materialize list, and store the basic block names 266 // to check after materialization. 267 GVs.insert(F); 268 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); 269 BBMap.push_back({F, std::move(BBNames)}); 270 } 271 272 // Use *argv instead of argv[0] to work around a wrong GCC warning. 273 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 274 275 if (Recursive) { 276 std::vector<llvm::Function *> Workqueue; 277 for (GlobalValue *GV : GVs) { 278 if (auto *F = dyn_cast<Function>(GV)) { 279 Workqueue.push_back(F); 280 } 281 } 282 while (!Workqueue.empty()) { 283 Function *F = &*Workqueue.back(); 284 Workqueue.pop_back(); 285 ExitOnErr(F->materialize()); 286 for (auto &BB : *F) { 287 for (auto &I : BB) { 288 CallBase *CB = dyn_cast<CallBase>(&I); 289 if (!CB) 290 continue; 291 Function *CF = CB->getCalledFunction(); 292 if (!CF) 293 continue; 294 if (CF->isDeclaration() || GVs.count(CF)) 295 continue; 296 GVs.insert(CF); 297 Workqueue.push_back(CF); 298 } 299 } 300 } 301 } 302 303 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 304 305 // Materialize requisite global values. 306 if (!DeleteFn) { 307 for (size_t i = 0, e = GVs.size(); i != e; ++i) 308 Materialize(*GVs[i]); 309 } else { 310 // Deleting. Materialize every GV that's *not* in GVs. 311 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 312 for (auto &F : *M) { 313 if (!GVSet.count(&F)) 314 Materialize(F); 315 } 316 } 317 318 { 319 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 320 legacy::PassManager Extract; 321 Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit)); 322 Extract.run(*M); 323 324 // Now that we have all the GVs we want, mark the module as fully 325 // materialized. 326 // FIXME: should the GVExtractionPass handle this? 327 ExitOnErr(M->materializeAll()); 328 } 329 330 // Extract the specified basic blocks from the module and erase the existing 331 // functions. 332 if (!ExtractBlocks.empty()) { 333 // Figure out which BasicBlocks we should extract. 334 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs; 335 for (auto &P : BBMap) { 336 SmallVector<BasicBlock *, 16> BBs; 337 for (StringRef BBName : P.second) { 338 // The function has been materialized, so add its matching basic blocks 339 // to the block extractor list, or fail if a name is not found. 340 auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) { 341 return BB.getName().equals(BBName); 342 }); 343 if (Res == P.first->end()) { 344 errs() << argv[0] << ": function " << P.first->getName() 345 << " doesn't contain a basic block named '" << BBName 346 << "'!\n"; 347 return 1; 348 } 349 BBs.push_back(&*Res); 350 } 351 GroupOfBBs.push_back(BBs); 352 } 353 354 legacy::PassManager PM; 355 PM.add(createBlockExtractorPass(GroupOfBBs, true)); 356 PM.run(*M); 357 } 358 359 // In addition to deleting all other functions, we also want to spiff it 360 // up a little bit. Do this now. 361 legacy::PassManager Passes; 362 363 if (!DeleteFn) 364 Passes.add(createGlobalDCEPass()); // Delete unreachable globals 365 Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info 366 Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls 367 368 std::error_code EC; 369 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); 370 if (EC) { 371 errs() << EC.message() << '\n'; 372 return 1; 373 } 374 375 if (OutputAssembly) 376 Passes.add( 377 createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 378 else if (Force || !CheckBitcodeOutputToConsole(Out.os())) 379 Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 380 381 Passes.run(*M.get()); 382 383 // Declare success. 384 Out.keep(); 385 386 return 0; 387 } 388