xref: /freebsd/contrib/llvm-project/llvm/tools/llvm-extract/llvm-extract.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This utility changes the input module to only contain a single function,
10 // which is primarily used for debugging transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/Bitcode/BitcodeWriterPass.h"
17 #include "llvm/IR/DataLayout.h"
18 #include "llvm/IR/IRPrintingPasses.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/IRPrinter/IRPrintingPasses.h"
23 #include "llvm/IRReader/IRReader.h"
24 #include "llvm/Passes/PassBuilder.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Error.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/Regex.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include "llvm/Support/SystemUtils.h"
32 #include "llvm/Support/ToolOutputFile.h"
33 #include "llvm/Transforms/IPO.h"
34 #include "llvm/Transforms/IPO/BlockExtractor.h"
35 #include "llvm/Transforms/IPO/ExtractGV.h"
36 #include "llvm/Transforms/IPO/GlobalDCE.h"
37 #include "llvm/Transforms/IPO/StripDeadPrototypes.h"
38 #include "llvm/Transforms/IPO/StripSymbols.h"
39 #include <memory>
40 #include <utility>
41 
42 using namespace llvm;
43 
44 static cl::OptionCategory ExtractCat("llvm-extract Options");
45 
46 // InputFilename - The filename to read from.
47 static cl::opt<std::string> InputFilename(cl::Positional,
48                                           cl::desc("<input bitcode file>"),
49                                           cl::init("-"),
50                                           cl::value_desc("filename"));
51 
52 static cl::opt<std::string> OutputFilename("o",
53                                            cl::desc("Specify output filename"),
54                                            cl::value_desc("filename"),
55                                            cl::init("-"), cl::cat(ExtractCat));
56 
57 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
58                            cl::cat(ExtractCat));
59 
60 static cl::opt<bool> DeleteFn("delete",
61                               cl::desc("Delete specified Globals from Module"),
62                               cl::cat(ExtractCat));
63 
64 static cl::opt<bool> KeepConstInit("keep-const-init",
65                               cl::desc("Keep initializers of constants"),
66                               cl::cat(ExtractCat));
67 
68 static cl::opt<bool>
69     Recursive("recursive", cl::desc("Recursively extract all called functions"),
70               cl::cat(ExtractCat));
71 
72 // ExtractFuncs - The functions to extract from the module.
73 static cl::list<std::string>
74     ExtractFuncs("func", cl::desc("Specify function to extract"),
75                  cl::value_desc("function"), cl::cat(ExtractCat));
76 
77 // ExtractRegExpFuncs - The functions, matched via regular expression, to
78 // extract from the module.
79 static cl::list<std::string>
80     ExtractRegExpFuncs("rfunc",
81                        cl::desc("Specify function(s) to extract using a "
82                                 "regular expression"),
83                        cl::value_desc("rfunction"), cl::cat(ExtractCat));
84 
85 // ExtractBlocks - The blocks to extract from the module.
86 static cl::list<std::string> ExtractBlocks(
87     "bb",
88     cl::desc(
89         "Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
90         "Each pair will create a function.\n"
91         "If multiple basic blocks are specified in one pair,\n"
92         "the first block in the sequence should dominate the rest.\n"
93         "If an unnamed basic block is to be extracted,\n"
94         "'%' should be added before the basic block variable names.\n"
95         "eg:\n"
96         "  --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
97         "  --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
98         "with bb2.\n"
99         "  --bb=f:%1 will extract one function with basic block 1;"),
100     cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat));
101 
102 // ExtractAlias - The alias to extract from the module.
103 static cl::list<std::string>
104     ExtractAliases("alias", cl::desc("Specify alias to extract"),
105                    cl::value_desc("alias"), cl::cat(ExtractCat));
106 
107 // ExtractRegExpAliases - The aliases, matched via regular expression, to
108 // extract from the module.
109 static cl::list<std::string>
110     ExtractRegExpAliases("ralias",
111                          cl::desc("Specify alias(es) to extract using a "
112                                   "regular expression"),
113                          cl::value_desc("ralias"), cl::cat(ExtractCat));
114 
115 // ExtractGlobals - The globals to extract from the module.
116 static cl::list<std::string>
117     ExtractGlobals("glob", cl::desc("Specify global to extract"),
118                    cl::value_desc("global"), cl::cat(ExtractCat));
119 
120 // ExtractRegExpGlobals - The globals, matched via regular expression, to
121 // extract from the module...
122 static cl::list<std::string>
123     ExtractRegExpGlobals("rglob",
124                          cl::desc("Specify global(s) to extract using a "
125                                   "regular expression"),
126                          cl::value_desc("rglobal"), cl::cat(ExtractCat));
127 
128 static cl::opt<bool> OutputAssembly("S",
129                                     cl::desc("Write output as LLVM assembly"),
130                                     cl::Hidden, cl::cat(ExtractCat));
131 
132 static cl::opt<bool> PreserveBitcodeUseListOrder(
133     "preserve-bc-uselistorder",
134     cl::desc("Preserve use-list order when writing LLVM bitcode."),
135     cl::init(true), cl::Hidden, cl::cat(ExtractCat));
136 
137 static cl::opt<bool> PreserveAssemblyUseListOrder(
138     "preserve-ll-uselistorder",
139     cl::desc("Preserve use-list order when writing LLVM assembly."),
140     cl::init(false), cl::Hidden, cl::cat(ExtractCat));
141 
main(int argc,char ** argv)142 int main(int argc, char **argv) {
143   InitLLVM X(argc, argv);
144 
145   LLVMContext Context;
146   cl::HideUnrelatedOptions(ExtractCat);
147   cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
148 
149   // Use lazy loading, since we only care about selected global values.
150   SMDiagnostic Err;
151   std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
152 
153   if (!M) {
154     Err.print(argv[0], errs());
155     return 1;
156   }
157 
158   // Use SetVector to avoid duplicates.
159   SetVector<GlobalValue *> GVs;
160 
161   // Figure out which aliases we should extract.
162   for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
163     GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
164     if (!GA) {
165       errs() << argv[0] << ": program doesn't contain alias named '"
166              << ExtractAliases[i] << "'!\n";
167       return 1;
168     }
169     GVs.insert(GA);
170   }
171 
172   // Extract aliases via regular expression matching.
173   for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
174     std::string Error;
175     Regex RegEx(ExtractRegExpAliases[i]);
176     if (!RegEx.isValid(Error)) {
177       errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
178         "invalid regex: " << Error;
179     }
180     bool match = false;
181     for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
182          GA != E; GA++) {
183       if (RegEx.match(GA->getName())) {
184         GVs.insert(&*GA);
185         match = true;
186       }
187     }
188     if (!match) {
189       errs() << argv[0] << ": program doesn't contain global named '"
190              << ExtractRegExpAliases[i] << "'!\n";
191       return 1;
192     }
193   }
194 
195   // Figure out which globals we should extract.
196   for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
197     GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
198     if (!GV) {
199       errs() << argv[0] << ": program doesn't contain global named '"
200              << ExtractGlobals[i] << "'!\n";
201       return 1;
202     }
203     GVs.insert(GV);
204   }
205 
206   // Extract globals via regular expression matching.
207   for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
208     std::string Error;
209     Regex RegEx(ExtractRegExpGlobals[i]);
210     if (!RegEx.isValid(Error)) {
211       errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
212         "invalid regex: " << Error;
213     }
214     bool match = false;
215     for (auto &GV : M->globals()) {
216       if (RegEx.match(GV.getName())) {
217         GVs.insert(&GV);
218         match = true;
219       }
220     }
221     if (!match) {
222       errs() << argv[0] << ": program doesn't contain global named '"
223              << ExtractRegExpGlobals[i] << "'!\n";
224       return 1;
225     }
226   }
227 
228   // Figure out which functions we should extract.
229   for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
230     GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
231     if (!GV) {
232       errs() << argv[0] << ": program doesn't contain function named '"
233              << ExtractFuncs[i] << "'!\n";
234       return 1;
235     }
236     GVs.insert(GV);
237   }
238   // Extract functions via regular expression matching.
239   for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
240     std::string Error;
241     StringRef RegExStr = ExtractRegExpFuncs[i];
242     Regex RegEx(RegExStr);
243     if (!RegEx.isValid(Error)) {
244       errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
245         "invalid regex: " << Error;
246     }
247     bool match = false;
248     for (Module::iterator F = M->begin(), E = M->end(); F != E;
249          F++) {
250       if (RegEx.match(F->getName())) {
251         GVs.insert(&*F);
252         match = true;
253       }
254     }
255     if (!match) {
256       errs() << argv[0] << ": program doesn't contain global named '"
257              << ExtractRegExpFuncs[i] << "'!\n";
258       return 1;
259     }
260   }
261 
262   // Figure out which BasicBlocks we should extract.
263   SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
264   for (StringRef StrPair : ExtractBlocks) {
265     SmallVector<StringRef, 16> BBNames;
266     auto BBInfo = StrPair.split(':');
267     // Get the function.
268     Function *F = M->getFunction(BBInfo.first);
269     if (!F) {
270       errs() << argv[0] << ": program doesn't contain a function named '"
271              << BBInfo.first << "'!\n";
272       return 1;
273     }
274     // Add the function to the materialize list, and store the basic block names
275     // to check after materialization.
276     GVs.insert(F);
277     BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
278     BBMap.push_back({F, std::move(BBNames)});
279   }
280 
281   // Use *argv instead of argv[0] to work around a wrong GCC warning.
282   ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
283 
284   if (Recursive) {
285     std::vector<llvm::Function *> Workqueue;
286     for (GlobalValue *GV : GVs) {
287       if (auto *F = dyn_cast<Function>(GV)) {
288         Workqueue.push_back(F);
289       }
290     }
291     while (!Workqueue.empty()) {
292       Function *F = &*Workqueue.back();
293       Workqueue.pop_back();
294       ExitOnErr(F->materialize());
295       for (auto &BB : *F) {
296         for (auto &I : BB) {
297           CallBase *CB = dyn_cast<CallBase>(&I);
298           if (!CB)
299             continue;
300           Function *CF = CB->getCalledFunction();
301           if (!CF)
302             continue;
303           if (CF->isDeclaration() || !GVs.insert(CF))
304             continue;
305           Workqueue.push_back(CF);
306         }
307       }
308     }
309   }
310 
311   auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
312 
313   // Materialize requisite global values.
314   if (!DeleteFn) {
315     for (size_t i = 0, e = GVs.size(); i != e; ++i)
316       Materialize(*GVs[i]);
317   } else {
318     // Deleting. Materialize every GV that's *not* in GVs.
319     SmallPtrSet<GlobalValue *, 8> GVSet(llvm::from_range, GVs);
320     for (auto &F : *M) {
321       if (!GVSet.count(&F))
322         Materialize(F);
323     }
324   }
325 
326   {
327     std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
328     LoopAnalysisManager LAM;
329     FunctionAnalysisManager FAM;
330     CGSCCAnalysisManager CGAM;
331     ModuleAnalysisManager MAM;
332 
333     PassBuilder PB;
334 
335     PB.registerModuleAnalyses(MAM);
336     PB.registerCGSCCAnalyses(CGAM);
337     PB.registerFunctionAnalyses(FAM);
338     PB.registerLoopAnalyses(LAM);
339     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
340 
341     ModulePassManager PM;
342     PM.addPass(ExtractGVPass(Gvs, DeleteFn, KeepConstInit));
343     PM.run(*M, MAM);
344 
345     // Now that we have all the GVs we want, mark the module as fully
346     // materialized.
347     // FIXME: should the GVExtractionPass handle this?
348     ExitOnErr(M->materializeAll());
349   }
350 
351   // Extract the specified basic blocks from the module and erase the existing
352   // functions.
353   if (!ExtractBlocks.empty()) {
354     // Figure out which BasicBlocks we should extract.
355     std::vector<std::vector<BasicBlock *>> GroupOfBBs;
356     for (auto &P : BBMap) {
357       std::vector<BasicBlock *> BBs;
358       for (StringRef BBName : P.second) {
359         // The function has been materialized, so add its matching basic blocks
360         // to the block extractor list, or fail if a name is not found.
361         auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) {
362           return BB.getNameOrAsOperand() == BBName;
363         });
364         if (Res == P.first->end()) {
365           errs() << argv[0] << ": function " << P.first->getName()
366                  << " doesn't contain a basic block named '" << BBName
367                  << "'!\n";
368           return 1;
369         }
370         BBs.push_back(&*Res);
371       }
372       GroupOfBBs.push_back(BBs);
373     }
374 
375     LoopAnalysisManager LAM;
376     FunctionAnalysisManager FAM;
377     CGSCCAnalysisManager CGAM;
378     ModuleAnalysisManager MAM;
379 
380     PassBuilder PB;
381 
382     PB.registerModuleAnalyses(MAM);
383     PB.registerCGSCCAnalyses(CGAM);
384     PB.registerFunctionAnalyses(FAM);
385     PB.registerLoopAnalyses(LAM);
386     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
387 
388     ModulePassManager PM;
389     PM.addPass(BlockExtractorPass(std::move(GroupOfBBs), true));
390     PM.run(*M, MAM);
391   }
392 
393   // In addition to deleting all other functions, we also want to spiff it
394   // up a little bit.  Do this now.
395 
396   LoopAnalysisManager LAM;
397   FunctionAnalysisManager FAM;
398   CGSCCAnalysisManager CGAM;
399   ModuleAnalysisManager MAM;
400 
401   PassBuilder PB;
402 
403   PB.registerModuleAnalyses(MAM);
404   PB.registerCGSCCAnalyses(CGAM);
405   PB.registerFunctionAnalyses(FAM);
406   PB.registerLoopAnalyses(LAM);
407   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
408 
409   ModulePassManager PM;
410   if (!DeleteFn)
411     PM.addPass(GlobalDCEPass());
412   PM.addPass(StripDeadDebugInfoPass());
413   PM.addPass(StripDeadPrototypesPass());
414   PM.addPass(StripDeadCGProfilePass());
415 
416   std::error_code EC;
417   ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
418   if (EC) {
419     errs() << EC.message() << '\n';
420     return 1;
421   }
422 
423   if (OutputAssembly)
424     PM.addPass(PrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
425   else if (Force || !CheckBitcodeOutputToConsole(Out.os()))
426     PM.addPass(BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
427 
428   PM.run(*M, MAM);
429 
430   // Declare success.
431   Out.keep();
432 
433   return 0;
434 }
435