xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/CtxProfAnalysis.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- CtxProfAnalysis.cpp - contextual profile analysis ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of the contextual profile analysis, which maintains contextual
10 // profiling info through IPO passes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/CtxProfAnalysis.h"
15 #include "llvm/ADT/APInt.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Analysis/CFG.h"
18 #include "llvm/IR/Analysis.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/IR/PassManager.h"
23 #include "llvm/ProfileData/PGOCtxProfReader.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/Path.h"
27 #include <deque>
28 #include <memory>
29 
30 #define DEBUG_TYPE "ctx_prof"
31 
32 using namespace llvm;
33 cl::opt<std::string>
34     UseCtxProfile("use-ctx-profile", cl::init(""), cl::Hidden,
35                   cl::desc("Use the specified contextual profile file"));
36 
37 static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
38     "ctx-profile-printer-level",
39     cl::init(CtxProfAnalysisPrinterPass::PrintMode::YAML), cl::Hidden,
40     cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
41                           "everything", "print everything - most verbose"),
42                clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::YAML, "yaml",
43                           "just the yaml representation of the profile")),
44     cl::desc("Verbosity level of the contextual profile printer pass."));
45 
46 static cl::opt<bool> ForceIsInSpecializedModule(
47     "ctx-profile-force-is-specialized", cl::init(false),
48     cl::desc("Treat the given module as-if it were containing the "
49              "post-thinlink module containing the root"));
50 
51 const char *AssignGUIDPass::GUIDMetadataName = "guid";
52 
53 namespace llvm {
54 class ProfileAnnotatorImpl final {
55   friend class ProfileAnnotator;
56   class BBInfo;
57   struct EdgeInfo {
58     BBInfo *const Src;
59     BBInfo *const Dest;
60     std::optional<uint64_t> Count;
61 
62     explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
63   };
64 
65   class BBInfo {
66     std::optional<uint64_t> Count;
67     // OutEdges is dimensioned to match the number of terminator operands.
68     // Entries in the vector match the index in the terminator operand list. In
69     // some cases - see `shouldExcludeEdge` and its implementation - an entry
70     // will be nullptr.
71     // InEdges doesn't have the above constraint.
72     SmallVector<EdgeInfo *> OutEdges;
73     SmallVector<EdgeInfo *> InEdges;
74     size_t UnknownCountOutEdges = 0;
75     size_t UnknownCountInEdges = 0;
76 
77     // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
78     // because all the edge counters must be known.
79     // Return std::nullopt if there were no edges to sum. The user can decide
80     // how to interpret that.
81     std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
82                                        bool AssumeAllKnown) const {
83       std::optional<uint64_t> Sum;
84       for (const auto *E : Edges) {
85         // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
86         if (E) {
87           if (!Sum.has_value())
88             Sum = 0;
89           *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U));
90         }
91       }
92       return Sum;
93     }
94 
95     bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
96       assert(!Count.has_value());
97       Count = getEdgeSum(Edges, true);
98       return Count.has_value();
99     }
100 
101     void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
102       uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U);
103       uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
104       EdgeInfo *E = nullptr;
105       for (auto *I : Edges)
106         if (I && !I->Count.has_value()) {
107           E = I;
108 #ifdef NDEBUG
109           break;
110 #else
111           assert((!E || E == I) &&
112                  "Expected exactly one edge to have an unknown count, "
113                  "found a second one");
114           continue;
115 #endif
116         }
117       assert(E && "Expected exactly one edge to have an unknown count");
118       assert(!E->Count.has_value());
119       E->Count = EdgeVal;
120       assert(E->Src->UnknownCountOutEdges > 0);
121       assert(E->Dest->UnknownCountInEdges > 0);
122       --E->Src->UnknownCountOutEdges;
123       --E->Dest->UnknownCountInEdges;
124     }
125 
126   public:
127     BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
128         : Count(Count) {
129       // For in edges, we just want to pre-allocate enough space, since we know
130       // it at this stage. For out edges, we will insert edges at the indices
131       // corresponding to positions in this BB's terminator instruction, so we
132       // construct a default (nullptr values)-initialized vector. A nullptr edge
133       // corresponds to those that are excluded (see shouldExcludeEdge).
134       InEdges.reserve(NumInEdges);
135       OutEdges.resize(NumOutEdges);
136     }
137 
138     bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
139       if (!UnknownCountOutEdges) {
140         return computeCountFrom(OutEdges);
141       }
142       return false;
143     }
144 
145     bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
146       if (!UnknownCountInEdges) {
147         return computeCountFrom(InEdges);
148       }
149       return false;
150     }
151 
152     void addInEdge(EdgeInfo &Info) {
153       InEdges.push_back(&Info);
154       ++UnknownCountInEdges;
155     }
156 
157     // For the out edges, we care about the position we place them in, which is
158     // the position in terminator instruction's list (at construction). Later,
159     // we build branch_weights metadata with edge frequency values matching
160     // these positions.
161     void addOutEdge(size_t Index, EdgeInfo &Info) {
162       OutEdges[Index] = &Info;
163       ++UnknownCountOutEdges;
164     }
165 
166     bool hasCount() const { return Count.has_value(); }
167 
168     uint64_t getCount() const { return *Count; }
169 
170     bool trySetSingleUnknownInEdgeCount() {
171       if (UnknownCountInEdges == 1) {
172         setSingleUnknownEdgeCount(InEdges);
173         return true;
174       }
175       return false;
176     }
177 
178     bool trySetSingleUnknownOutEdgeCount() {
179       if (UnknownCountOutEdges == 1) {
180         setSingleUnknownEdgeCount(OutEdges);
181         return true;
182       }
183       return false;
184     }
185     size_t getNumOutEdges() const { return OutEdges.size(); }
186 
187     uint64_t getEdgeCount(size_t Index) const {
188       if (auto *E = OutEdges[Index])
189         return *E->Count;
190       return 0U;
191     }
192   };
193 
194   const Function &F;
195   ArrayRef<uint64_t> Counters;
196   // To be accessed through getBBInfo() after construction.
197   std::map<const BasicBlock *, BBInfo> BBInfos;
198   std::vector<EdgeInfo> EdgeInfos;
199 
200   // The only criteria for exclusion is faux suspend -> exit edges in presplit
201   // coroutines. The API serves for readability, currently.
202   bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
203     return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
204   }
205 
206   BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
207 
208   const BBInfo &getBBInfo(const BasicBlock &BB) const {
209     return BBInfos.find(&BB)->second;
210   }
211 
212   // validation function after we propagate the counters: all BBs and edges'
213   // counters must have a value.
214   bool allCountersAreAssigned() const {
215     for (const auto &BBInfo : BBInfos)
216       if (!BBInfo.second.hasCount())
217         return false;
218     for (const auto &EdgeInfo : EdgeInfos)
219       if (!EdgeInfo.Count.has_value())
220         return false;
221     return true;
222   }
223 
224   /// Check that all paths from the entry basic block that use edges with
225   /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
226   bool allTakenPathsExit() const {
227     std::deque<const BasicBlock *> Worklist;
228     DenseSet<const BasicBlock *> Visited;
229     Worklist.push_back(&F.getEntryBlock());
230     bool HitExit = false;
231     while (!Worklist.empty()) {
232       const auto *BB = Worklist.front();
233       Worklist.pop_front();
234       if (!Visited.insert(BB).second)
235         continue;
236       if (succ_size(BB) == 0) {
237         if (isa<UnreachableInst>(BB->getTerminator()))
238           return false;
239         HitExit = true;
240         continue;
241       }
242       if (succ_size(BB) == 1) {
243         Worklist.push_back(BB->getUniqueSuccessor());
244         continue;
245       }
246       const auto &BBInfo = getBBInfo(*BB);
247       bool HasAWayOut = false;
248       for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
249         const auto *Succ = BB->getTerminator()->getSuccessor(I);
250         if (!shouldExcludeEdge(*BB, *Succ)) {
251           if (BBInfo.getEdgeCount(I) > 0) {
252             HasAWayOut = true;
253             Worklist.push_back(Succ);
254           }
255         }
256       }
257       if (!HasAWayOut)
258         return false;
259     }
260     return HitExit;
261   }
262 
263   bool allNonColdSelectsHaveProfile() const {
264     for (const auto &BB : F) {
265       if (getBBInfo(BB).getCount() > 0) {
266         for (const auto &I : BB) {
267           if (const auto *SI = dyn_cast<SelectInst>(&I)) {
268             if (const auto *Inst = CtxProfAnalysis::getSelectInstrumentation(
269                     *const_cast<SelectInst *>(SI))) {
270               auto Index = Inst->getIndex()->getZExtValue();
271               assert(Index < Counters.size());
272               if (Counters[Index] == 0)
273                 return false;
274             }
275           }
276         }
277       }
278     }
279     return true;
280   }
281 
282   // This is an adaptation of PGOUseFunc::populateCounters.
283   // FIXME(mtrofin): look into factoring the code to share one implementation.
284   void propagateCounterValues() {
285     bool KeepGoing = true;
286     while (KeepGoing) {
287       KeepGoing = false;
288       for (const auto &BB : F) {
289         auto &Info = getBBInfo(BB);
290         if (!Info.hasCount())
291           KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
292                        Info.tryTakeCountFromKnownInEdges(BB);
293         if (Info.hasCount()) {
294           KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
295           KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
296         }
297       }
298     }
299     assert(allCountersAreAssigned() &&
300            "[ctx-prof] Expected all counters have been assigned.");
301     assert(allTakenPathsExit() &&
302            "[ctx-prof] Encountered a BB with more than one successor, where "
303            "all outgoing edges have a 0 count. This occurs in non-exiting "
304            "functions (message pumps, usually) which are not supported in the "
305            "contextual profiling case");
306     assert(allNonColdSelectsHaveProfile() &&
307            "[ctx-prof] All non-cold select instructions were expected to have "
308            "a profile.");
309   }
310 
311 public:
312   ProfileAnnotatorImpl(const Function &F, ArrayRef<uint64_t> Counters)
313       : F(F), Counters(Counters) {
314     assert(!F.isDeclaration());
315     assert(!Counters.empty());
316     size_t NrEdges = 0;
317     for (const auto &BB : F) {
318       std::optional<uint64_t> Count;
319       if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
320               const_cast<BasicBlock &>(BB))) {
321         auto Index = Ins->getIndex()->getZExtValue();
322         assert(Index < Counters.size() &&
323                "The index must be inside the counters vector by construction - "
324                "tripping this assertion indicates a bug in how the contextual "
325                "profile is managed by IPO transforms");
326         (void)Index;
327         Count = Counters[Ins->getIndex()->getZExtValue()];
328       } else if (isa<UnreachableInst>(BB.getTerminator())) {
329         // The program presumably didn't crash.
330         Count = 0;
331       }
332       auto [It, Ins] =
333           BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
334       (void)Ins;
335       assert(Ins && "We iterate through the function's BBs, no reason to "
336                     "insert one more than once");
337       NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
338         return !shouldExcludeEdge(BB, *Succ);
339       });
340     }
341     // Pre-allocate the vector, we want references to its contents to be stable.
342     EdgeInfos.reserve(NrEdges);
343     for (const auto &BB : F) {
344       auto &Info = getBBInfo(BB);
345       for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
346         const auto *Succ = BB.getTerminator()->getSuccessor(I);
347         if (!shouldExcludeEdge(BB, *Succ)) {
348           auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
349           Info.addOutEdge(I, EI);
350           getBBInfo(*Succ).addInEdge(EI);
351         }
352       }
353     }
354     assert(EdgeInfos.capacity() == NrEdges &&
355            "The capacity of EdgeInfos should have stayed unchanged it was "
356            "populated, because we need pointers to its contents to be stable");
357     propagateCounterValues();
358   }
359 
360   uint64_t getBBCount(const BasicBlock &BB) { return getBBInfo(BB).getCount(); }
361 };
362 
363 } // namespace llvm
364 
365 ProfileAnnotator::ProfileAnnotator(const Function &F,
366                                    ArrayRef<uint64_t> RawCounters)
367     : PImpl(std::make_unique<ProfileAnnotatorImpl>(F, RawCounters)) {}
368 
369 ProfileAnnotator::~ProfileAnnotator() = default;
370 
371 uint64_t ProfileAnnotator::getBBCount(const BasicBlock &BB) const {
372   return PImpl->getBBCount(BB);
373 }
374 
375 bool ProfileAnnotator::getSelectInstrProfile(SelectInst &SI,
376                                              uint64_t &TrueCount,
377                                              uint64_t &FalseCount) const {
378   const auto &BBInfo = PImpl->getBBInfo(*SI.getParent());
379   TrueCount = FalseCount = 0;
380   if (BBInfo.getCount() == 0)
381     return false;
382 
383   auto *Step = CtxProfAnalysis::getSelectInstrumentation(SI);
384   if (!Step)
385     return false;
386   auto Index = Step->getIndex()->getZExtValue();
387   assert(Index < PImpl->Counters.size() &&
388          "The index of the step instruction must be inside the "
389          "counters vector by "
390          "construction - tripping this assertion indicates a bug in "
391          "how the contextual profile is managed by IPO transforms");
392   auto TotalCount = BBInfo.getCount();
393   TrueCount = PImpl->Counters[Index];
394   FalseCount = (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
395   return true;
396 }
397 
398 bool ProfileAnnotator::getOutgoingBranchWeights(
399     BasicBlock &BB, SmallVectorImpl<uint64_t> &Profile,
400     uint64_t &MaxCount) const {
401   Profile.clear();
402 
403   if (succ_size(&BB) < 2)
404     return false;
405 
406   auto *Term = BB.getTerminator();
407   Profile.resize(Term->getNumSuccessors());
408 
409   const auto &BBInfo = PImpl->getBBInfo(BB);
410   MaxCount = 0;
411   for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
412        ++SuccIdx) {
413     uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
414     if (EdgeCount > MaxCount)
415       MaxCount = EdgeCount;
416     Profile[SuccIdx] = EdgeCount;
417   }
418   return MaxCount > 0;
419 }
420 
421 PreservedAnalyses AssignGUIDPass::run(Module &M, ModuleAnalysisManager &MAM) {
422   for (auto &F : M.functions()) {
423     if (F.isDeclaration())
424       continue;
425     if (F.getMetadata(GUIDMetadataName))
426       continue;
427     const GlobalValue::GUID GUID = F.getGUID();
428     F.setMetadata(GUIDMetadataName,
429                   MDNode::get(M.getContext(),
430                               {ConstantAsMetadata::get(ConstantInt::get(
431                                   Type::getInt64Ty(M.getContext()), GUID))}));
432   }
433   return PreservedAnalyses::none();
434 }
435 
436 GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
437   if (F.isDeclaration()) {
438     assert(GlobalValue::isExternalLinkage(F.getLinkage()));
439     return F.getGUID();
440   }
441   auto *MD = F.getMetadata(GUIDMetadataName);
442   assert(MD && "guid not found for defined function");
443   return cast<ConstantInt>(cast<ConstantAsMetadata>(MD->getOperand(0))
444                                ->getValue()
445                                ->stripPointerCasts())
446       ->getZExtValue();
447 }
448 AnalysisKey CtxProfAnalysis::Key;
449 
450 CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
451     : Profile([&]() -> std::optional<StringRef> {
452         if (Profile)
453           return *Profile;
454         if (UseCtxProfile.getNumOccurrences())
455           return UseCtxProfile;
456         return std::nullopt;
457       }()) {}
458 
459 PGOContextualProfile CtxProfAnalysis::run(Module &M,
460                                           ModuleAnalysisManager &MAM) {
461   if (!Profile)
462     return {};
463   ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(*Profile);
464   if (auto EC = MB.getError()) {
465     M.getContext().emitError("could not open contextual profile file: " +
466                              EC.message());
467     return {};
468   }
469   PGOCtxProfileReader Reader(MB.get()->getBuffer());
470   auto MaybeProfiles = Reader.loadProfiles();
471   if (!MaybeProfiles) {
472     M.getContext().emitError("contextual profile file is invalid: " +
473                              toString(MaybeProfiles.takeError()));
474     return {};
475   }
476 
477   // FIXME: We should drive this from ThinLTO, but for the time being, use the
478   // module name as indicator.
479   // We want to *only* keep the contextual profiles in modules that capture
480   // context trees. That allows us to compute specific PSIs, for example.
481   auto DetermineRootsInModule = [&M]() -> const DenseSet<GlobalValue::GUID> {
482     DenseSet<GlobalValue::GUID> ProfileRootsInModule;
483     auto ModName = M.getName();
484     auto Filename = sys::path::filename(ModName);
485     // Drop the file extension.
486     Filename = Filename.substr(0, Filename.find_last_of('.'));
487     // See if it parses
488     APInt Guid;
489     // getAsInteger returns true if there are more chars to read other than the
490     // integer. So the "false" test is what we want.
491     if (!Filename.getAsInteger(0, Guid))
492       ProfileRootsInModule.insert(Guid.getZExtValue());
493     return ProfileRootsInModule;
494   };
495   const auto ProfileRootsInModule = DetermineRootsInModule();
496   PGOContextualProfile Result;
497 
498   // the logic from here on allows for modules that contain - by design - more
499   // than one root. We currently don't support that, because the determination
500   // happens based on the module name matching the root guid, but the logic can
501   // avoid assuming that.
502   if (!ProfileRootsInModule.empty()) {
503     Result.IsInSpecializedModule = true;
504     // Trim first the roots that aren't in this module.
505     for (auto &[RootGuid, _] :
506          llvm::make_early_inc_range(MaybeProfiles->Contexts))
507       if (!ProfileRootsInModule.contains(RootGuid))
508         MaybeProfiles->Contexts.erase(RootGuid);
509     // we can also drop the flat profiles
510     MaybeProfiles->FlatProfiles.clear();
511   }
512 
513   for (const auto &F : M) {
514     if (F.isDeclaration())
515       continue;
516     auto GUID = AssignGUIDPass::getGUID(F);
517     assert(GUID && "guid not found for defined function");
518     const auto &Entry = F.begin();
519     uint32_t MaxCounters = 0; // we expect at least a counter.
520     for (const auto &I : *Entry)
521       if (auto *C = dyn_cast<InstrProfIncrementInst>(&I)) {
522         MaxCounters =
523             static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
524         break;
525       }
526     if (!MaxCounters)
527       continue;
528     uint32_t MaxCallsites = 0;
529     for (const auto &BB : F)
530       for (const auto &I : BB)
531         if (auto *C = dyn_cast<InstrProfCallsite>(&I)) {
532           MaxCallsites =
533               static_cast<uint32_t>(C->getNumCounters()->getZExtValue());
534           break;
535         }
536     auto [It, Ins] = Result.FuncInfo.insert(
537         {GUID, PGOContextualProfile::FunctionInfo(F.getName())});
538     (void)Ins;
539     assert(Ins);
540     It->second.NextCallsiteIndex = MaxCallsites;
541     It->second.NextCounterIndex = MaxCounters;
542   }
543   // If we made it this far, the Result is valid - which we mark by setting
544   // .Profiles.
545   Result.Profiles = std::move(*MaybeProfiles);
546   Result.initIndex();
547   return Result;
548 }
549 
550 GlobalValue::GUID
551 PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
552   if (auto It = FuncInfo.find(AssignGUIDPass::getGUID(F)); It != FuncInfo.end())
553     return It->first;
554   return 0;
555 }
556 
557 CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
558     : OS(OS), Mode(PrintLevel) {}
559 
560 PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
561                                                   ModuleAnalysisManager &MAM) {
562   CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(M);
563   if (C.contexts().empty()) {
564     OS << "No contextual profile was provided.\n";
565     return PreservedAnalyses::all();
566   }
567 
568   if (Mode == PrintMode::Everything) {
569     OS << "Function Info:\n";
570     for (const auto &[Guid, FuncInfo] : C.FuncInfo)
571       OS << Guid << " : " << FuncInfo.Name
572          << ". MaxCounterID: " << FuncInfo.NextCounterIndex
573          << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
574   }
575 
576   if (Mode == PrintMode::Everything)
577     OS << "\nCurrent Profile:\n";
578   convertCtxProfToYaml(OS, C.profiles());
579   OS << "\n";
580   if (Mode == PrintMode::YAML)
581     return PreservedAnalyses::all();
582 
583   OS << "\nFlat Profile:\n";
584   auto Flat = C.flatten();
585   for (const auto &[Guid, Counters] : Flat) {
586     OS << Guid << " : ";
587     for (auto V : Counters)
588       OS << V << " ";
589     OS << "\n";
590   }
591   return PreservedAnalyses::all();
592 }
593 
594 InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
595   if (!InstrProfCallsite::canInstrumentCallsite(CB))
596     return nullptr;
597   for (auto *Prev = CB.getPrevNode(); Prev; Prev = Prev->getPrevNode()) {
598     if (auto *IPC = dyn_cast<InstrProfCallsite>(Prev))
599       return IPC;
600     assert(!isa<CallBase>(Prev) &&
601            "didn't expect to find another call, that's not the callsite "
602            "instrumentation, before an instrumentable callsite");
603   }
604   return nullptr;
605 }
606 
607 InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
608   for (auto &I : BB)
609     if (auto *Incr = dyn_cast<InstrProfIncrementInst>(&I))
610       if (!isa<InstrProfIncrementInstStep>(&I))
611         return Incr;
612   return nullptr;
613 }
614 
615 InstrProfIncrementInstStep *
616 CtxProfAnalysis::getSelectInstrumentation(SelectInst &SI) {
617   Instruction *Prev = &SI;
618   while ((Prev = Prev->getPrevNode()))
619     if (auto *Step = dyn_cast<InstrProfIncrementInstStep>(Prev))
620       return Step;
621   return nullptr;
622 }
623 
624 template <class ProfTy>
625 static void preorderVisitOneRoot(ProfTy &Profile,
626                                  function_ref<void(ProfTy &)> Visitor) {
627   std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
628     Visitor(Ctx);
629     for (auto &[_, SubCtxSet] : Ctx.callsites())
630       for (auto &[__, Subctx] : SubCtxSet)
631         Traverser(Subctx);
632   };
633   Traverser(Profile);
634 }
635 
636 template <class ProfilesTy, class ProfTy>
637 static void preorderVisit(ProfilesTy &Profiles,
638                           function_ref<void(ProfTy &)> Visitor) {
639   for (auto &[_, P] : Profiles)
640     preorderVisitOneRoot<ProfTy>(P, Visitor);
641 }
642 
643 void PGOContextualProfile::initIndex() {
644   // Initialize the head of the index list for each function. We don't need it
645   // after this point.
646   DenseMap<GlobalValue::GUID, PGOCtxProfContext *> InsertionPoints;
647   for (auto &[Guid, FI] : FuncInfo)
648     InsertionPoints[Guid] = &FI.Index;
649   preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
650       Profiles.Contexts, [&](PGOCtxProfContext &Ctx) {
651         auto InsertIt = InsertionPoints.find(Ctx.guid());
652         if (InsertIt == InsertionPoints.end())
653           return;
654         // Insert at the end of the list. Since we traverse in preorder, it
655         // means that when we iterate the list from the beginning, we'd
656         // encounter the contexts in the order we would have, should we have
657         // performed a full preorder traversal.
658         InsertIt->second->Next = &Ctx;
659         Ctx.Previous = InsertIt->second;
660         InsertIt->second = &Ctx;
661       });
662 }
663 
664 bool PGOContextualProfile::isInSpecializedModule() const {
665   return ForceIsInSpecializedModule.getNumOccurrences() > 0
666              ? ForceIsInSpecializedModule
667              : IsInSpecializedModule;
668 }
669 
670 void PGOContextualProfile::update(Visitor V, const Function &F) {
671   assert(isFunctionKnown(F));
672   GlobalValue::GUID G = getDefinedFunctionGUID(F);
673   for (auto *Node = FuncInfo.find(G)->second.Index.Next; Node;
674        Node = Node->Next)
675     V(*reinterpret_cast<PGOCtxProfContext *>(Node));
676 }
677 
678 void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
679   if (!F)
680     return preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
681                          const PGOCtxProfContext>(Profiles.Contexts, V);
682   assert(isFunctionKnown(*F));
683   GlobalValue::GUID G = getDefinedFunctionGUID(*F);
684   for (const auto *Node = FuncInfo.find(G)->second.Index.Next; Node;
685        Node = Node->Next)
686     V(*reinterpret_cast<const PGOCtxProfContext *>(Node));
687 }
688 
689 const CtxProfFlatProfile PGOContextualProfile::flatten() const {
690   CtxProfFlatProfile Flat;
691   auto Accummulate = [](SmallVectorImpl<uint64_t> &Into,
692                         const SmallVectorImpl<uint64_t> &From,
693                         uint64_t SamplingRate) {
694     if (Into.empty())
695       Into.resize(From.size());
696     assert(Into.size() == From.size() &&
697            "All contexts corresponding to a function should have the exact "
698            "same number of counters.");
699     for (size_t I = 0, E = Into.size(); I < E; ++I)
700       Into[I] += From[I] * SamplingRate;
701   };
702 
703   for (const auto &[_, CtxRoot] : Profiles.Contexts) {
704     const uint64_t SamplingFactor = CtxRoot.getTotalRootEntryCount();
705     preorderVisitOneRoot<const PGOCtxProfContext>(
706         CtxRoot, [&](const PGOCtxProfContext &Ctx) {
707           Accummulate(Flat[Ctx.guid()], Ctx.counters(), SamplingFactor);
708         });
709 
710     for (const auto &[G, Unh] : CtxRoot.getUnhandled())
711       Accummulate(Flat[G], Unh, SamplingFactor);
712   }
713   // We don't sample "Flat" currently, so sampling rate is 1.
714   for (const auto &[G, FC] : Profiles.FlatProfiles)
715     Accummulate(Flat[G], FC, /*SamplingRate=*/1);
716   return Flat;
717 }
718 
719 const CtxProfFlatIndirectCallProfile
720 PGOContextualProfile::flattenVirtCalls() const {
721   CtxProfFlatIndirectCallProfile Ret;
722   for (const auto &[_, CtxRoot] : Profiles.Contexts) {
723     const uint64_t TotalRootEntryCount = CtxRoot.getTotalRootEntryCount();
724     preorderVisitOneRoot<const PGOCtxProfContext>(
725         CtxRoot, [&](const PGOCtxProfContext &Ctx) {
726           auto &Targets = Ret[Ctx.guid()];
727           for (const auto &[ID, SubctxSet] : Ctx.callsites())
728             for (const auto &Subctx : SubctxSet)
729               Targets[ID][Subctx.first] +=
730                   Subctx.second.getEntrycount() * TotalRootEntryCount;
731         });
732   }
733   return Ret;
734 }
735 
736 void CtxProfAnalysis::collectIndirectCallPromotionList(
737     CallBase &IC, Result &Profile,
738     SetVector<std::pair<CallBase *, Function *>> &Candidates) {
739   const auto *Instr = CtxProfAnalysis::getCallsiteInstrumentation(IC);
740   if (!Instr)
741     return;
742   Module &M = *IC.getParent()->getModule();
743   const uint32_t CallID = Instr->getIndex()->getZExtValue();
744   Profile.visit(
745       [&](const PGOCtxProfContext &Ctx) {
746         const auto &Targets = Ctx.callsites().find(CallID);
747         if (Targets == Ctx.callsites().end())
748           return;
749         for (const auto &[Guid, _] : Targets->second)
750           if (auto Name = Profile.getFunctionName(Guid); !Name.empty())
751             if (auto *Target = M.getFunction(Name))
752               if (Target->hasFnAttribute(Attribute::AlwaysInline))
753                 Candidates.insert({&IC, Target});
754       },
755       IC.getCaller());
756 }
757