xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/SwitchLoweringUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/SwitchLoweringUtils.h"
15 #include "llvm/CodeGen/FunctionLoweringInfo.h"
16 #include "llvm/CodeGen/MachineJumpTableInfo.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 #include "llvm/Target/TargetMachine.h"
19 
20 using namespace llvm;
21 using namespace SwitchCG;
22 
getJumpTableRange(const CaseClusterVector & Clusters,unsigned First,unsigned Last)23 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
24                                      unsigned First, unsigned Last) {
25   assert(Last >= First);
26   const APInt &LowCase = Clusters[First].Low->getValue();
27   const APInt &HighCase = Clusters[Last].High->getValue();
28   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
29 
30   // FIXME: A range of consecutive cases has 100% density, but only requires one
31   // comparison to lower. We should discriminate against such consecutive ranges
32   // in jump tables.
33   return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
34 }
35 
36 uint64_t
getJumpTableNumCases(const SmallVectorImpl<unsigned> & TotalCases,unsigned First,unsigned Last)37 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
38                                unsigned First, unsigned Last) {
39   assert(Last >= First);
40   assert(TotalCases[Last] >= TotalCases[First]);
41   uint64_t NumCases =
42       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
43   return NumCases;
44 }
45 
findJumpTables(CaseClusterVector & Clusters,const SwitchInst * SI,std::optional<SDLoc> SL,MachineBasicBlock * DefaultMBB,ProfileSummaryInfo * PSI,BlockFrequencyInfo * BFI)46 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
47                                               const SwitchInst *SI,
48                                               std::optional<SDLoc> SL,
49                                               MachineBasicBlock *DefaultMBB,
50                                               ProfileSummaryInfo *PSI,
51                                               BlockFrequencyInfo *BFI) {
52 #ifndef NDEBUG
53   // Clusters must be non-empty, sorted, and only contain Range clusters.
54   assert(!Clusters.empty());
55   for (CaseCluster &C : Clusters)
56     assert(C.Kind == CC_Range);
57   for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
58     assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
59 #endif
60 
61   assert(TLI && "TLI not set!");
62   if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
63     return;
64 
65   const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
66   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
67 
68   // Bail if not enough cases.
69   const int64_t N = Clusters.size();
70   if (N < 2 || N < MinJumpTableEntries)
71     return;
72 
73   // Accumulated number of cases in each cluster and those prior to it.
74   SmallVector<unsigned, 8> TotalCases(N);
75   for (unsigned i = 0; i < N; ++i) {
76     const APInt &Hi = Clusters[i].High->getValue();
77     const APInt &Lo = Clusters[i].Low->getValue();
78     TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
79     if (i != 0)
80       TotalCases[i] += TotalCases[i - 1];
81   }
82 
83   uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
84   uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
85   assert(NumCases < UINT64_MAX / 100);
86   assert(Range >= NumCases);
87 
88   // Cheap case: the whole range may be suitable for jump table.
89   if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
90     CaseCluster JTCluster;
91     if (buildJumpTable(Clusters, 0, N - 1, SI, SL, DefaultMBB, JTCluster)) {
92       Clusters[0] = JTCluster;
93       Clusters.resize(1);
94       return;
95     }
96   }
97 
98   // The algorithm below is not suitable for -O0.
99   if (TM->getOptLevel() == CodeGenOptLevel::None)
100     return;
101 
102   // Split Clusters into minimum number of dense partitions. The algorithm uses
103   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
104   // for the Case Statement'" (1994), but builds the MinPartitions array in
105   // reverse order to make it easier to reconstruct the partitions in ascending
106   // order. In the choice between two optimal partitionings, it picks the one
107   // which yields more jump tables. The algorithm is described in
108   // https://arxiv.org/pdf/1910.02351v2
109 
110   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
111   SmallVector<unsigned, 8> MinPartitions(N);
112   // LastElement[i] is the last element of the partition starting at i.
113   SmallVector<unsigned, 8> LastElement(N);
114   // PartitionsScore[i] is used to break ties when choosing between two
115   // partitionings resulting in the same number of partitions.
116   SmallVector<unsigned, 8> PartitionsScore(N);
117   // For PartitionsScore, a small number of comparisons is considered as good as
118   // a jump table and a single comparison is considered better than a jump
119   // table.
120   enum PartitionScores : unsigned {
121     NoTable = 0,
122     Table = 1,
123     FewCases = 1,
124     SingleCase = 2
125   };
126 
127   // Base case: There is only one way to partition Clusters[N-1].
128   MinPartitions[N - 1] = 1;
129   LastElement[N - 1] = N - 1;
130   PartitionsScore[N - 1] = PartitionScores::SingleCase;
131 
132   // Note: loop indexes are signed to avoid underflow.
133   for (int64_t i = N - 2; i >= 0; i--) {
134     // Find optimal partitioning of Clusters[i..N-1].
135     // Baseline: Put Clusters[i] into a partition on its own.
136     MinPartitions[i] = MinPartitions[i + 1] + 1;
137     LastElement[i] = i;
138     PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
139 
140     // Search for a solution that results in fewer partitions.
141     for (int64_t j = N - 1; j > i; j--) {
142       // Try building a partition from Clusters[i..j].
143       Range = getJumpTableRange(Clusters, i, j);
144       NumCases = getJumpTableNumCases(TotalCases, i, j);
145       assert(NumCases < UINT64_MAX / 100);
146       assert(Range >= NumCases);
147 
148       if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
149         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
150         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
151         int64_t NumEntries = j - i + 1;
152 
153         if (NumEntries == 1)
154           Score += PartitionScores::SingleCase;
155         else if (NumEntries <= SmallNumberOfEntries)
156           Score += PartitionScores::FewCases;
157         else if (NumEntries >= MinJumpTableEntries)
158           Score += PartitionScores::Table;
159 
160         // If this leads to fewer partitions, or to the same number of
161         // partitions with better score, it is a better partitioning.
162         if (NumPartitions < MinPartitions[i] ||
163             (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
164           MinPartitions[i] = NumPartitions;
165           LastElement[i] = j;
166           PartitionsScore[i] = Score;
167         }
168       }
169     }
170   }
171 
172   // Iterate over the partitions, replacing some with jump tables in-place.
173   unsigned DstIndex = 0;
174   for (unsigned First = 0, Last; First < N; First = Last + 1) {
175     Last = LastElement[First];
176     assert(Last >= First);
177     assert(DstIndex <= First);
178     unsigned NumClusters = Last - First + 1;
179 
180     CaseCluster JTCluster;
181     if (NumClusters >= MinJumpTableEntries &&
182         buildJumpTable(Clusters, First, Last, SI, SL, DefaultMBB, JTCluster)) {
183       Clusters[DstIndex++] = JTCluster;
184     } else {
185       for (unsigned I = First; I <= Last; ++I)
186         std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
187     }
188   }
189   Clusters.resize(DstIndex);
190 }
191 
buildJumpTable(const CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,const std::optional<SDLoc> & SL,MachineBasicBlock * DefaultMBB,CaseCluster & JTCluster)192 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
193                                               unsigned First, unsigned Last,
194                                               const SwitchInst *SI,
195                                               const std::optional<SDLoc> &SL,
196                                               MachineBasicBlock *DefaultMBB,
197                                               CaseCluster &JTCluster) {
198   assert(First <= Last);
199 
200   auto Prob = BranchProbability::getZero();
201   unsigned NumCmps = 0;
202   std::vector<MachineBasicBlock*> Table;
203   DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
204 
205   // Initialize probabilities in JTProbs.
206   for (unsigned I = First; I <= Last; ++I)
207     JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
208 
209   for (unsigned I = First; I <= Last; ++I) {
210     assert(Clusters[I].Kind == CC_Range);
211     Prob += Clusters[I].Prob;
212     const APInt &Low = Clusters[I].Low->getValue();
213     const APInt &High = Clusters[I].High->getValue();
214     NumCmps += (Low == High) ? 1 : 2;
215     if (I != First) {
216       // Fill the gap between this and the previous cluster.
217       const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
218       assert(PreviousHigh.slt(Low));
219       uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
220       for (uint64_t J = 0; J < Gap; J++)
221         Table.push_back(DefaultMBB);
222     }
223     uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
224     for (uint64_t J = 0; J < ClusterSize; ++J)
225       Table.push_back(Clusters[I].MBB);
226     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
227   }
228 
229   unsigned NumDests = JTProbs.size();
230   if (TLI->isSuitableForBitTests(NumDests, NumCmps,
231                                  Clusters[First].Low->getValue(),
232                                  Clusters[Last].High->getValue(), *DL)) {
233     // Clusters[First..Last] should be lowered as bit tests instead.
234     return false;
235   }
236 
237   // Create the MBB that will load from and jump through the table.
238   // Note: We create it here, but it's not inserted into the function yet.
239   MachineFunction *CurMF = FuncInfo.MF;
240   MachineBasicBlock *JumpTableMBB =
241       CurMF->CreateMachineBasicBlock(SI->getParent());
242 
243   // Add successors. Note: use table order for determinism.
244   SmallPtrSet<MachineBasicBlock *, 8> Done;
245   for (MachineBasicBlock *Succ : Table) {
246     if (Done.count(Succ))
247       continue;
248     addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
249     Done.insert(Succ);
250   }
251   JumpTableMBB->normalizeSuccProbs();
252 
253   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
254                      ->createJumpTableIndex(Table);
255 
256   // Set up the jump table info.
257   JumpTable JT(-1U, JTI, JumpTableMBB, nullptr, SL);
258   JumpTableHeader JTH(Clusters[First].Low->getValue(),
259                       Clusters[Last].High->getValue(), SI->getCondition(),
260                       nullptr, false);
261   JTCases.emplace_back(std::move(JTH), std::move(JT));
262 
263   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
264                                      JTCases.size() - 1, Prob);
265   return true;
266 }
267 
findBitTestClusters(CaseClusterVector & Clusters,const SwitchInst * SI)268 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
269                                                    const SwitchInst *SI) {
270   // Partition Clusters into as few subsets as possible, where each subset has a
271   // range that fits in a machine word and has <= 3 unique destinations.
272 
273 #ifndef NDEBUG
274   // Clusters must be sorted and contain Range or JumpTable clusters.
275   assert(!Clusters.empty());
276   assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
277   for (const CaseCluster &C : Clusters)
278     assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
279   for (unsigned i = 1; i < Clusters.size(); ++i)
280     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
281 #endif
282 
283   // The algorithm below is not suitable for -O0.
284   if (TM->getOptLevel() == CodeGenOptLevel::None)
285     return;
286 
287   // If target does not have legal shift left, do not emit bit tests at all.
288   EVT PTy = TLI->getPointerTy(*DL);
289   if (!TLI->isOperationLegal(ISD::SHL, PTy))
290     return;
291 
292   int BitWidth = PTy.getSizeInBits();
293   const int64_t N = Clusters.size();
294 
295   // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
296   SmallVector<unsigned, 8> MinPartitions(N);
297   // LastElement[i] is the last element of the partition starting at i.
298   SmallVector<unsigned, 8> LastElement(N);
299 
300   // FIXME: This might not be the best algorithm for finding bit test clusters.
301 
302   // Base case: There is only one way to partition Clusters[N-1].
303   MinPartitions[N - 1] = 1;
304   LastElement[N - 1] = N - 1;
305 
306   // Note: loop indexes are signed to avoid underflow.
307   for (int64_t i = N - 2; i >= 0; --i) {
308     // Find optimal partitioning of Clusters[i..N-1].
309     // Baseline: Put Clusters[i] into a partition on its own.
310     MinPartitions[i] = MinPartitions[i + 1] + 1;
311     LastElement[i] = i;
312 
313     // Search for a solution that results in fewer partitions.
314     // Note: the search is limited by BitWidth, reducing time complexity.
315     for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
316       // Try building a partition from Clusters[i..j].
317 
318       // Check the range.
319       if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
320                                 Clusters[j].High->getValue(), *DL))
321         continue;
322 
323       // Check nbr of destinations and cluster types.
324       // FIXME: This works, but doesn't seem very efficient.
325       bool RangesOnly = true;
326       BitVector Dests(FuncInfo.MF->getNumBlockIDs());
327       for (int64_t k = i; k <= j; k++) {
328         if (Clusters[k].Kind != CC_Range) {
329           RangesOnly = false;
330           break;
331         }
332         Dests.set(Clusters[k].MBB->getNumber());
333       }
334       if (!RangesOnly || Dests.count() > 3)
335         break;
336 
337       // Check if it's a better partition.
338       unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
339       if (NumPartitions < MinPartitions[i]) {
340         // Found a better partition.
341         MinPartitions[i] = NumPartitions;
342         LastElement[i] = j;
343       }
344     }
345   }
346 
347   // Iterate over the partitions, replacing with bit-test clusters in-place.
348   unsigned DstIndex = 0;
349   for (unsigned First = 0, Last; First < N; First = Last + 1) {
350     Last = LastElement[First];
351     assert(First <= Last);
352     assert(DstIndex <= First);
353 
354     CaseCluster BitTestCluster;
355     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
356       Clusters[DstIndex++] = BitTestCluster;
357     } else {
358       size_t NumClusters = Last - First + 1;
359       std::memmove(&Clusters[DstIndex], &Clusters[First],
360                    sizeof(Clusters[0]) * NumClusters);
361       DstIndex += NumClusters;
362     }
363   }
364   Clusters.resize(DstIndex);
365 }
366 
buildBitTests(CaseClusterVector & Clusters,unsigned First,unsigned Last,const SwitchInst * SI,CaseCluster & BTCluster)367 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
368                                              unsigned First, unsigned Last,
369                                              const SwitchInst *SI,
370                                              CaseCluster &BTCluster) {
371   assert(First <= Last);
372   if (First == Last)
373     return false;
374 
375   BitVector Dests(FuncInfo.MF->getNumBlockIDs());
376   unsigned NumCmps = 0;
377   for (int64_t I = First; I <= Last; ++I) {
378     assert(Clusters[I].Kind == CC_Range);
379     Dests.set(Clusters[I].MBB->getNumber());
380     NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
381   }
382   unsigned NumDests = Dests.count();
383 
384   APInt Low = Clusters[First].Low->getValue();
385   APInt High = Clusters[Last].High->getValue();
386   assert(Low.slt(High));
387 
388   if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
389     return false;
390 
391   APInt LowBound;
392   APInt CmpRange;
393 
394   const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
395   assert(TLI->rangeFitsInWord(Low, High, *DL) &&
396          "Case range must fit in bit mask!");
397 
398   // Check if the clusters cover a contiguous range such that no value in the
399   // range will jump to the default statement.
400   bool ContiguousRange = true;
401   for (int64_t I = First + 1; I <= Last; ++I) {
402     if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
403       ContiguousRange = false;
404       break;
405     }
406   }
407 
408   if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
409     // Optimize the case where all the case values fit in a word without having
410     // to subtract minValue. In this case, we can optimize away the subtraction.
411     LowBound = APInt::getZero(Low.getBitWidth());
412     CmpRange = High;
413     ContiguousRange = false;
414   } else {
415     LowBound = Low;
416     CmpRange = High - Low;
417   }
418 
419   CaseBitsVector CBV;
420   auto TotalProb = BranchProbability::getZero();
421   for (unsigned i = First; i <= Last; ++i) {
422     // Find the CaseBits for this destination.
423     unsigned j;
424     for (j = 0; j < CBV.size(); ++j)
425       if (CBV[j].BB == Clusters[i].MBB)
426         break;
427     if (j == CBV.size())
428       CBV.push_back(
429           CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
430     CaseBits *CB = &CBV[j];
431 
432     // Update Mask, Bits and ExtraProb.
433     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
434     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
435     assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
436     CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
437     CB->Bits += Hi - Lo + 1;
438     CB->ExtraProb += Clusters[i].Prob;
439     TotalProb += Clusters[i].Prob;
440   }
441 
442   BitTestInfo BTI;
443   llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
444     // Sort by probability first, number of bits second, bit mask third.
445     if (a.ExtraProb != b.ExtraProb)
446       return a.ExtraProb > b.ExtraProb;
447     if (a.Bits != b.Bits)
448       return a.Bits > b.Bits;
449     return a.Mask < b.Mask;
450   });
451 
452   for (auto &CB : CBV) {
453     MachineBasicBlock *BitTestBB =
454         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
455     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
456   }
457   BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
458                             SI->getCondition(), -1U, MVT::Other, false,
459                             ContiguousRange, nullptr, nullptr, std::move(BTI),
460                             TotalProb);
461 
462   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
463                                     BitTestCases.size() - 1, TotalProb);
464   return true;
465 }
466 
sortAndRangeify(CaseClusterVector & Clusters)467 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
468 #ifndef NDEBUG
469   for (const CaseCluster &CC : Clusters)
470     assert(CC.Low == CC.High && "Input clusters must be single-case");
471 #endif
472 
473   llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
474     return a.Low->getValue().slt(b.Low->getValue());
475   });
476 
477   // Merge adjacent clusters with the same destination.
478   const unsigned N = Clusters.size();
479   unsigned DstIndex = 0;
480   for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
481     CaseCluster &CC = Clusters[SrcIndex];
482     const ConstantInt *CaseVal = CC.Low;
483     MachineBasicBlock *Succ = CC.MBB;
484 
485     if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
486         (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
487       // If this case has the same successor and is a neighbour, merge it into
488       // the previous cluster.
489       Clusters[DstIndex - 1].High = CaseVal;
490       Clusters[DstIndex - 1].Prob += CC.Prob;
491     } else {
492       std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
493                    sizeof(Clusters[SrcIndex]));
494     }
495   }
496   Clusters.resize(DstIndex);
497 }
498 
caseClusterRank(const CaseCluster & CC,CaseClusterIt First,CaseClusterIt Last)499 unsigned SwitchCG::SwitchLowering::caseClusterRank(const CaseCluster &CC,
500                                                    CaseClusterIt First,
501                                                    CaseClusterIt Last) {
502   return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
503     if (X.Prob != CC.Prob)
504       return X.Prob > CC.Prob;
505 
506     // Ties are broken by comparing the case value.
507     return X.Low->getValue().slt(CC.Low->getValue());
508   });
509 }
510 
511 llvm::SwitchCG::SwitchLowering::SplitWorkItemInfo
computeSplitWorkItemInfo(const SwitchWorkListItem & W)512 SwitchCG::SwitchLowering::computeSplitWorkItemInfo(
513     const SwitchWorkListItem &W) {
514   CaseClusterIt LastLeft = W.FirstCluster;
515   CaseClusterIt FirstRight = W.LastCluster;
516   auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
517   auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
518 
519   // Move LastLeft and FirstRight towards each other from opposite directions to
520   // find a partitioning of the clusters which balances the probability on both
521   // sides. If LeftProb and RightProb are equal, alternate which side is
522   // taken to ensure 0-probability nodes are distributed evenly.
523   unsigned I = 0;
524   while (LastLeft + 1 < FirstRight) {
525     if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
526       LeftProb += (++LastLeft)->Prob;
527     else
528       RightProb += (--FirstRight)->Prob;
529     I++;
530   }
531 
532   while (true) {
533     // Our binary search tree differs from a typical BST in that ours can have
534     // up to three values in each leaf. The pivot selection above doesn't take
535     // that into account, which means the tree might require more nodes and be
536     // less efficient. We compensate for this here.
537 
538     unsigned NumLeft = LastLeft - W.FirstCluster + 1;
539     unsigned NumRight = W.LastCluster - FirstRight + 1;
540 
541     if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
542       // If one side has less than 3 clusters, and the other has more than 3,
543       // consider taking a cluster from the other side.
544 
545       if (NumLeft < NumRight) {
546         // Consider moving the first cluster on the right to the left side.
547         CaseCluster &CC = *FirstRight;
548         unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
549         unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
550         if (LeftSideRank <= RightSideRank) {
551           // Moving the cluster to the left does not demote it.
552           ++LastLeft;
553           ++FirstRight;
554           continue;
555         }
556       } else {
557         assert(NumRight < NumLeft);
558         // Consider moving the last element on the left to the right side.
559         CaseCluster &CC = *LastLeft;
560         unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
561         unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
562         if (RightSideRank <= LeftSideRank) {
563           // Moving the cluster to the right does not demot it.
564           --LastLeft;
565           --FirstRight;
566           continue;
567         }
568       }
569     }
570     break;
571   }
572 
573   assert(LastLeft + 1 == FirstRight);
574   assert(LastLeft >= W.FirstCluster);
575   assert(FirstRight <= W.LastCluster);
576 
577   return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};
578 }
579