xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/CallPrinter.cpp (revision 924226fba12cc9a228c73b956e1b7fa24c60b055)
1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/BranchProbabilityInfo.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/DOTGraphTraitsPass.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/SmallSet.h"
26 
27 using namespace llvm;
28 
29 // This option shows static (relative) call counts.
30 // FIXME:
31 // Need to show real counts when profile data is available
32 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
33                                     cl::Hidden,
34                                     cl::desc("Show heat colors in call-graph"));
35 
36 static cl::opt<bool>
37     ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
38                        cl::desc("Show edges labeled with weights"));
39 
40 static cl::opt<bool>
41     CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
42             cl::desc("Show call-multigraph (do not remove parallel edges)"));
43 
44 static cl::opt<std::string> CallGraphDotFilenamePrefix(
45     "callgraph-dot-filename-prefix", cl::Hidden,
46     cl::desc("The prefix used for the CallGraph dot file names."));
47 
48 namespace llvm {
49 
50 class CallGraphDOTInfo {
51 private:
52   Module *M;
53   CallGraph *CG;
54   DenseMap<const Function *, uint64_t> Freq;
55   uint64_t MaxFreq;
56 
57 public:
58   std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
59 
60   CallGraphDOTInfo(Module *M, CallGraph *CG,
61                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
62       : M(M), CG(CG), LookupBFI(LookupBFI) {
63     MaxFreq = 0;
64 
65     for (Function &F : M->getFunctionList()) {
66       uint64_t localSumFreq = 0;
67       SmallSet<Function *, 16> Callers;
68       for (User *U : F.users())
69         if (isa<CallInst>(U))
70           Callers.insert(cast<Instruction>(U)->getFunction());
71       for (Function *Caller : Callers)
72         localSumFreq += getNumOfCalls(*Caller, F);
73       if (localSumFreq >= MaxFreq)
74         MaxFreq = localSumFreq;
75       Freq[&F] = localSumFreq;
76     }
77     if (!CallMultiGraph)
78       removeParallelEdges();
79   }
80 
81   Module *getModule() const { return M; }
82 
83   CallGraph *getCallGraph() const { return CG; }
84 
85   uint64_t getFreq(const Function *F) { return Freq[F]; }
86 
87   uint64_t getMaxFreq() { return MaxFreq; }
88 
89 private:
90   void removeParallelEdges() {
91     for (auto &I : (*CG)) {
92       CallGraphNode *Node = I.second.get();
93 
94       bool FoundParallelEdge = true;
95       while (FoundParallelEdge) {
96         SmallSet<Function *, 16> Visited;
97         FoundParallelEdge = false;
98         for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
99           if (!(Visited.insert(CI->second->getFunction())).second) {
100             FoundParallelEdge = true;
101             Node->removeCallEdge(CI);
102             break;
103           }
104         }
105       }
106     }
107   }
108 };
109 
110 template <>
111 struct GraphTraits<CallGraphDOTInfo *>
112     : public GraphTraits<const CallGraphNode *> {
113   static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
114     // Start at the external node!
115     return CGInfo->getCallGraph()->getExternalCallingNode();
116   }
117 
118   typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
119       PairTy;
120   static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
121     return P.second.get();
122   }
123 
124   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
125   typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
126       nodes_iterator;
127 
128   static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
129     return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
130   }
131   static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
132     return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
133   }
134 };
135 
136 template <>
137 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
138 
139   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
140 
141   static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
142     return "Call graph: " +
143            std::string(CGInfo->getModule()->getModuleIdentifier());
144   }
145 
146   static bool isNodeHidden(const CallGraphNode *Node,
147                            const CallGraphDOTInfo *CGInfo) {
148     if (CallMultiGraph || Node->getFunction())
149       return false;
150     return true;
151   }
152 
153   std::string getNodeLabel(const CallGraphNode *Node,
154                            CallGraphDOTInfo *CGInfo) {
155     if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
156       return "external caller";
157     if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
158       return "external callee";
159 
160     if (Function *Func = Node->getFunction())
161       return std::string(Func->getName());
162     return "external node";
163   }
164   static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
165     return P.second;
166   }
167 
168   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
169   typedef mapped_iterator<CallGraphNode::const_iterator,
170                           decltype(&CGGetValuePtr)>
171       nodes_iterator;
172 
173   std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
174                                 CallGraphDOTInfo *CGInfo) {
175     if (!ShowEdgeWeight)
176       return "";
177 
178     Function *Caller = Node->getFunction();
179     if (Caller == nullptr || Caller->isDeclaration())
180       return "";
181 
182     Function *Callee = (*I)->getFunction();
183     if (Callee == nullptr)
184       return "";
185 
186     uint64_t Counter = getNumOfCalls(*Caller, *Callee);
187     double Width =
188         1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
189     std::string Attrs = "label=\"" + std::to_string(Counter) +
190                         "\" penwidth=" + std::to_string(Width);
191     return Attrs;
192   }
193 
194   std::string getNodeAttributes(const CallGraphNode *Node,
195                                 CallGraphDOTInfo *CGInfo) {
196     Function *F = Node->getFunction();
197     if (F == nullptr)
198       return "";
199     std::string attrs;
200     if (ShowHeatColors) {
201       uint64_t freq = CGInfo->getFreq(F);
202       std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
203       std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
204                                   ? getHeatColor(0)
205                                   : getHeatColor(1);
206       attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
207               color + "80\"";
208     }
209     return attrs;
210   }
211 };
212 
213 } // end llvm namespace
214 
215 namespace {
216 // Viewer
217 class CallGraphViewer : public ModulePass {
218 public:
219   static char ID;
220   CallGraphViewer() : ModulePass(ID) {}
221 
222   void getAnalysisUsage(AnalysisUsage &AU) const override;
223   bool runOnModule(Module &M) override;
224 };
225 
226 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
227   ModulePass::getAnalysisUsage(AU);
228   AU.addRequired<BlockFrequencyInfoWrapperPass>();
229   AU.setPreservesAll();
230 }
231 
232 bool CallGraphViewer::runOnModule(Module &M) {
233   auto LookupBFI = [this](Function &F) {
234     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
235   };
236 
237   CallGraph CG(M);
238   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
239 
240   std::string Title =
241       DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
242   ViewGraph(&CFGInfo, "callgraph", true, Title);
243 
244   return false;
245 }
246 
247 // DOT Printer
248 
249 class CallGraphDOTPrinter : public ModulePass {
250 public:
251   static char ID;
252   CallGraphDOTPrinter() : ModulePass(ID) {}
253 
254   void getAnalysisUsage(AnalysisUsage &AU) const override;
255   bool runOnModule(Module &M) override;
256 };
257 
258 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
259   ModulePass::getAnalysisUsage(AU);
260   AU.addRequired<BlockFrequencyInfoWrapperPass>();
261   AU.setPreservesAll();
262 }
263 
264 bool CallGraphDOTPrinter::runOnModule(Module &M) {
265   auto LookupBFI = [this](Function &F) {
266     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
267   };
268 
269   std::string Filename;
270   if (!CallGraphDotFilenamePrefix.empty())
271     Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
272   else
273     Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
274   errs() << "Writing '" << Filename << "'...";
275 
276   std::error_code EC;
277   raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
278 
279   CallGraph CG(M);
280   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
281 
282   if (!EC)
283     WriteGraph(File, &CFGInfo);
284   else
285     errs() << "  error opening file for writing!";
286   errs() << "\n";
287 
288   return false;
289 }
290 
291 } // end anonymous namespace
292 
293 char CallGraphViewer::ID = 0;
294 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
295                 false)
296 
297 char CallGraphDOTPrinter::ID = 0;
298 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
299                 "Print call graph to 'dot' file", false, false)
300 
301 // Create methods available outside of this file, to use them
302 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
303 // the link time optimization.
304 
305 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
306 
307 ModulePass *llvm::createCallGraphDOTPrinterPass() {
308   return new CallGraphDOTPrinter();
309 }
310