xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/CallPrinter.cpp (revision f9fd7337f63698f33239c58c07bf430198235a22)
1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/BranchProbabilityInfo.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/DOTGraphTraitsPass.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/SmallSet.h"
26 
27 using namespace llvm;
28 
29 // This option shows static (relative) call counts.
30 // FIXME:
31 // Need to show real counts when profile data is available
32 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
33                                     cl::Hidden,
34                                     cl::desc("Show heat colors in call-graph"));
35 
36 static cl::opt<bool>
37     ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
38                        cl::desc("Show edges labeled with weights"));
39 
40 static cl::opt<bool>
41     CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
42             cl::desc("Show call-multigraph (do not remove parallel edges)"));
43 
44 static cl::opt<std::string> CallGraphDotFilenamePrefix(
45     "callgraph-dot-filename-prefix", cl::Hidden,
46     cl::desc("The prefix used for the CallGraph dot file names."));
47 
48 namespace llvm {
49 
50 class CallGraphDOTInfo {
51 private:
52   Module *M;
53   CallGraph *CG;
54   DenseMap<const Function *, uint64_t> Freq;
55   uint64_t MaxFreq;
56 
57 public:
58   std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
59 
60   CallGraphDOTInfo(Module *M, CallGraph *CG,
61                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
62       : M(M), CG(CG), LookupBFI(LookupBFI) {
63     MaxFreq = 0;
64 
65     for (auto F = M->getFunctionList().begin(); F != M->getFunctionList().end(); ++F) {
66       uint64_t localSumFreq = 0;
67       SmallSet<Function *, 16> Callers;
68       for (User *U : (*F).users())
69         if (isa<CallInst>(U))
70           Callers.insert(cast<Instruction>(U)->getFunction());
71       for (auto iter = Callers.begin() ; iter != Callers.end() ; ++iter)
72         localSumFreq += getNumOfCalls((**iter), *F);
73       if (localSumFreq >= MaxFreq)
74         MaxFreq = localSumFreq;
75       Freq[&*F] = localSumFreq;
76     }
77     if (!CallMultiGraph)
78       removeParallelEdges();
79   }
80 
81   Module *getModule() const { return M; }
82 
83   CallGraph *getCallGraph() const { return CG; }
84 
85   uint64_t getFreq(const Function *F) { return Freq[F]; }
86 
87   uint64_t getMaxFreq() { return MaxFreq; }
88 
89 private:
90   void removeParallelEdges() {
91     for (auto &I : (*CG)) {
92       CallGraphNode *Node = I.second.get();
93 
94       bool FoundParallelEdge = true;
95       while (FoundParallelEdge) {
96         SmallSet<Function *, 16> Visited;
97         FoundParallelEdge = false;
98         for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
99           if (!(Visited.insert(CI->second->getFunction())).second) {
100             FoundParallelEdge = true;
101             Node->removeCallEdge(CI);
102             break;
103           }
104         }
105       }
106     }
107   }
108 };
109 
110 template <>
111 struct GraphTraits<CallGraphDOTInfo *>
112     : public GraphTraits<const CallGraphNode *> {
113   static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
114     // Start at the external node!
115     return CGInfo->getCallGraph()->getExternalCallingNode();
116   }
117 
118   typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
119       PairTy;
120   static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
121     return P.second.get();
122   }
123 
124   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
125   typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
126       nodes_iterator;
127 
128   static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
129     return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
130   }
131   static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
132     return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
133   }
134 };
135 
136 template <>
137 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
138 
139   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
140 
141   static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
142     return "Call graph: " +
143            std::string(CGInfo->getModule()->getModuleIdentifier());
144   }
145 
146   static bool isNodeHidden(const CallGraphNode *Node) {
147     if (CallMultiGraph || Node->getFunction())
148       return false;
149     return true;
150   }
151 
152   std::string getNodeLabel(const CallGraphNode *Node,
153                            CallGraphDOTInfo *CGInfo) {
154     if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
155       return "external caller";
156     if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
157       return "external callee";
158 
159     if (Function *Func = Node->getFunction())
160       return std::string(Func->getName());
161     return "external node";
162   }
163   static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
164     return P.second;
165   }
166 
167   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
168   typedef mapped_iterator<CallGraphNode::const_iterator,
169                           decltype(&CGGetValuePtr)>
170       nodes_iterator;
171 
172   std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
173                                 CallGraphDOTInfo *CGInfo) {
174     if (!ShowEdgeWeight)
175       return "";
176 
177     Function *Caller = Node->getFunction();
178     if (Caller == nullptr || Caller->isDeclaration())
179       return "";
180 
181     Function *Callee = (*I)->getFunction();
182     if (Callee == nullptr)
183       return "";
184 
185     uint64_t Counter = getNumOfCalls(*Caller, *Callee);
186     double Width =
187         1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
188     std::string Attrs = "label=\"" + std::to_string(Counter) +
189                         "\" penwidth=" + std::to_string(Width);
190     return Attrs;
191   }
192 
193   std::string getNodeAttributes(const CallGraphNode *Node,
194                                 CallGraphDOTInfo *CGInfo) {
195     Function *F = Node->getFunction();
196     if (F == nullptr)
197       return "";
198     std::string attrs = "";
199     if (ShowHeatColors) {
200       uint64_t freq = CGInfo->getFreq(F);
201       std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
202       std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
203                                   ? getHeatColor(0)
204                                   : getHeatColor(1);
205       attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
206               color + "80\"";
207     }
208     return attrs;
209   }
210 };
211 
212 } // end llvm namespace
213 
214 namespace {
215 // Viewer
216 class CallGraphViewer : public ModulePass {
217 public:
218   static char ID;
219   CallGraphViewer() : ModulePass(ID) {}
220 
221   void getAnalysisUsage(AnalysisUsage &AU) const override;
222   bool runOnModule(Module &M) override;
223 };
224 
225 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
226   ModulePass::getAnalysisUsage(AU);
227   AU.addRequired<BlockFrequencyInfoWrapperPass>();
228   AU.setPreservesAll();
229 }
230 
231 bool CallGraphViewer::runOnModule(Module &M) {
232   auto LookupBFI = [this](Function &F) {
233     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
234   };
235 
236   CallGraph CG(M);
237   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
238 
239   std::string Title =
240       DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
241   ViewGraph(&CFGInfo, "callgraph", true, Title);
242 
243   return false;
244 }
245 
246 // DOT Printer
247 
248 class CallGraphDOTPrinter : public ModulePass {
249 public:
250   static char ID;
251   CallGraphDOTPrinter() : ModulePass(ID) {}
252 
253   void getAnalysisUsage(AnalysisUsage &AU) const override;
254   bool runOnModule(Module &M) override;
255 };
256 
257 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
258   ModulePass::getAnalysisUsage(AU);
259   AU.addRequired<BlockFrequencyInfoWrapperPass>();
260   AU.setPreservesAll();
261 }
262 
263 bool CallGraphDOTPrinter::runOnModule(Module &M) {
264   auto LookupBFI = [this](Function &F) {
265     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
266   };
267 
268   std::string Filename;
269   if (!CallGraphDotFilenamePrefix.empty())
270     Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
271   else
272     Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
273   errs() << "Writing '" << Filename << "'...";
274 
275   std::error_code EC;
276   raw_fd_ostream File(Filename, EC, sys::fs::F_Text);
277 
278   CallGraph CG(M);
279   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
280 
281   if (!EC)
282     WriteGraph(File, &CFGInfo);
283   else
284     errs() << "  error opening file for writing!";
285   errs() << "\n";
286 
287   return false;
288 }
289 
290 } // end anonymous namespace
291 
292 char CallGraphViewer::ID = 0;
293 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
294                 false)
295 
296 char CallGraphDOTPrinter::ID = 0;
297 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
298                 "Print call graph to 'dot' file", false, false)
299 
300 // Create methods available outside of this file, to use them
301 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
302 // the link time optimization.
303 
304 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
305 
306 ModulePass *llvm::createCallGraphDOTPrinterPass() {
307   return new CallGraphDOTPrinter();
308 }
309