xref: /freebsd/contrib/llvm-project/llvm/tools/llvm-extract/llvm-extract.cpp (revision 1f1e2261e341e6ca6862f82261066ef1705f0a7a)
1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This utility changes the input module to only contain a single function,
10 // which is primarily used for debugging transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/Bitcode/BitcodeWriterPass.h"
17 #include "llvm/IR/DataLayout.h"
18 #include "llvm/IR/IRPrintingPasses.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/LegacyPassManager.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IRReader/IRReader.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Error.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/Regex.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include "llvm/Support/SystemUtils.h"
32 #include "llvm/Support/ToolOutputFile.h"
33 #include "llvm/Transforms/IPO.h"
34 #include <memory>
35 #include <utility>
36 using namespace llvm;
37 
38 cl::OptionCategory ExtractCat("llvm-extract Options");
39 
40 // InputFilename - The filename to read from.
41 static cl::opt<std::string> InputFilename(cl::Positional,
42                                           cl::desc("<input bitcode file>"),
43                                           cl::init("-"),
44                                           cl::value_desc("filename"));
45 
46 static cl::opt<std::string> OutputFilename("o",
47                                            cl::desc("Specify output filename"),
48                                            cl::value_desc("filename"),
49                                            cl::init("-"), cl::cat(ExtractCat));
50 
51 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
52                            cl::cat(ExtractCat));
53 
54 static cl::opt<bool> DeleteFn("delete",
55                               cl::desc("Delete specified Globals from Module"),
56                               cl::cat(ExtractCat));
57 
58 static cl::opt<bool> KeepConstInit("keep-const-init",
59                               cl::desc("Keep initializers of constants"),
60                               cl::cat(ExtractCat));
61 
62 static cl::opt<bool>
63     Recursive("recursive", cl::desc("Recursively extract all called functions"),
64               cl::cat(ExtractCat));
65 
66 // ExtractFuncs - The functions to extract from the module.
67 static cl::list<std::string>
68     ExtractFuncs("func", cl::desc("Specify function to extract"),
69                  cl::ZeroOrMore, cl::value_desc("function"),
70                  cl::cat(ExtractCat));
71 
72 // ExtractRegExpFuncs - The functions, matched via regular expression, to
73 // extract from the module.
74 static cl::list<std::string>
75     ExtractRegExpFuncs("rfunc",
76                        cl::desc("Specify function(s) to extract using a "
77                                 "regular expression"),
78                        cl::ZeroOrMore, cl::value_desc("rfunction"),
79                        cl::cat(ExtractCat));
80 
81 // ExtractBlocks - The blocks to extract from the module.
82 static cl::list<std::string> ExtractBlocks(
83     "bb",
84     cl::desc(
85         "Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
86         "Each pair will create a function.\n"
87         "If multiple basic blocks are specified in one pair,\n"
88         "the first block in the sequence should dominate the rest.\n"
89         "eg:\n"
90         "  --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
91         "  --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
92         "with bb2."),
93     cl::ZeroOrMore, cl::value_desc("function:bb1[;bb2...]"),
94     cl::cat(ExtractCat));
95 
96 // ExtractAlias - The alias to extract from the module.
97 static cl::list<std::string>
98     ExtractAliases("alias", cl::desc("Specify alias to extract"),
99                    cl::ZeroOrMore, cl::value_desc("alias"),
100                    cl::cat(ExtractCat));
101 
102 // ExtractRegExpAliases - The aliases, matched via regular expression, to
103 // extract from the module.
104 static cl::list<std::string>
105     ExtractRegExpAliases("ralias",
106                          cl::desc("Specify alias(es) to extract using a "
107                                   "regular expression"),
108                          cl::ZeroOrMore, cl::value_desc("ralias"),
109                          cl::cat(ExtractCat));
110 
111 // ExtractGlobals - The globals to extract from the module.
112 static cl::list<std::string>
113     ExtractGlobals("glob", cl::desc("Specify global to extract"),
114                    cl::ZeroOrMore, cl::value_desc("global"),
115                    cl::cat(ExtractCat));
116 
117 // ExtractRegExpGlobals - The globals, matched via regular expression, to
118 // extract from the module...
119 static cl::list<std::string>
120     ExtractRegExpGlobals("rglob",
121                          cl::desc("Specify global(s) to extract using a "
122                                   "regular expression"),
123                          cl::ZeroOrMore, cl::value_desc("rglobal"),
124                          cl::cat(ExtractCat));
125 
126 static cl::opt<bool> OutputAssembly("S",
127                                     cl::desc("Write output as LLVM assembly"),
128                                     cl::Hidden, cl::cat(ExtractCat));
129 
130 static cl::opt<bool> PreserveBitcodeUseListOrder(
131     "preserve-bc-uselistorder",
132     cl::desc("Preserve use-list order when writing LLVM bitcode."),
133     cl::init(true), cl::Hidden, cl::cat(ExtractCat));
134 
135 static cl::opt<bool> PreserveAssemblyUseListOrder(
136     "preserve-ll-uselistorder",
137     cl::desc("Preserve use-list order when writing LLVM assembly."),
138     cl::init(false), cl::Hidden, cl::cat(ExtractCat));
139 
140 int main(int argc, char **argv) {
141   InitLLVM X(argc, argv);
142 
143   LLVMContext Context;
144   cl::HideUnrelatedOptions(ExtractCat);
145   cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
146 
147   // Use lazy loading, since we only care about selected global values.
148   SMDiagnostic Err;
149   std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
150 
151   if (!M.get()) {
152     Err.print(argv[0], errs());
153     return 1;
154   }
155 
156   // Use SetVector to avoid duplicates.
157   SetVector<GlobalValue *> GVs;
158 
159   // Figure out which aliases we should extract.
160   for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
161     GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
162     if (!GA) {
163       errs() << argv[0] << ": program doesn't contain alias named '"
164              << ExtractAliases[i] << "'!\n";
165       return 1;
166     }
167     GVs.insert(GA);
168   }
169 
170   // Extract aliases via regular expression matching.
171   for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
172     std::string Error;
173     Regex RegEx(ExtractRegExpAliases[i]);
174     if (!RegEx.isValid(Error)) {
175       errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
176         "invalid regex: " << Error;
177     }
178     bool match = false;
179     for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
180          GA != E; GA++) {
181       if (RegEx.match(GA->getName())) {
182         GVs.insert(&*GA);
183         match = true;
184       }
185     }
186     if (!match) {
187       errs() << argv[0] << ": program doesn't contain global named '"
188              << ExtractRegExpAliases[i] << "'!\n";
189       return 1;
190     }
191   }
192 
193   // Figure out which globals we should extract.
194   for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
195     GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
196     if (!GV) {
197       errs() << argv[0] << ": program doesn't contain global named '"
198              << ExtractGlobals[i] << "'!\n";
199       return 1;
200     }
201     GVs.insert(GV);
202   }
203 
204   // Extract globals via regular expression matching.
205   for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
206     std::string Error;
207     Regex RegEx(ExtractRegExpGlobals[i]);
208     if (!RegEx.isValid(Error)) {
209       errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
210         "invalid regex: " << Error;
211     }
212     bool match = false;
213     for (auto &GV : M->globals()) {
214       if (RegEx.match(GV.getName())) {
215         GVs.insert(&GV);
216         match = true;
217       }
218     }
219     if (!match) {
220       errs() << argv[0] << ": program doesn't contain global named '"
221              << ExtractRegExpGlobals[i] << "'!\n";
222       return 1;
223     }
224   }
225 
226   // Figure out which functions we should extract.
227   for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
228     GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
229     if (!GV) {
230       errs() << argv[0] << ": program doesn't contain function named '"
231              << ExtractFuncs[i] << "'!\n";
232       return 1;
233     }
234     GVs.insert(GV);
235   }
236   // Extract functions via regular expression matching.
237   for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
238     std::string Error;
239     StringRef RegExStr = ExtractRegExpFuncs[i];
240     Regex RegEx(RegExStr);
241     if (!RegEx.isValid(Error)) {
242       errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
243         "invalid regex: " << Error;
244     }
245     bool match = false;
246     for (Module::iterator F = M->begin(), E = M->end(); F != E;
247          F++) {
248       if (RegEx.match(F->getName())) {
249         GVs.insert(&*F);
250         match = true;
251       }
252     }
253     if (!match) {
254       errs() << argv[0] << ": program doesn't contain global named '"
255              << ExtractRegExpFuncs[i] << "'!\n";
256       return 1;
257     }
258   }
259 
260   // Figure out which BasicBlocks we should extract.
261   SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
262   for (StringRef StrPair : ExtractBlocks) {
263     SmallVector<StringRef, 16> BBNames;
264     auto BBInfo = StrPair.split(':');
265     // Get the function.
266     Function *F = M->getFunction(BBInfo.first);
267     if (!F) {
268       errs() << argv[0] << ": program doesn't contain a function named '"
269              << BBInfo.first << "'!\n";
270       return 1;
271     }
272     // Add the function to the materialize list, and store the basic block names
273     // to check after materialization.
274     GVs.insert(F);
275     BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
276     BBMap.push_back({F, std::move(BBNames)});
277   }
278 
279   // Use *argv instead of argv[0] to work around a wrong GCC warning.
280   ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
281 
282   if (Recursive) {
283     std::vector<llvm::Function *> Workqueue;
284     for (GlobalValue *GV : GVs) {
285       if (auto *F = dyn_cast<Function>(GV)) {
286         Workqueue.push_back(F);
287       }
288     }
289     while (!Workqueue.empty()) {
290       Function *F = &*Workqueue.back();
291       Workqueue.pop_back();
292       ExitOnErr(F->materialize());
293       for (auto &BB : *F) {
294         for (auto &I : BB) {
295           CallBase *CB = dyn_cast<CallBase>(&I);
296           if (!CB)
297             continue;
298           Function *CF = CB->getCalledFunction();
299           if (!CF)
300             continue;
301           if (CF->isDeclaration() || GVs.count(CF))
302             continue;
303           GVs.insert(CF);
304           Workqueue.push_back(CF);
305         }
306       }
307     }
308   }
309 
310   auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
311 
312   // Materialize requisite global values.
313   if (!DeleteFn) {
314     for (size_t i = 0, e = GVs.size(); i != e; ++i)
315       Materialize(*GVs[i]);
316   } else {
317     // Deleting. Materialize every GV that's *not* in GVs.
318     SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
319     for (auto &F : *M) {
320       if (!GVSet.count(&F))
321         Materialize(F);
322     }
323   }
324 
325   {
326     std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
327     legacy::PassManager Extract;
328     Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit));
329     Extract.run(*M);
330 
331     // Now that we have all the GVs we want, mark the module as fully
332     // materialized.
333     // FIXME: should the GVExtractionPass handle this?
334     ExitOnErr(M->materializeAll());
335   }
336 
337   // Extract the specified basic blocks from the module and erase the existing
338   // functions.
339   if (!ExtractBlocks.empty()) {
340     // Figure out which BasicBlocks we should extract.
341     SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs;
342     for (auto &P : BBMap) {
343       SmallVector<BasicBlock *, 16> BBs;
344       for (StringRef BBName : P.second) {
345         // The function has been materialized, so add its matching basic blocks
346         // to the block extractor list, or fail if a name is not found.
347         auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) {
348           return BB.getName().equals(BBName);
349         });
350         if (Res == P.first->end()) {
351           errs() << argv[0] << ": function " << P.first->getName()
352                  << " doesn't contain a basic block named '" << BBName
353                  << "'!\n";
354           return 1;
355         }
356         BBs.push_back(&*Res);
357       }
358       GroupOfBBs.push_back(BBs);
359     }
360 
361     legacy::PassManager PM;
362     PM.add(createBlockExtractorPass(GroupOfBBs, true));
363     PM.run(*M);
364   }
365 
366   // In addition to deleting all other functions, we also want to spiff it
367   // up a little bit.  Do this now.
368   legacy::PassManager Passes;
369 
370   if (!DeleteFn)
371     Passes.add(createGlobalDCEPass());           // Delete unreachable globals
372   Passes.add(createStripDeadDebugInfoPass());    // Remove dead debug info
373   Passes.add(createStripDeadPrototypesPass());   // Remove dead func decls
374 
375   std::error_code EC;
376   ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
377   if (EC) {
378     errs() << EC.message() << '\n';
379     return 1;
380   }
381 
382   if (OutputAssembly)
383     Passes.add(
384         createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
385   else if (Force || !CheckBitcodeOutputToConsole(Out.os()))
386     Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
387 
388   Passes.run(*M.get());
389 
390   // Declare success.
391   Out.keep();
392 
393   return 0;
394 }
395