xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/CallPrinter.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/CallGraph.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/DOTGraphTraits.h"
27 #include "llvm/Support/GraphWriter.h"
28 
29 using namespace llvm;
30 
31 namespace llvm {
32 template <class GraphType> struct GraphTraits;
33 } // namespace llvm
34 
35 // This option shows static (relative) call counts.
36 // FIXME:
37 // Need to show real counts when profile data is available
38 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
39                                     cl::Hidden,
40                                     cl::desc("Show heat colors in call-graph"));
41 
42 static cl::opt<bool>
43     ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
44                        cl::desc("Show edges labeled with weights"));
45 
46 static cl::opt<bool>
47     CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
48             cl::desc("Show call-multigraph (do not remove parallel edges)"));
49 
50 static cl::opt<std::string> CallGraphDotFilenamePrefix(
51     "callgraph-dot-filename-prefix", cl::Hidden,
52     cl::desc("The prefix used for the CallGraph dot file names."));
53 
54 namespace llvm {
55 
56 class CallGraphDOTInfo {
57 private:
58   Module *M;
59   CallGraph *CG;
60   DenseMap<const Function *, uint64_t> Freq;
61   uint64_t MaxFreq;
62 
63 public:
64   std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
65 
66   CallGraphDOTInfo(Module *M, CallGraph *CG,
67                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
68       : M(M), CG(CG), LookupBFI(LookupBFI) {
69     MaxFreq = 0;
70 
71     for (Function &F : M->getFunctionList()) {
72       uint64_t localSumFreq = 0;
73       SmallSet<Function *, 16> Callers;
74       for (User *U : F.users())
75         if (isa<CallInst>(U))
76           Callers.insert(cast<Instruction>(U)->getFunction());
77       for (Function *Caller : Callers)
78         localSumFreq += getNumOfCalls(*Caller, F);
79       if (localSumFreq >= MaxFreq)
80         MaxFreq = localSumFreq;
81       Freq[&F] = localSumFreq;
82     }
83     if (!CallMultiGraph)
84       removeParallelEdges();
85   }
86 
87   Module *getModule() const { return M; }
88 
89   CallGraph *getCallGraph() const { return CG; }
90 
91   uint64_t getFreq(const Function *F) { return Freq[F]; }
92 
93   uint64_t getMaxFreq() { return MaxFreq; }
94 
95 private:
96   void removeParallelEdges() {
97     for (auto &I : (*CG)) {
98       CallGraphNode *Node = I.second.get();
99 
100       bool FoundParallelEdge = true;
101       while (FoundParallelEdge) {
102         SmallSet<Function *, 16> Visited;
103         FoundParallelEdge = false;
104         for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
105           if (!(Visited.insert(CI->second->getFunction())).second) {
106             FoundParallelEdge = true;
107             Node->removeCallEdge(CI);
108             break;
109           }
110         }
111       }
112     }
113   }
114 };
115 
116 template <>
117 struct GraphTraits<CallGraphDOTInfo *>
118     : public GraphTraits<const CallGraphNode *> {
119   static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
120     // Start at the external node!
121     return CGInfo->getCallGraph()->getExternalCallingNode();
122   }
123 
124   typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
125       PairTy;
126   static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
127     return P.second.get();
128   }
129 
130   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
131   typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
132       nodes_iterator;
133 
134   static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
135     return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
136   }
137   static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
138     return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
139   }
140 };
141 
142 template <>
143 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
144 
145   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
146 
147   static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
148     return "Call graph: " +
149            std::string(CGInfo->getModule()->getModuleIdentifier());
150   }
151 
152   static bool isNodeHidden(const CallGraphNode *Node,
153                            const CallGraphDOTInfo *CGInfo) {
154     if (CallMultiGraph || Node->getFunction())
155       return false;
156     return true;
157   }
158 
159   std::string getNodeLabel(const CallGraphNode *Node,
160                            CallGraphDOTInfo *CGInfo) {
161     if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
162       return "external caller";
163     if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
164       return "external callee";
165 
166     if (Function *Func = Node->getFunction())
167       return std::string(Func->getName());
168     return "external node";
169   }
170   static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
171     return P.second;
172   }
173 
174   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
175   typedef mapped_iterator<CallGraphNode::const_iterator,
176                           decltype(&CGGetValuePtr)>
177       nodes_iterator;
178 
179   std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
180                                 CallGraphDOTInfo *CGInfo) {
181     if (!ShowEdgeWeight)
182       return "";
183 
184     Function *Caller = Node->getFunction();
185     if (Caller == nullptr || Caller->isDeclaration())
186       return "";
187 
188     Function *Callee = (*I)->getFunction();
189     if (Callee == nullptr)
190       return "";
191 
192     uint64_t Counter = getNumOfCalls(*Caller, *Callee);
193     double Width =
194         1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
195     std::string Attrs = "label=\"" + std::to_string(Counter) +
196                         "\" penwidth=" + std::to_string(Width);
197     return Attrs;
198   }
199 
200   std::string getNodeAttributes(const CallGraphNode *Node,
201                                 CallGraphDOTInfo *CGInfo) {
202     Function *F = Node->getFunction();
203     if (F == nullptr)
204       return "";
205     std::string attrs;
206     if (ShowHeatColors) {
207       uint64_t freq = CGInfo->getFreq(F);
208       std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
209       std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
210                                   ? getHeatColor(0)
211                                   : getHeatColor(1);
212       attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
213               color + "80\"";
214     }
215     return attrs;
216   }
217 };
218 
219 } // namespace llvm
220 
221 namespace {
222 void doCallGraphDOTPrinting(
223     Module &M, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
224   std::string Filename;
225   if (!CallGraphDotFilenamePrefix.empty())
226     Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
227   else
228     Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
229   errs() << "Writing '" << Filename << "'...";
230 
231   std::error_code EC;
232   raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
233 
234   CallGraph CG(M);
235   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
236 
237   if (!EC)
238     WriteGraph(File, &CFGInfo);
239   else
240     errs() << "  error opening file for writing!";
241   errs() << "\n";
242 }
243 
244 void viewCallGraph(Module &M,
245                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
246   CallGraph CG(M);
247   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
248 
249   std::string Title =
250       DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
251   ViewGraph(&CFGInfo, "callgraph", true, Title);
252 }
253 } // namespace
254 
255 namespace llvm {
256 PreservedAnalyses CallGraphDOTPrinterPass::run(Module &M,
257                                                ModuleAnalysisManager &AM) {
258   FunctionAnalysisManager &FAM =
259       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
260 
261   auto LookupBFI = [&FAM](Function &F) {
262     return &FAM.getResult<BlockFrequencyAnalysis>(F);
263   };
264 
265   doCallGraphDOTPrinting(M, LookupBFI);
266 
267   return PreservedAnalyses::all();
268 }
269 
270 PreservedAnalyses CallGraphViewerPass::run(Module &M,
271                                            ModuleAnalysisManager &AM) {
272 
273   FunctionAnalysisManager &FAM =
274       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
275 
276   auto LookupBFI = [&FAM](Function &F) {
277     return &FAM.getResult<BlockFrequencyAnalysis>(F);
278   };
279 
280   viewCallGraph(M, LookupBFI);
281 
282   return PreservedAnalyses::all();
283 }
284 } // namespace llvm
285 
286 namespace {
287 // Viewer
288 class CallGraphViewer : public ModulePass {
289 public:
290   static char ID;
291   CallGraphViewer() : ModulePass(ID) {}
292 
293   void getAnalysisUsage(AnalysisUsage &AU) const override;
294   bool runOnModule(Module &M) override;
295 };
296 
297 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
298   ModulePass::getAnalysisUsage(AU);
299   AU.addRequired<BlockFrequencyInfoWrapperPass>();
300   AU.setPreservesAll();
301 }
302 
303 bool CallGraphViewer::runOnModule(Module &M) {
304   auto LookupBFI = [this](Function &F) {
305     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
306   };
307 
308   viewCallGraph(M, LookupBFI);
309 
310   return false;
311 }
312 
313 // DOT Printer
314 
315 class CallGraphDOTPrinter : public ModulePass {
316 public:
317   static char ID;
318   CallGraphDOTPrinter() : ModulePass(ID) {}
319 
320   void getAnalysisUsage(AnalysisUsage &AU) const override;
321   bool runOnModule(Module &M) override;
322 };
323 
324 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
325   ModulePass::getAnalysisUsage(AU);
326   AU.addRequired<BlockFrequencyInfoWrapperPass>();
327   AU.setPreservesAll();
328 }
329 
330 bool CallGraphDOTPrinter::runOnModule(Module &M) {
331   auto LookupBFI = [this](Function &F) {
332     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
333   };
334 
335   doCallGraphDOTPrinting(M, LookupBFI);
336 
337   return false;
338 }
339 
340 } // end anonymous namespace
341 
342 char CallGraphViewer::ID = 0;
343 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
344                 false)
345 
346 char CallGraphDOTPrinter::ID = 0;
347 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
348                 "Print call graph to 'dot' file", false, false)
349 
350 // Create methods available outside of this file, to use them
351 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
352 // the link time optimization.
353 
354 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
355 
356 ModulePass *llvm::createCallGraphDOTPrinterPass() {
357   return new CallGraphDOTPrinter();
358 }
359