xref: /freebsd/contrib/llvm-project/llvm/lib/Support/BalancedPartitioning.cpp (revision fe75646a0234a261c0013bf1840fdac4acaf0cec)
1 //===- BalancedPartitioning.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements BalancedPartitioning, a recursive balanced graph
10 // partitioning algorithm.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/BalancedPartitioning.h"
15 #include "llvm/ADT/SetVector.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/ThreadPool.h"
20 
21 using namespace llvm;
22 #define DEBUG_TYPE "balanced-partitioning"
23 
24 void BPFunctionNode::dump(raw_ostream &OS) const {
25   OS << formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id,
26                 make_range(UtilityNodes.begin(), UtilityNodes.end()), Bucket);
27 }
28 
29 template <typename Func>
30 void BalancedPartitioning::BPThreadPool::async(Func &&F) {
31 #if LLVM_ENABLE_THREADS
32   // This new thread could spawn more threads, so mark it as active
33   ++NumActiveThreads;
34   TheThreadPool.async([=]() {
35     // Run the task
36     F();
37 
38     // This thread will no longer spawn new threads, so mark it as inactive
39     if (--NumActiveThreads == 0) {
40       // There are no more active threads, so mark as finished and notify
41       {
42         std::unique_lock<std::mutex> lock(mtx);
43         assert(!IsFinishedSpawning);
44         IsFinishedSpawning = true;
45       }
46       cv.notify_one();
47     }
48   });
49 #else
50   llvm_unreachable("threads are disabled");
51 #endif
52 }
53 
54 void BalancedPartitioning::BPThreadPool::wait() {
55 #if LLVM_ENABLE_THREADS
56   // TODO: We could remove the mutex and condition variable and use
57   // std::atomic::wait() instead, but that isn't available until C++20
58   {
59     std::unique_lock<std::mutex> lock(mtx);
60     cv.wait(lock, [&]() { return IsFinishedSpawning; });
61     assert(IsFinishedSpawning && NumActiveThreads == 0);
62   }
63   // Now we can call ThreadPool::wait() since all tasks have been submitted
64   TheThreadPool.wait();
65 #else
66   llvm_unreachable("threads are disabled");
67 #endif
68 }
69 
70 BalancedPartitioning::BalancedPartitioning(
71     const BalancedPartitioningConfig &Config)
72     : Config(Config) {
73   // Pre-computing log2 values
74   Log2Cache[0] = 0.0;
75   for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)
76     Log2Cache[I] = std::log2(I);
77 }
78 
79 void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {
80   LLVM_DEBUG(
81       dbgs() << format(
82           "Partitioning %d nodes using depth %d and %d iterations per split\n",
83           Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));
84   std::optional<BPThreadPool> TP;
85 #if LLVM_ENABLE_THREADS
86   ThreadPool TheThreadPool;
87   if (Config.TaskSplitDepth > 1)
88     TP.emplace(TheThreadPool);
89 #endif
90 
91   // Record the input order
92   for (unsigned I = 0; I < Nodes.size(); I++)
93     Nodes[I].InputOrderIndex = I;
94 
95   auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end());
96   auto BisectTask = [=, &TP]() {
97     bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);
98   };
99   if (TP) {
100     TP->async(std::move(BisectTask));
101     TP->wait();
102   } else {
103     BisectTask();
104   }
105 
106   llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) {
107     return L.Bucket < R.Bucket;
108   });
109 
110   LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
111 }
112 
113 void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,
114                                   unsigned RecDepth, unsigned RootBucket,
115                                   unsigned Offset,
116                                   std::optional<BPThreadPool> &TP) const {
117   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
118   if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {
119     // We've reach the lowest level of the recursion tree. Fall back to the
120     // original order and assign to buckets.
121     llvm::stable_sort(Nodes, [](const auto &L, const auto &R) {
122       return L.InputOrderIndex < R.InputOrderIndex;
123     });
124     for (auto &N : Nodes)
125       N.Bucket = Offset++;
126     return;
127   }
128 
129   LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
130                               NumNodes, RootBucket));
131 
132   std::mt19937 RNG(RootBucket);
133 
134   unsigned LeftBucket = 2 * RootBucket;
135   unsigned RightBucket = 2 * RootBucket + 1;
136 
137   // Split into two and assign to the left and right buckets
138   split(Nodes, LeftBucket);
139 
140   runIterations(Nodes, RecDepth, LeftBucket, RightBucket, RNG);
141 
142   // Split nodes wrt the resulting buckets
143   auto NodesMid =
144       llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; });
145   unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid);
146 
147   auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid);
148   auto RightNodes = llvm::make_range(NodesMid, Nodes.end());
149 
150   auto LeftRecTask = [=, &TP]() {
151     bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP);
152   };
153   auto RightRecTask = [=, &TP]() {
154     bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP);
155   };
156 
157   if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {
158     TP->async(std::move(LeftRecTask));
159     TP->async(std::move(RightRecTask));
160   } else {
161     LeftRecTask();
162     RightRecTask();
163   }
164 }
165 
166 void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,
167                                          unsigned RecDepth, unsigned LeftBucket,
168                                          unsigned RightBucket,
169                                          std::mt19937 &RNG) const {
170   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
171   DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeDegree;
172   for (auto &N : Nodes)
173     for (auto &UN : N.UtilityNodes)
174       ++UtilityNodeDegree[UN];
175   // Remove utility nodes if they have just one edge or are connected to all
176   // functions
177   for (auto &N : Nodes)
178     llvm::erase_if(N.UtilityNodes, [&](auto &UN) {
179       return UtilityNodeDegree[UN] <= 1 || UtilityNodeDegree[UN] >= NumNodes;
180     });
181 
182   // Renumber utility nodes so they can be used to index into Signatures
183   DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex;
184   for (auto &N : Nodes)
185     for (auto &UN : N.UtilityNodes)
186       if (!UtilityNodeIndex.count(UN))
187         UtilityNodeIndex[UN] = UtilityNodeIndex.size();
188   for (auto &N : Nodes)
189     for (auto &UN : N.UtilityNodes)
190       UN = UtilityNodeIndex[UN];
191 
192   // Initialize signatures
193   SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());
194   for (auto &N : Nodes) {
195     for (auto &UN : N.UtilityNodes) {
196       assert(UN < Signatures.size());
197       if (N.Bucket == LeftBucket) {
198         Signatures[UN].LeftCount++;
199       } else {
200         Signatures[UN].RightCount++;
201       }
202     }
203   }
204 
205   for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {
206     unsigned NumMovedNodes =
207         runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);
208     if (NumMovedNodes == 0)
209       break;
210   }
211 }
212 
213 unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,
214                                             unsigned LeftBucket,
215                                             unsigned RightBucket,
216                                             SignaturesT &Signatures,
217                                             std::mt19937 &RNG) const {
218   // Init signature cost caches
219   for (auto &Signature : Signatures) {
220     if (Signature.CachedGainIsValid)
221       continue;
222     unsigned L = Signature.LeftCount;
223     unsigned R = Signature.RightCount;
224     assert((L > 0 || R > 0) && "incorrect signature");
225     float Cost = logCost(L, R);
226     Signature.CachedGainLR = 0.f;
227     Signature.CachedGainRL = 0.f;
228     if (L > 0)
229       Signature.CachedGainLR = Cost - logCost(L - 1, R + 1);
230     if (R > 0)
231       Signature.CachedGainRL = Cost - logCost(L + 1, R - 1);
232     Signature.CachedGainIsValid = true;
233   }
234 
235   // Compute move gains
236   typedef std::pair<float, BPFunctionNode *> GainPair;
237   std::vector<GainPair> Gains;
238   for (auto &N : Nodes) {
239     bool FromLeftToRight = (N.Bucket == LeftBucket);
240     float Gain = moveGain(N, FromLeftToRight, Signatures);
241     Gains.push_back(std::make_pair(Gain, &N));
242   }
243 
244   // Collect left and right gains
245   auto LeftEnd = llvm::partition(
246       Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });
247   auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd);
248   auto RightRange = llvm::make_range(LeftEnd, Gains.end());
249 
250   // Sort gains in descending order
251   auto LargerGain = [](const auto &L, const auto &R) {
252     return L.first > R.first;
253   };
254   llvm::stable_sort(LeftRange, LargerGain);
255   llvm::stable_sort(RightRange, LargerGain);
256 
257   unsigned NumMovedDataVertices = 0;
258   for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) {
259     auto &[LeftGain, LeftNode] = LeftPair;
260     auto &[RightGain, RightNode] = RightPair;
261     // Stop when the gain is no longer beneficial
262     if (LeftGain + RightGain <= 0.f)
263       break;
264     // Try to exchange the nodes between buckets
265     if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG))
266       ++NumMovedDataVertices;
267     if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG))
268       ++NumMovedDataVertices;
269   }
270   return NumMovedDataVertices;
271 }
272 
273 bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,
274                                             unsigned LeftBucket,
275                                             unsigned RightBucket,
276                                             SignaturesT &Signatures,
277                                             std::mt19937 &RNG) const {
278   // Sometimes we skip the move. This helps to escape local optima
279   if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=
280       Config.SkipProbability)
281     return false;
282 
283   bool FromLeftToRight = (N.Bucket == LeftBucket);
284   // Update the current bucket
285   N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);
286 
287   // Update signatures and invalidate gain cache
288   if (FromLeftToRight) {
289     for (auto &UN : N.UtilityNodes) {
290       auto &Signature = Signatures[UN];
291       Signature.LeftCount--;
292       Signature.RightCount++;
293       Signature.CachedGainIsValid = false;
294     }
295   } else {
296     for (auto &UN : N.UtilityNodes) {
297       auto &Signature = Signatures[UN];
298       Signature.LeftCount++;
299       Signature.RightCount--;
300       Signature.CachedGainIsValid = false;
301     }
302   }
303   return true;
304 }
305 
306 void BalancedPartitioning::split(const FunctionNodeRange Nodes,
307                                  unsigned StartBucket) const {
308   unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
309   auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;
310 
311   std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) {
312     return L.InputOrderIndex < R.InputOrderIndex;
313   });
314 
315   for (auto &N : llvm::make_range(Nodes.begin(), NodesMid))
316     N.Bucket = StartBucket;
317   for (auto &N : llvm::make_range(NodesMid, Nodes.end()))
318     N.Bucket = StartBucket + 1;
319 }
320 
321 float BalancedPartitioning::moveGain(const BPFunctionNode &N,
322                                      bool FromLeftToRight,
323                                      const SignaturesT &Signatures) {
324   float Gain = 0.f;
325   for (auto &UN : N.UtilityNodes)
326     Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR
327                              : Signatures[UN].CachedGainRL);
328   return Gain;
329 }
330 
331 float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {
332   return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1));
333 }
334 
335 float BalancedPartitioning::log2Cached(unsigned i) const {
336   return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i);
337 }
338