xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp (revision a91a246563dffa876a52f53a98de4af9fa364c52)
1  //===- CodeLayout.cpp - Implementation of code layout algorithms ----------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // The file implements "cache-aware" layout algorithms of basic blocks and
10  // functions in a binary.
11  //
12  // The algorithm tries to find a layout of nodes (basic blocks) of a given CFG
13  // optimizing jump locality and thus processor I-cache utilization. This is
14  // achieved via increasing the number of fall-through jumps and co-locating
15  // frequently executed nodes together. The name follows the underlying
16  // optimization problem, Extended-TSP, which is a generalization of classical
17  // (maximum) Traveling Salesmen Problem.
18  //
19  // The algorithm is a greedy heuristic that works with chains (ordered lists)
20  // of basic blocks. Initially all chains are isolated basic blocks. On every
21  // iteration, we pick a pair of chains whose merging yields the biggest increase
22  // in the ExtTSP score, which models how i-cache "friendly" a specific chain is.
23  // A pair of chains giving the maximum gain is merged into a new chain. The
24  // procedure stops when there is only one chain left, or when merging does not
25  // increase ExtTSP. In the latter case, the remaining chains are sorted by
26  // density in the decreasing order.
27  //
28  // An important aspect is the way two chains are merged. Unlike earlier
29  // algorithms (e.g., based on the approach of Pettis-Hansen), two
30  // chains, X and Y, are first split into three, X1, X2, and Y. Then we
31  // consider all possible ways of gluing the three chains (e.g., X1YX2, X1X2Y,
32  // X2X1Y, X2YX1, YX1X2, YX2X1) and choose the one producing the largest score.
33  // This improves the quality of the final result (the search space is larger)
34  // while keeping the implementation sufficiently fast.
35  //
36  // Reference:
37  //   * A. Newell and S. Pupyrev, Improved Basic Block Reordering,
38  //     IEEE Transactions on Computers, 2020
39  //     https://arxiv.org/abs/1809.04676
40  //
41  //===----------------------------------------------------------------------===//
42  
43  #include "llvm/Transforms/Utils/CodeLayout.h"
44  #include "llvm/Support/CommandLine.h"
45  #include "llvm/Support/Debug.h"
46  
47  #include <cmath>
48  
49  using namespace llvm;
50  #define DEBUG_TYPE "code-layout"
51  
52  namespace llvm {
53  cl::opt<bool> EnableExtTspBlockPlacement(
54      "enable-ext-tsp-block-placement", cl::Hidden, cl::init(false),
55      cl::desc("Enable machine block placement based on the ext-tsp model, "
56               "optimizing I-cache utilization."));
57  
58  cl::opt<bool> ApplyExtTspWithoutProfile(
59      "ext-tsp-apply-without-profile",
60      cl::desc("Whether to apply ext-tsp placement for instances w/o profile"),
61      cl::init(true), cl::Hidden);
62  } // namespace llvm
63  
64  // Algorithm-specific params. The values are tuned for the best performance
65  // of large-scale front-end bound binaries.
66  static cl::opt<double> ForwardWeightCond(
67      "ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1),
68      cl::desc("The weight of conditional forward jumps for ExtTSP value"));
69  
70  static cl::opt<double> ForwardWeightUncond(
71      "ext-tsp-forward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
72      cl::desc("The weight of unconditional forward jumps for ExtTSP value"));
73  
74  static cl::opt<double> BackwardWeightCond(
75      "ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1),
76      cl::desc("The weight of conditional backward jumps for ExtTSP value"));
77  
78  static cl::opt<double> BackwardWeightUncond(
79      "ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1),
80      cl::desc("The weight of unconditional backward jumps for ExtTSP value"));
81  
82  static cl::opt<double> FallthroughWeightCond(
83      "ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0),
84      cl::desc("The weight of conditional fallthrough jumps for ExtTSP value"));
85  
86  static cl::opt<double> FallthroughWeightUncond(
87      "ext-tsp-fallthrough-weight-uncond", cl::ReallyHidden, cl::init(1.05),
88      cl::desc("The weight of unconditional fallthrough jumps for ExtTSP value"));
89  
90  static cl::opt<unsigned> ForwardDistance(
91      "ext-tsp-forward-distance", cl::ReallyHidden, cl::init(1024),
92      cl::desc("The maximum distance (in bytes) of a forward jump for ExtTSP"));
93  
94  static cl::opt<unsigned> BackwardDistance(
95      "ext-tsp-backward-distance", cl::ReallyHidden, cl::init(640),
96      cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP"));
97  
98  // The maximum size of a chain created by the algorithm. The size is bounded
99  // so that the algorithm can efficiently process extremely large instance.
100  static cl::opt<unsigned>
101      MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(4096),
102                   cl::desc("The maximum size of a chain to create."));
103  
104  // The maximum size of a chain for splitting. Larger values of the threshold
105  // may yield better quality at the cost of worsen run-time.
106  static cl::opt<unsigned> ChainSplitThreshold(
107      "ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128),
108      cl::desc("The maximum size of a chain to apply splitting"));
109  
110  // The option enables splitting (large) chains along in-coming and out-going
111  // jumps. This typically results in a better quality.
112  static cl::opt<bool> EnableChainSplitAlongJumps(
113      "ext-tsp-enable-chain-split-along-jumps", cl::ReallyHidden, cl::init(true),
114      cl::desc("The maximum size of a chain to apply splitting"));
115  
116  namespace {
117  
118  // Epsilon for comparison of doubles.
119  constexpr double EPS = 1e-8;
120  
121  // Compute the Ext-TSP score for a given jump.
122  double jumpExtTSPScore(uint64_t JumpDist, uint64_t JumpMaxDist, uint64_t Count,
123                         double Weight) {
124    if (JumpDist > JumpMaxDist)
125      return 0;
126    double Prob = 1.0 - static_cast<double>(JumpDist) / JumpMaxDist;
127    return Weight * Prob * Count;
128  }
129  
130  // Compute the Ext-TSP score for a jump between a given pair of blocks,
131  // using their sizes, (estimated) addresses and the jump execution count.
132  double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr,
133                     uint64_t Count, bool IsConditional) {
134    // Fallthrough
135    if (SrcAddr + SrcSize == DstAddr) {
136      return jumpExtTSPScore(0, 1, Count,
137                             IsConditional ? FallthroughWeightCond
138                                           : FallthroughWeightUncond);
139    }
140    // Forward
141    if (SrcAddr + SrcSize < DstAddr) {
142      const uint64_t Dist = DstAddr - (SrcAddr + SrcSize);
143      return jumpExtTSPScore(Dist, ForwardDistance, Count,
144                             IsConditional ? ForwardWeightCond
145                                           : ForwardWeightUncond);
146    }
147    // Backward
148    const uint64_t Dist = SrcAddr + SrcSize - DstAddr;
149    return jumpExtTSPScore(Dist, BackwardDistance, Count,
150                           IsConditional ? BackwardWeightCond
151                                         : BackwardWeightUncond);
152  }
153  
154  /// A type of merging two chains, X and Y. The former chain is split into
155  /// X1 and X2 and then concatenated with Y in the order specified by the type.
156  enum class MergeTypeT : int { X_Y, Y_X, X1_Y_X2, Y_X2_X1, X2_X1_Y };
157  
158  /// The gain of merging two chains, that is, the Ext-TSP score of the merge
159  /// together with the corresponding merge 'type' and 'offset'.
160  struct MergeGainT {
161    explicit MergeGainT() = default;
162    explicit MergeGainT(double Score, size_t MergeOffset, MergeTypeT MergeType)
163        : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {}
164  
165    double score() const { return Score; }
166  
167    size_t mergeOffset() const { return MergeOffset; }
168  
169    MergeTypeT mergeType() const { return MergeType; }
170  
171    void setMergeType(MergeTypeT Ty) { MergeType = Ty; }
172  
173    // Returns 'true' iff Other is preferred over this.
174    bool operator<(const MergeGainT &Other) const {
175      return (Other.Score > EPS && Other.Score > Score + EPS);
176    }
177  
178    // Update the current gain if Other is preferred over this.
179    void updateIfLessThan(const MergeGainT &Other) {
180      if (*this < Other)
181        *this = Other;
182    }
183  
184  private:
185    double Score{-1.0};
186    size_t MergeOffset{0};
187    MergeTypeT MergeType{MergeTypeT::X_Y};
188  };
189  
190  struct JumpT;
191  struct ChainT;
192  struct ChainEdge;
193  
194  /// A node in the graph, typically corresponding to a basic block in the CFG or
195  /// a function in the call graph.
196  struct NodeT {
197    NodeT(const NodeT &) = delete;
198    NodeT(NodeT &&) = default;
199    NodeT &operator=(const NodeT &) = delete;
200    NodeT &operator=(NodeT &&) = default;
201  
202    explicit NodeT(size_t Index, uint64_t Size, uint64_t EC)
203        : Index(Index), Size(Size), ExecutionCount(EC) {}
204  
205    bool isEntry() const { return Index == 0; }
206  
207    // The total execution count of outgoing jumps.
208    uint64_t outCount() const;
209  
210    // The total execution count of incoming jumps.
211    uint64_t inCount() const;
212  
213    // The original index of the node in graph.
214    size_t Index{0};
215    // The index of the node in the current chain.
216    size_t CurIndex{0};
217    // The size of the node in the binary.
218    uint64_t Size{0};
219    // The execution count of the node in the profile data.
220    uint64_t ExecutionCount{0};
221    // The current chain of the node.
222    ChainT *CurChain{nullptr};
223    // The offset of the node in the current chain.
224    mutable uint64_t EstimatedAddr{0};
225    // Forced successor of the node in the graph.
226    NodeT *ForcedSucc{nullptr};
227    // Forced predecessor of the node in the graph.
228    NodeT *ForcedPred{nullptr};
229    // Outgoing jumps from the node.
230    std::vector<JumpT *> OutJumps;
231    // Incoming jumps to the node.
232    std::vector<JumpT *> InJumps;
233  };
234  
235  /// An arc in the graph, typically corresponding to a jump between two nodes.
236  struct JumpT {
237    JumpT(const JumpT &) = delete;
238    JumpT(JumpT &&) = default;
239    JumpT &operator=(const JumpT &) = delete;
240    JumpT &operator=(JumpT &&) = default;
241  
242    explicit JumpT(NodeT *Source, NodeT *Target, uint64_t ExecutionCount)
243        : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {}
244  
245    // Source node of the jump.
246    NodeT *Source;
247    // Target node of the jump.
248    NodeT *Target;
249    // Execution count of the arc in the profile data.
250    uint64_t ExecutionCount{0};
251    // Whether the jump corresponds to a conditional branch.
252    bool IsConditional{false};
253    // The offset of the jump from the source node.
254    uint64_t Offset{0};
255  };
256  
257  /// A chain (ordered sequence) of nodes in the graph.
258  struct ChainT {
259    ChainT(const ChainT &) = delete;
260    ChainT(ChainT &&) = default;
261    ChainT &operator=(const ChainT &) = delete;
262    ChainT &operator=(ChainT &&) = default;
263  
264    explicit ChainT(uint64_t Id, NodeT *Node)
265        : Id(Id), ExecutionCount(Node->ExecutionCount), Size(Node->Size),
266          Nodes(1, Node) {}
267  
268    size_t numBlocks() const { return Nodes.size(); }
269  
270    double density() const { return static_cast<double>(ExecutionCount) / Size; }
271  
272    bool isEntry() const { return Nodes[0]->Index == 0; }
273  
274    bool isCold() const {
275      for (NodeT *Node : Nodes) {
276        if (Node->ExecutionCount > 0)
277          return false;
278      }
279      return true;
280    }
281  
282    ChainEdge *getEdge(ChainT *Other) const {
283      for (auto It : Edges) {
284        if (It.first == Other)
285          return It.second;
286      }
287      return nullptr;
288    }
289  
290    void removeEdge(ChainT *Other) {
291      auto It = Edges.begin();
292      while (It != Edges.end()) {
293        if (It->first == Other) {
294          Edges.erase(It);
295          return;
296        }
297        It++;
298      }
299    }
300  
301    void addEdge(ChainT *Other, ChainEdge *Edge) {
302      Edges.push_back(std::make_pair(Other, Edge));
303    }
304  
305    void merge(ChainT *Other, const std::vector<NodeT *> &MergedBlocks) {
306      Nodes = MergedBlocks;
307      // Update the chain's data
308      ExecutionCount += Other->ExecutionCount;
309      Size += Other->Size;
310      Id = Nodes[0]->Index;
311      // Update the node's data
312      for (size_t Idx = 0; Idx < Nodes.size(); Idx++) {
313        Nodes[Idx]->CurChain = this;
314        Nodes[Idx]->CurIndex = Idx;
315      }
316    }
317  
318    void mergeEdges(ChainT *Other);
319  
320    void clear() {
321      Nodes.clear();
322      Nodes.shrink_to_fit();
323      Edges.clear();
324      Edges.shrink_to_fit();
325    }
326  
327    // Unique chain identifier.
328    uint64_t Id;
329    // Cached ext-tsp score for the chain.
330    double Score{0};
331    // The total execution count of the chain.
332    uint64_t ExecutionCount{0};
333    // The total size of the chain.
334    uint64_t Size{0};
335    // Nodes of the chain.
336    std::vector<NodeT *> Nodes;
337    // Adjacent chains and corresponding edges (lists of jumps).
338    std::vector<std::pair<ChainT *, ChainEdge *>> Edges;
339  };
340  
341  /// An edge in the graph representing jumps between two chains.
342  /// When nodes are merged into chains, the edges are combined too so that
343  /// there is always at most one edge between a pair of chains
344  struct ChainEdge {
345    ChainEdge(const ChainEdge &) = delete;
346    ChainEdge(ChainEdge &&) = default;
347    ChainEdge &operator=(const ChainEdge &) = delete;
348    ChainEdge &operator=(ChainEdge &&) = delete;
349  
350    explicit ChainEdge(JumpT *Jump)
351        : SrcChain(Jump->Source->CurChain), DstChain(Jump->Target->CurChain),
352          Jumps(1, Jump) {}
353  
354    ChainT *srcChain() const { return SrcChain; }
355  
356    ChainT *dstChain() const { return DstChain; }
357  
358    bool isSelfEdge() const { return SrcChain == DstChain; }
359  
360    const std::vector<JumpT *> &jumps() const { return Jumps; }
361  
362    void appendJump(JumpT *Jump) { Jumps.push_back(Jump); }
363  
364    void moveJumps(ChainEdge *Other) {
365      Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end());
366      Other->Jumps.clear();
367      Other->Jumps.shrink_to_fit();
368    }
369  
370    void changeEndpoint(ChainT *From, ChainT *To) {
371      if (From == SrcChain)
372        SrcChain = To;
373      if (From == DstChain)
374        DstChain = To;
375    }
376  
377    bool hasCachedMergeGain(ChainT *Src, ChainT *Dst) const {
378      return Src == SrcChain ? CacheValidForward : CacheValidBackward;
379    }
380  
381    MergeGainT getCachedMergeGain(ChainT *Src, ChainT *Dst) const {
382      return Src == SrcChain ? CachedGainForward : CachedGainBackward;
383    }
384  
385    void setCachedMergeGain(ChainT *Src, ChainT *Dst, MergeGainT MergeGain) {
386      if (Src == SrcChain) {
387        CachedGainForward = MergeGain;
388        CacheValidForward = true;
389      } else {
390        CachedGainBackward = MergeGain;
391        CacheValidBackward = true;
392      }
393    }
394  
395    void invalidateCache() {
396      CacheValidForward = false;
397      CacheValidBackward = false;
398    }
399  
400    void setMergeGain(MergeGainT Gain) { CachedGain = Gain; }
401  
402    MergeGainT getMergeGain() const { return CachedGain; }
403  
404    double gain() const { return CachedGain.score(); }
405  
406  private:
407    // Source chain.
408    ChainT *SrcChain{nullptr};
409    // Destination chain.
410    ChainT *DstChain{nullptr};
411    // Original jumps in the binary with corresponding execution counts.
412    std::vector<JumpT *> Jumps;
413    // Cached gain value for merging the pair of chains.
414    MergeGainT CachedGain;
415  
416    // Cached gain values for merging the pair of chains. Since the gain of
417    // merging (Src, Dst) and (Dst, Src) might be different, we store both values
418    // here and a flag indicating which of the options results in a higher gain.
419    // Cached gain values.
420    MergeGainT CachedGainForward;
421    MergeGainT CachedGainBackward;
422    // Whether the cached value must be recomputed.
423    bool CacheValidForward{false};
424    bool CacheValidBackward{false};
425  };
426  
427  uint64_t NodeT::outCount() const {
428    uint64_t Count = 0;
429    for (JumpT *Jump : OutJumps) {
430      Count += Jump->ExecutionCount;
431    }
432    return Count;
433  }
434  
435  uint64_t NodeT::inCount() const {
436    uint64_t Count = 0;
437    for (JumpT *Jump : InJumps) {
438      Count += Jump->ExecutionCount;
439    }
440    return Count;
441  }
442  
443  void ChainT::mergeEdges(ChainT *Other) {
444    // Update edges adjacent to chain Other
445    for (auto EdgeIt : Other->Edges) {
446      ChainT *DstChain = EdgeIt.first;
447      ChainEdge *DstEdge = EdgeIt.second;
448      ChainT *TargetChain = DstChain == Other ? this : DstChain;
449      ChainEdge *CurEdge = getEdge(TargetChain);
450      if (CurEdge == nullptr) {
451        DstEdge->changeEndpoint(Other, this);
452        this->addEdge(TargetChain, DstEdge);
453        if (DstChain != this && DstChain != Other) {
454          DstChain->addEdge(this, DstEdge);
455        }
456      } else {
457        CurEdge->moveJumps(DstEdge);
458      }
459      // Cleanup leftover edge
460      if (DstChain != Other) {
461        DstChain->removeEdge(Other);
462      }
463    }
464  }
465  
466  using NodeIter = std::vector<NodeT *>::const_iterator;
467  
468  /// A wrapper around three chains of nodes; it is used to avoid extra
469  /// instantiation of the vectors.
470  struct MergedChain {
471    MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
472                NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
473                NodeIter End3 = NodeIter())
474        : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
475          End3(End3) {}
476  
477    template <typename F> void forEach(const F &Func) const {
478      for (auto It = Begin1; It != End1; It++)
479        Func(*It);
480      for (auto It = Begin2; It != End2; It++)
481        Func(*It);
482      for (auto It = Begin3; It != End3; It++)
483        Func(*It);
484    }
485  
486    std::vector<NodeT *> getNodes() const {
487      std::vector<NodeT *> Result;
488      Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) +
489                     std::distance(Begin3, End3));
490      Result.insert(Result.end(), Begin1, End1);
491      Result.insert(Result.end(), Begin2, End2);
492      Result.insert(Result.end(), Begin3, End3);
493      return Result;
494    }
495  
496    const NodeT *getFirstNode() const { return *Begin1; }
497  
498  private:
499    NodeIter Begin1;
500    NodeIter End1;
501    NodeIter Begin2;
502    NodeIter End2;
503    NodeIter Begin3;
504    NodeIter End3;
505  };
506  
507  /// Merge two chains of nodes respecting a given 'type' and 'offset'.
508  ///
509  /// If MergeType == 0, then the result is a concatenation of two chains.
510  /// Otherwise, the first chain is cut into two sub-chains at the offset,
511  /// and merged using all possible ways of concatenating three chains.
512  MergedChain mergeNodes(const std::vector<NodeT *> &X,
513                         const std::vector<NodeT *> &Y, size_t MergeOffset,
514                         MergeTypeT MergeType) {
515    // Split the first chain, X, into X1 and X2
516    NodeIter BeginX1 = X.begin();
517    NodeIter EndX1 = X.begin() + MergeOffset;
518    NodeIter BeginX2 = X.begin() + MergeOffset;
519    NodeIter EndX2 = X.end();
520    NodeIter BeginY = Y.begin();
521    NodeIter EndY = Y.end();
522  
523    // Construct a new chain from the three existing ones
524    switch (MergeType) {
525    case MergeTypeT::X_Y:
526      return MergedChain(BeginX1, EndX2, BeginY, EndY);
527    case MergeTypeT::Y_X:
528      return MergedChain(BeginY, EndY, BeginX1, EndX2);
529    case MergeTypeT::X1_Y_X2:
530      return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
531    case MergeTypeT::Y_X2_X1:
532      return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
533    case MergeTypeT::X2_X1_Y:
534      return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
535    }
536    llvm_unreachable("unexpected chain merge type");
537  }
538  
539  /// The implementation of the ExtTSP algorithm.
540  class ExtTSPImpl {
541  public:
542    ExtTSPImpl(const std::vector<uint64_t> &NodeSizes,
543               const std::vector<uint64_t> &NodeCounts,
544               const std::vector<EdgeCountT> &EdgeCounts)
545        : NumNodes(NodeSizes.size()) {
546      initialize(NodeSizes, NodeCounts, EdgeCounts);
547    }
548  
549    /// Run the algorithm and return an optimized ordering of nodes.
550    void run(std::vector<uint64_t> &Result) {
551      // Pass 1: Merge nodes with their mutually forced successors
552      mergeForcedPairs();
553  
554      // Pass 2: Merge pairs of chains while improving the ExtTSP objective
555      mergeChainPairs();
556  
557      // Pass 3: Merge cold nodes to reduce code size
558      mergeColdChains();
559  
560      // Collect nodes from all chains
561      concatChains(Result);
562    }
563  
564  private:
565    /// Initialize the algorithm's data structures.
566    void initialize(const std::vector<uint64_t> &NodeSizes,
567                    const std::vector<uint64_t> &NodeCounts,
568                    const std::vector<EdgeCountT> &EdgeCounts) {
569      // Initialize nodes
570      AllNodes.reserve(NumNodes);
571      for (uint64_t Idx = 0; Idx < NumNodes; Idx++) {
572        uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL);
573        uint64_t ExecutionCount = NodeCounts[Idx];
574        // The execution count of the entry node is set to at least one
575        if (Idx == 0 && ExecutionCount == 0)
576          ExecutionCount = 1;
577        AllNodes.emplace_back(Idx, Size, ExecutionCount);
578      }
579  
580      // Initialize jumps between nodes
581      SuccNodes.resize(NumNodes);
582      PredNodes.resize(NumNodes);
583      std::vector<uint64_t> OutDegree(NumNodes, 0);
584      AllJumps.reserve(EdgeCounts.size());
585      for (auto It : EdgeCounts) {
586        uint64_t Pred = It.first.first;
587        uint64_t Succ = It.first.second;
588        OutDegree[Pred]++;
589        // Ignore self-edges
590        if (Pred == Succ)
591          continue;
592  
593        SuccNodes[Pred].push_back(Succ);
594        PredNodes[Succ].push_back(Pred);
595        uint64_t ExecutionCount = It.second;
596        if (ExecutionCount > 0) {
597          NodeT &PredNode = AllNodes[Pred];
598          NodeT &SuccNode = AllNodes[Succ];
599          AllJumps.emplace_back(&PredNode, &SuccNode, ExecutionCount);
600          SuccNode.InJumps.push_back(&AllJumps.back());
601          PredNode.OutJumps.push_back(&AllJumps.back());
602        }
603      }
604      for (JumpT &Jump : AllJumps) {
605        assert(OutDegree[Jump.Source->Index] > 0);
606        Jump.IsConditional = OutDegree[Jump.Source->Index] > 1;
607      }
608  
609      // Initialize chains
610      AllChains.reserve(NumNodes);
611      HotChains.reserve(NumNodes);
612      for (NodeT &Node : AllNodes) {
613        AllChains.emplace_back(Node.Index, &Node);
614        Node.CurChain = &AllChains.back();
615        if (Node.ExecutionCount > 0) {
616          HotChains.push_back(&AllChains.back());
617        }
618      }
619  
620      // Initialize chain edges
621      AllEdges.reserve(AllJumps.size());
622      for (NodeT &PredNode : AllNodes) {
623        for (JumpT *Jump : PredNode.OutJumps) {
624          NodeT *SuccNode = Jump->Target;
625          ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
626          // this edge is already present in the graph
627          if (CurEdge != nullptr) {
628            assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
629            CurEdge->appendJump(Jump);
630            continue;
631          }
632          // this is a new edge
633          AllEdges.emplace_back(Jump);
634          PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
635          SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
636        }
637      }
638    }
639  
640    /// For a pair of nodes, A and B, node B is the forced successor of A,
641    /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps
642    /// to B are from A. Such nodes should be adjacent in the optimal ordering;
643    /// the method finds and merges such pairs of nodes.
644    void mergeForcedPairs() {
645      // Find fallthroughs based on edge weights
646      for (NodeT &Node : AllNodes) {
647        if (SuccNodes[Node.Index].size() == 1 &&
648            PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
649            SuccNodes[Node.Index][0] != 0) {
650          size_t SuccIndex = SuccNodes[Node.Index][0];
651          Node.ForcedSucc = &AllNodes[SuccIndex];
652          AllNodes[SuccIndex].ForcedPred = &Node;
653        }
654      }
655  
656      // There might be 'cycles' in the forced dependencies, since profile
657      // data isn't 100% accurate. Typically this is observed in loops, when the
658      // loop edges are the hottest successors for the basic blocks of the loop.
659      // Break the cycles by choosing the node with the smallest index as the
660      // head. This helps to keep the original order of the loops, which likely
661      // have already been rotated in the optimized manner.
662      for (NodeT &Node : AllNodes) {
663        if (Node.ForcedSucc == nullptr || Node.ForcedPred == nullptr)
664          continue;
665  
666        NodeT *SuccNode = Node.ForcedSucc;
667        while (SuccNode != nullptr && SuccNode != &Node) {
668          SuccNode = SuccNode->ForcedSucc;
669        }
670        if (SuccNode == nullptr)
671          continue;
672        // Break the cycle
673        AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr;
674        Node.ForcedPred = nullptr;
675      }
676  
677      // Merge nodes with their fallthrough successors
678      for (NodeT &Node : AllNodes) {
679        if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) {
680          const NodeT *CurBlock = &Node;
681          while (CurBlock->ForcedSucc != nullptr) {
682            const NodeT *NextBlock = CurBlock->ForcedSucc;
683            mergeChains(Node.CurChain, NextBlock->CurChain, 0, MergeTypeT::X_Y);
684            CurBlock = NextBlock;
685          }
686        }
687      }
688    }
689  
690    /// Merge pairs of chains while improving the ExtTSP objective.
691    void mergeChainPairs() {
692      /// Deterministically compare pairs of chains
693      auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
694                                  const ChainT *A2, const ChainT *B2) {
695        if (A1 != A2)
696          return A1->Id < A2->Id;
697        return B1->Id < B2->Id;
698      };
699  
700      while (HotChains.size() > 1) {
701        ChainT *BestChainPred = nullptr;
702        ChainT *BestChainSucc = nullptr;
703        MergeGainT BestGain;
704        // Iterate over all pairs of chains
705        for (ChainT *ChainPred : HotChains) {
706          // Get candidates for merging with the current chain
707          for (auto EdgeIt : ChainPred->Edges) {
708            ChainT *ChainSucc = EdgeIt.first;
709            ChainEdge *Edge = EdgeIt.second;
710            // Ignore loop edges
711            if (ChainPred == ChainSucc)
712              continue;
713  
714            // Stop early if the combined chain violates the maximum allowed size
715            if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize)
716              continue;
717  
718            // Compute the gain of merging the two chains
719            MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge);
720            if (CurGain.score() <= EPS)
721              continue;
722  
723            if (BestGain < CurGain ||
724                (std::abs(CurGain.score() - BestGain.score()) < EPS &&
725                 compareChainPairs(ChainPred, ChainSucc, BestChainPred,
726                                   BestChainSucc))) {
727              BestGain = CurGain;
728              BestChainPred = ChainPred;
729              BestChainSucc = ChainSucc;
730            }
731          }
732        }
733  
734        // Stop merging when there is no improvement
735        if (BestGain.score() <= EPS)
736          break;
737  
738        // Merge the best pair of chains
739        mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(),
740                    BestGain.mergeType());
741      }
742    }
743  
744    /// Merge remaining nodes into chains w/o taking jump counts into
745    /// consideration. This allows to maintain the original node order in the
746    /// absence of profile data
747    void mergeColdChains() {
748      for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) {
749        // Iterating in reverse order to make sure original fallthrough jumps are
750        // merged first; this might be beneficial for code size.
751        size_t NumSuccs = SuccNodes[SrcBB].size();
752        for (size_t Idx = 0; Idx < NumSuccs; Idx++) {
753          size_t DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1];
754          ChainT *SrcChain = AllNodes[SrcBB].CurChain;
755          ChainT *DstChain = AllNodes[DstBB].CurChain;
756          if (SrcChain != DstChain && !DstChain->isEntry() &&
757              SrcChain->Nodes.back()->Index == SrcBB &&
758              DstChain->Nodes.front()->Index == DstBB &&
759              SrcChain->isCold() == DstChain->isCold()) {
760            mergeChains(SrcChain, DstChain, 0, MergeTypeT::X_Y);
761          }
762        }
763      }
764    }
765  
766    /// Compute the Ext-TSP score for a given node order and a list of jumps.
767    double extTSPScore(const MergedChain &MergedBlocks,
768                       const std::vector<JumpT *> &Jumps) const {
769      if (Jumps.empty())
770        return 0.0;
771      uint64_t CurAddr = 0;
772      MergedBlocks.forEach([&](const NodeT *Node) {
773        Node->EstimatedAddr = CurAddr;
774        CurAddr += Node->Size;
775      });
776  
777      double Score = 0;
778      for (JumpT *Jump : Jumps) {
779        const NodeT *SrcBlock = Jump->Source;
780        const NodeT *DstBlock = Jump->Target;
781        Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
782                               DstBlock->EstimatedAddr, Jump->ExecutionCount,
783                               Jump->IsConditional);
784      }
785      return Score;
786    }
787  
788    /// Compute the gain of merging two chains.
789    ///
790    /// The function considers all possible ways of merging two chains and
791    /// computes the one having the largest increase in ExtTSP objective. The
792    /// result is a pair with the first element being the gain and the second
793    /// element being the corresponding merging type.
794    MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
795                                ChainEdge *Edge) const {
796      if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) {
797        return Edge->getCachedMergeGain(ChainPred, ChainSucc);
798      }
799  
800      // Precompute jumps between ChainPred and ChainSucc
801      auto Jumps = Edge->jumps();
802      ChainEdge *EdgePP = ChainPred->getEdge(ChainPred);
803      if (EdgePP != nullptr) {
804        Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end());
805      }
806      assert(!Jumps.empty() && "trying to merge chains w/o jumps");
807  
808      // The object holds the best currently chosen gain of merging the two chains
809      MergeGainT Gain = MergeGainT();
810  
811      /// Given a merge offset and a list of merge types, try to merge two chains
812      /// and update Gain with a better alternative
813      auto tryChainMerging = [&](size_t Offset,
814                                 const std::vector<MergeTypeT> &MergeTypes) {
815        // Skip merging corresponding to concatenation w/o splitting
816        if (Offset == 0 || Offset == ChainPred->Nodes.size())
817          return;
818        // Skip merging if it breaks Forced successors
819        NodeT *Node = ChainPred->Nodes[Offset - 1];
820        if (Node->ForcedSucc != nullptr)
821          return;
822        // Apply the merge, compute the corresponding gain, and update the best
823        // value, if the merge is beneficial
824        for (const MergeTypeT &MergeType : MergeTypes) {
825          Gain.updateIfLessThan(
826              computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType));
827        }
828      };
829  
830      // Try to concatenate two chains w/o splitting
831      Gain.updateIfLessThan(
832          computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y));
833  
834      if (EnableChainSplitAlongJumps) {
835        // Attach (a part of) ChainPred before the first node of ChainSucc
836        for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) {
837          const NodeT *SrcBlock = Jump->Source;
838          if (SrcBlock->CurChain != ChainPred)
839            continue;
840          size_t Offset = SrcBlock->CurIndex + 1;
841          tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y});
842        }
843  
844        // Attach (a part of) ChainPred after the last node of ChainSucc
845        for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) {
846          const NodeT *DstBlock = Jump->Source;
847          if (DstBlock->CurChain != ChainPred)
848            continue;
849          size_t Offset = DstBlock->CurIndex;
850          tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1});
851        }
852      }
853  
854      // Try to break ChainPred in various ways and concatenate with ChainSucc
855      if (ChainPred->Nodes.size() <= ChainSplitThreshold) {
856        for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) {
857          // Try to split the chain in different ways. In practice, applying
858          // X2_Y_X1 merging is almost never provides benefits; thus, we exclude
859          // it from consideration to reduce the search space
860          tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1,
861                                   MergeTypeT::X2_X1_Y});
862        }
863      }
864      Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain);
865      return Gain;
866    }
867  
868    /// Compute the score gain of merging two chains, respecting a given
869    /// merge 'type' and 'offset'.
870    ///
871    /// The two chains are not modified in the method.
872    MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
873                                const std::vector<JumpT *> &Jumps,
874                                size_t MergeOffset, MergeTypeT MergeType) const {
875      auto MergedBlocks =
876          mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
877  
878      // Do not allow a merge that does not preserve the original entry point
879      if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
880          !MergedBlocks.getFirstNode()->isEntry())
881        return MergeGainT();
882  
883      // The gain for the new chain
884      auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score;
885      return MergeGainT(NewGainScore, MergeOffset, MergeType);
886    }
887  
888    /// Merge chain From into chain Into, update the list of active chains,
889    /// adjacency information, and the corresponding cached values.
890    void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset,
891                     MergeTypeT MergeType) {
892      assert(Into != From && "a chain cannot be merged with itself");
893  
894      // Merge the nodes
895      MergedChain MergedNodes =
896          mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
897      Into->merge(From, MergedNodes.getNodes());
898  
899      // Merge the edges
900      Into->mergeEdges(From);
901      From->clear();
902  
903      // Update cached ext-tsp score for the new chain
904      ChainEdge *SelfEdge = Into->getEdge(Into);
905      if (SelfEdge != nullptr) {
906        MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end());
907        Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps());
908      }
909  
910      // Remove the chain from the list of active chains
911      llvm::erase_value(HotChains, From);
912  
913      // Invalidate caches
914      for (auto EdgeIt : Into->Edges)
915        EdgeIt.second->invalidateCache();
916    }
917  
918    /// Concatenate all chains into the final order.
919    void concatChains(std::vector<uint64_t> &Order) {
920      // Collect chains and calculate density stats for their sorting
921      std::vector<const ChainT *> SortedChains;
922      DenseMap<const ChainT *, double> ChainDensity;
923      for (ChainT &Chain : AllChains) {
924        if (!Chain.Nodes.empty()) {
925          SortedChains.push_back(&Chain);
926          // Using doubles to avoid overflow of ExecutionCounts
927          double Size = 0;
928          double ExecutionCount = 0;
929          for (NodeT *Node : Chain.Nodes) {
930            Size += static_cast<double>(Node->Size);
931            ExecutionCount += static_cast<double>(Node->ExecutionCount);
932          }
933          assert(Size > 0 && "a chain of zero size");
934          ChainDensity[&Chain] = ExecutionCount / Size;
935        }
936      }
937  
938      // Sorting chains by density in the decreasing order
939      std::stable_sort(SortedChains.begin(), SortedChains.end(),
940                       [&](const ChainT *L, const ChainT *R) {
941                         // Make sure the original entry point is at the
942                         // beginning of the order
943                         if (L->isEntry() != R->isEntry())
944                           return L->isEntry();
945  
946                         const double DL = ChainDensity[L];
947                         const double DR = ChainDensity[R];
948                         // Compare by density and break ties by chain identifiers
949                         return (DL != DR) ? (DL > DR) : (L->Id < R->Id);
950                       });
951  
952      // Collect the nodes in the order specified by their chains
953      Order.reserve(NumNodes);
954      for (const ChainT *Chain : SortedChains) {
955        for (NodeT *Node : Chain->Nodes) {
956          Order.push_back(Node->Index);
957        }
958      }
959    }
960  
961  private:
962    /// The number of nodes in the graph.
963    const size_t NumNodes;
964  
965    /// Successors of each node.
966    std::vector<std::vector<uint64_t>> SuccNodes;
967  
968    /// Predecessors of each node.
969    std::vector<std::vector<uint64_t>> PredNodes;
970  
971    /// All nodes (basic blocks) in the graph.
972    std::vector<NodeT> AllNodes;
973  
974    /// All jumps between the nodes.
975    std::vector<JumpT> AllJumps;
976  
977    /// All chains of nodes.
978    std::vector<ChainT> AllChains;
979  
980    /// All edges between the chains.
981    std::vector<ChainEdge> AllEdges;
982  
983    /// Active chains. The vector gets updated at runtime when chains are merged.
984    std::vector<ChainT *> HotChains;
985  };
986  
987  } // end of anonymous namespace
988  
989  std::vector<uint64_t>
990  llvm::applyExtTspLayout(const std::vector<uint64_t> &NodeSizes,
991                          const std::vector<uint64_t> &NodeCounts,
992                          const std::vector<EdgeCountT> &EdgeCounts) {
993    // Verify correctness of the input data
994    assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input");
995    assert(NodeSizes.size() > 2 && "Incorrect input");
996  
997    // Apply the reordering algorithm
998    ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts);
999    std::vector<uint64_t> Result;
1000    Alg.run(Result);
1001  
1002    // Verify correctness of the output
1003    assert(Result.front() == 0 && "Original entry point is not preserved");
1004    assert(Result.size() == NodeSizes.size() && "Incorrect size of layout");
1005    return Result;
1006  }
1007  
1008  double llvm::calcExtTspScore(const std::vector<uint64_t> &Order,
1009                               const std::vector<uint64_t> &NodeSizes,
1010                               const std::vector<uint64_t> &NodeCounts,
1011                               const std::vector<EdgeCountT> &EdgeCounts) {
1012    // Estimate addresses of the blocks in memory
1013    std::vector<uint64_t> Addr(NodeSizes.size(), 0);
1014    for (size_t Idx = 1; Idx < Order.size(); Idx++) {
1015      Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]];
1016    }
1017    std::vector<uint64_t> OutDegree(NodeSizes.size(), 0);
1018    for (auto It : EdgeCounts) {
1019      uint64_t Pred = It.first.first;
1020      OutDegree[Pred]++;
1021    }
1022  
1023    // Increase the score for each jump
1024    double Score = 0;
1025    for (auto It : EdgeCounts) {
1026      uint64_t Pred = It.first.first;
1027      uint64_t Succ = It.first.second;
1028      uint64_t Count = It.second;
1029      bool IsConditional = OutDegree[Pred] > 1;
1030      Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count,
1031                             IsConditional);
1032    }
1033    return Score;
1034  }
1035  
1036  double llvm::calcExtTspScore(const std::vector<uint64_t> &NodeSizes,
1037                               const std::vector<uint64_t> &NodeCounts,
1038                               const std::vector<EdgeCountT> &EdgeCounts) {
1039    std::vector<uint64_t> Order(NodeSizes.size());
1040    for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) {
1041      Order[Idx] = Idx;
1042    }
1043    return calcExtTspScore(Order, NodeSizes, NodeCounts, EdgeCounts);
1044  }
1045