10b57cec5SDimitry Andric //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // This file defines utilities for propagating synthetic counts.
100b57cec5SDimitry Andric //
110b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
120b57cec5SDimitry Andric
130b57cec5SDimitry Andric #include "llvm/Analysis/SyntheticCountsUtils.h"
140b57cec5SDimitry Andric #include "llvm/ADT/DenseSet.h"
150b57cec5SDimitry Andric #include "llvm/ADT/SCCIterator.h"
160b57cec5SDimitry Andric #include "llvm/Analysis/CallGraph.h"
170b57cec5SDimitry Andric #include "llvm/IR/ModuleSummaryIndex.h"
180b57cec5SDimitry Andric
190b57cec5SDimitry Andric using namespace llvm;
200b57cec5SDimitry Andric
210b57cec5SDimitry Andric // Given an SCC, propagate entry counts along the edge of the SCC nodes.
220b57cec5SDimitry Andric template <typename CallGraphType>
propagateFromSCC(const SccTy & SCC,GetProfCountTy GetProfCount,AddCountTy AddCount)230b57cec5SDimitry Andric void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
240b57cec5SDimitry Andric const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
250b57cec5SDimitry Andric
260b57cec5SDimitry Andric DenseSet<NodeRef> SCCNodes;
270b57cec5SDimitry Andric SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
280b57cec5SDimitry Andric
290b57cec5SDimitry Andric for (auto &Node : SCC)
300b57cec5SDimitry Andric SCCNodes.insert(Node);
310b57cec5SDimitry Andric
320b57cec5SDimitry Andric // Partition the edges coming out of the SCC into those whose destination is
330b57cec5SDimitry Andric // in the SCC and the rest.
340b57cec5SDimitry Andric for (const auto &Node : SCCNodes) {
350b57cec5SDimitry Andric for (auto &E : children_edges<CallGraphType>(Node)) {
360b57cec5SDimitry Andric if (SCCNodes.count(CGT::edge_dest(E)))
370b57cec5SDimitry Andric SCCEdges.emplace_back(Node, E);
380b57cec5SDimitry Andric else
390b57cec5SDimitry Andric NonSCCEdges.emplace_back(Node, E);
400b57cec5SDimitry Andric }
410b57cec5SDimitry Andric }
420b57cec5SDimitry Andric
430b57cec5SDimitry Andric // For nodes in the same SCC, update the counts in two steps:
440b57cec5SDimitry Andric // 1. Compute the additional count for each node by propagating the counts
450b57cec5SDimitry Andric // along all incoming edges to the node that originate from within the same
460b57cec5SDimitry Andric // SCC and summing them up.
470b57cec5SDimitry Andric // 2. Add the additional counts to the nodes in the SCC.
480b57cec5SDimitry Andric // This ensures that the order of
490b57cec5SDimitry Andric // traversal of nodes within the SCC doesn't affect the final result.
500b57cec5SDimitry Andric
510b57cec5SDimitry Andric DenseMap<NodeRef, Scaled64> AdditionalCounts;
520b57cec5SDimitry Andric for (auto &E : SCCEdges) {
530b57cec5SDimitry Andric auto OptProfCount = GetProfCount(E.first, E.second);
540b57cec5SDimitry Andric if (!OptProfCount)
550b57cec5SDimitry Andric continue;
560b57cec5SDimitry Andric auto Callee = CGT::edge_dest(E.second);
57*81ad6265SDimitry Andric AdditionalCounts[Callee] += *OptProfCount;
580b57cec5SDimitry Andric }
590b57cec5SDimitry Andric
600b57cec5SDimitry Andric // Update the counts for the nodes in the SCC.
610b57cec5SDimitry Andric for (auto &Entry : AdditionalCounts)
620b57cec5SDimitry Andric AddCount(Entry.first, Entry.second);
630b57cec5SDimitry Andric
640b57cec5SDimitry Andric // Now update the counts for nodes outside the SCC.
650b57cec5SDimitry Andric for (auto &E : NonSCCEdges) {
660b57cec5SDimitry Andric auto OptProfCount = GetProfCount(E.first, E.second);
670b57cec5SDimitry Andric if (!OptProfCount)
680b57cec5SDimitry Andric continue;
690b57cec5SDimitry Andric auto Callee = CGT::edge_dest(E.second);
70*81ad6265SDimitry Andric AddCount(Callee, *OptProfCount);
710b57cec5SDimitry Andric }
720b57cec5SDimitry Andric }
730b57cec5SDimitry Andric
740b57cec5SDimitry Andric /// Propgate synthetic entry counts on a callgraph \p CG.
750b57cec5SDimitry Andric ///
760b57cec5SDimitry Andric /// This performs a reverse post-order traversal of the callgraph SCC. For each
770b57cec5SDimitry Andric /// SCC, it first propagates the entry counts to the nodes within the SCC
780b57cec5SDimitry Andric /// through call edges and updates them in one shot. Then the entry counts are
790b57cec5SDimitry Andric /// propagated to nodes outside the SCC. This requires \p GraphTraits
800b57cec5SDimitry Andric /// to have a specialization for \p CallGraphType.
810b57cec5SDimitry Andric
820b57cec5SDimitry Andric template <typename CallGraphType>
propagate(const CallGraphType & CG,GetProfCountTy GetProfCount,AddCountTy AddCount)830b57cec5SDimitry Andric void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
840b57cec5SDimitry Andric GetProfCountTy GetProfCount,
850b57cec5SDimitry Andric AddCountTy AddCount) {
860b57cec5SDimitry Andric std::vector<SccTy> SCCs;
870b57cec5SDimitry Andric
880b57cec5SDimitry Andric // Collect all the SCCs.
890b57cec5SDimitry Andric for (auto I = scc_begin(CG); !I.isAtEnd(); ++I)
900b57cec5SDimitry Andric SCCs.push_back(*I);
910b57cec5SDimitry Andric
920b57cec5SDimitry Andric // The callgraph-scc needs to be visited in top-down order for propagation.
930b57cec5SDimitry Andric // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
940b57cec5SDimitry Andric // and call propagateFromSCC.
950b57cec5SDimitry Andric for (auto &SCC : reverse(SCCs))
960b57cec5SDimitry Andric propagateFromSCC(SCC, GetProfCount, AddCount);
970b57cec5SDimitry Andric }
980b57cec5SDimitry Andric
990b57cec5SDimitry Andric template class llvm::SyntheticCountsUtils<const CallGraph *>;
1000b57cec5SDimitry Andric template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>;
101