1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains switch inst lowering optimizations and utilities for 10 // codegen, so that it can be used for both SelectionDAG and GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MachineJumpTableInfo.h" 15 #include "llvm/CodeGen/SwitchLoweringUtils.h" 16 17 using namespace llvm; 18 using namespace SwitchCG; 19 20 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters, 21 unsigned First, unsigned Last) { 22 assert(Last >= First); 23 const APInt &LowCase = Clusters[First].Low->getValue(); 24 const APInt &HighCase = Clusters[Last].High->getValue(); 25 assert(LowCase.getBitWidth() == HighCase.getBitWidth()); 26 27 // FIXME: A range of consecutive cases has 100% density, but only requires one 28 // comparison to lower. We should discriminate against such consecutive ranges 29 // in jump tables. 30 return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1; 31 } 32 33 uint64_t 34 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, 35 unsigned First, unsigned Last) { 36 assert(Last >= First); 37 assert(TotalCases[Last] >= TotalCases[First]); 38 uint64_t NumCases = 39 TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); 40 return NumCases; 41 } 42 43 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, 44 const SwitchInst *SI, 45 MachineBasicBlock *DefaultMBB) { 46 #ifndef NDEBUG 47 // Clusters must be non-empty, sorted, and only contain Range clusters. 48 assert(!Clusters.empty()); 49 for (CaseCluster &C : Clusters) 50 assert(C.Kind == CC_Range); 51 for (unsigned i = 1, e = Clusters.size(); i < e; ++i) 52 assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); 53 #endif 54 55 assert(TLI && "TLI not set!"); 56 if (!TLI->areJTsAllowed(SI->getParent()->getParent())) 57 return; 58 59 const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries(); 60 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2; 61 62 // Bail if not enough cases. 63 const int64_t N = Clusters.size(); 64 if (N < 2 || N < MinJumpTableEntries) 65 return; 66 67 // Accumulated number of cases in each cluster and those prior to it. 68 SmallVector<unsigned, 8> TotalCases(N); 69 for (unsigned i = 0; i < N; ++i) { 70 const APInt &Hi = Clusters[i].High->getValue(); 71 const APInt &Lo = Clusters[i].Low->getValue(); 72 TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; 73 if (i != 0) 74 TotalCases[i] += TotalCases[i - 1]; 75 } 76 77 uint64_t Range = getJumpTableRange(Clusters,0, N - 1); 78 uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1); 79 assert(NumCases < UINT64_MAX / 100); 80 assert(Range >= NumCases); 81 82 // Cheap case: the whole range may be suitable for jump table. 83 if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { 84 CaseCluster JTCluster; 85 if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { 86 Clusters[0] = JTCluster; 87 Clusters.resize(1); 88 return; 89 } 90 } 91 92 // The algorithm below is not suitable for -O0. 93 if (TM->getOptLevel() == CodeGenOpt::None) 94 return; 95 96 // Split Clusters into minimum number of dense partitions. The algorithm uses 97 // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code 98 // for the Case Statement'" (1994), but builds the MinPartitions array in 99 // reverse order to make it easier to reconstruct the partitions in ascending 100 // order. In the choice between two optimal partitionings, it picks the one 101 // which yields more jump tables. 102 103 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. 104 SmallVector<unsigned, 8> MinPartitions(N); 105 // LastElement[i] is the last element of the partition starting at i. 106 SmallVector<unsigned, 8> LastElement(N); 107 // PartitionsScore[i] is used to break ties when choosing between two 108 // partitionings resulting in the same number of partitions. 109 SmallVector<unsigned, 8> PartitionsScore(N); 110 // For PartitionsScore, a small number of comparisons is considered as good as 111 // a jump table and a single comparison is considered better than a jump 112 // table. 113 enum PartitionScores : unsigned { 114 NoTable = 0, 115 Table = 1, 116 FewCases = 1, 117 SingleCase = 2 118 }; 119 120 // Base case: There is only one way to partition Clusters[N-1]. 121 MinPartitions[N - 1] = 1; 122 LastElement[N - 1] = N - 1; 123 PartitionsScore[N - 1] = PartitionScores::SingleCase; 124 125 // Note: loop indexes are signed to avoid underflow. 126 for (int64_t i = N - 2; i >= 0; i--) { 127 // Find optimal partitioning of Clusters[i..N-1]. 128 // Baseline: Put Clusters[i] into a partition on its own. 129 MinPartitions[i] = MinPartitions[i + 1] + 1; 130 LastElement[i] = i; 131 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase; 132 133 // Search for a solution that results in fewer partitions. 134 for (int64_t j = N - 1; j > i; j--) { 135 // Try building a partition from Clusters[i..j]. 136 Range = getJumpTableRange(Clusters, i, j); 137 NumCases = getJumpTableNumCases(TotalCases, i, j); 138 assert(NumCases < UINT64_MAX / 100); 139 assert(Range >= NumCases); 140 141 if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { 142 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); 143 unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; 144 int64_t NumEntries = j - i + 1; 145 146 if (NumEntries == 1) 147 Score += PartitionScores::SingleCase; 148 else if (NumEntries <= SmallNumberOfEntries) 149 Score += PartitionScores::FewCases; 150 else if (NumEntries >= MinJumpTableEntries) 151 Score += PartitionScores::Table; 152 153 // If this leads to fewer partitions, or to the same number of 154 // partitions with better score, it is a better partitioning. 155 if (NumPartitions < MinPartitions[i] || 156 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) { 157 MinPartitions[i] = NumPartitions; 158 LastElement[i] = j; 159 PartitionsScore[i] = Score; 160 } 161 } 162 } 163 } 164 165 // Iterate over the partitions, replacing some with jump tables in-place. 166 unsigned DstIndex = 0; 167 for (unsigned First = 0, Last; First < N; First = Last + 1) { 168 Last = LastElement[First]; 169 assert(Last >= First); 170 assert(DstIndex <= First); 171 unsigned NumClusters = Last - First + 1; 172 173 CaseCluster JTCluster; 174 if (NumClusters >= MinJumpTableEntries && 175 buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) { 176 Clusters[DstIndex++] = JTCluster; 177 } else { 178 for (unsigned I = First; I <= Last; ++I) 179 std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); 180 } 181 } 182 Clusters.resize(DstIndex); 183 } 184 185 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters, 186 unsigned First, unsigned Last, 187 const SwitchInst *SI, 188 MachineBasicBlock *DefaultMBB, 189 CaseCluster &JTCluster) { 190 assert(First <= Last); 191 192 auto Prob = BranchProbability::getZero(); 193 unsigned NumCmps = 0; 194 std::vector<MachineBasicBlock*> Table; 195 DenseMap<MachineBasicBlock*, BranchProbability> JTProbs; 196 197 // Initialize probabilities in JTProbs. 198 for (unsigned I = First; I <= Last; ++I) 199 JTProbs[Clusters[I].MBB] = BranchProbability::getZero(); 200 201 for (unsigned I = First; I <= Last; ++I) { 202 assert(Clusters[I].Kind == CC_Range); 203 Prob += Clusters[I].Prob; 204 const APInt &Low = Clusters[I].Low->getValue(); 205 const APInt &High = Clusters[I].High->getValue(); 206 NumCmps += (Low == High) ? 1 : 2; 207 if (I != First) { 208 // Fill the gap between this and the previous cluster. 209 const APInt &PreviousHigh = Clusters[I - 1].High->getValue(); 210 assert(PreviousHigh.slt(Low)); 211 uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1; 212 for (uint64_t J = 0; J < Gap; J++) 213 Table.push_back(DefaultMBB); 214 } 215 uint64_t ClusterSize = (High - Low).getLimitedValue() + 1; 216 for (uint64_t J = 0; J < ClusterSize; ++J) 217 Table.push_back(Clusters[I].MBB); 218 JTProbs[Clusters[I].MBB] += Clusters[I].Prob; 219 } 220 221 unsigned NumDests = JTProbs.size(); 222 if (TLI->isSuitableForBitTests(NumDests, NumCmps, 223 Clusters[First].Low->getValue(), 224 Clusters[Last].High->getValue(), *DL)) { 225 // Clusters[First..Last] should be lowered as bit tests instead. 226 return false; 227 } 228 229 // Create the MBB that will load from and jump through the table. 230 // Note: We create it here, but it's not inserted into the function yet. 231 MachineFunction *CurMF = FuncInfo.MF; 232 MachineBasicBlock *JumpTableMBB = 233 CurMF->CreateMachineBasicBlock(SI->getParent()); 234 235 // Add successors. Note: use table order for determinism. 236 SmallPtrSet<MachineBasicBlock *, 8> Done; 237 for (MachineBasicBlock *Succ : Table) { 238 if (Done.count(Succ)) 239 continue; 240 addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]); 241 Done.insert(Succ); 242 } 243 JumpTableMBB->normalizeSuccProbs(); 244 245 unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding()) 246 ->createJumpTableIndex(Table); 247 248 // Set up the jump table info. 249 JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); 250 JumpTableHeader JTH(Clusters[First].Low->getValue(), 251 Clusters[Last].High->getValue(), SI->getCondition(), 252 nullptr, false); 253 JTCases.emplace_back(std::move(JTH), std::move(JT)); 254 255 JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, 256 JTCases.size() - 1, Prob); 257 return true; 258 } 259 260 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters, 261 const SwitchInst *SI) { 262 // Partition Clusters into as few subsets as possible, where each subset has a 263 // range that fits in a machine word and has <= 3 unique destinations. 264 265 #ifndef NDEBUG 266 // Clusters must be sorted and contain Range or JumpTable clusters. 267 assert(!Clusters.empty()); 268 assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable); 269 for (const CaseCluster &C : Clusters) 270 assert(C.Kind == CC_Range || C.Kind == CC_JumpTable); 271 for (unsigned i = 1; i < Clusters.size(); ++i) 272 assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue())); 273 #endif 274 275 // The algorithm below is not suitable for -O0. 276 if (TM->getOptLevel() == CodeGenOpt::None) 277 return; 278 279 // If target does not have legal shift left, do not emit bit tests at all. 280 EVT PTy = TLI->getPointerTy(*DL); 281 if (!TLI->isOperationLegal(ISD::SHL, PTy)) 282 return; 283 284 int BitWidth = PTy.getSizeInBits(); 285 const int64_t N = Clusters.size(); 286 287 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. 288 SmallVector<unsigned, 8> MinPartitions(N); 289 // LastElement[i] is the last element of the partition starting at i. 290 SmallVector<unsigned, 8> LastElement(N); 291 292 // FIXME: This might not be the best algorithm for finding bit test clusters. 293 294 // Base case: There is only one way to partition Clusters[N-1]. 295 MinPartitions[N - 1] = 1; 296 LastElement[N - 1] = N - 1; 297 298 // Note: loop indexes are signed to avoid underflow. 299 for (int64_t i = N - 2; i >= 0; --i) { 300 // Find optimal partitioning of Clusters[i..N-1]. 301 // Baseline: Put Clusters[i] into a partition on its own. 302 MinPartitions[i] = MinPartitions[i + 1] + 1; 303 LastElement[i] = i; 304 305 // Search for a solution that results in fewer partitions. 306 // Note: the search is limited by BitWidth, reducing time complexity. 307 for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) { 308 // Try building a partition from Clusters[i..j]. 309 310 // Check the range. 311 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(), 312 Clusters[j].High->getValue(), *DL)) 313 continue; 314 315 // Check nbr of destinations and cluster types. 316 // FIXME: This works, but doesn't seem very efficient. 317 bool RangesOnly = true; 318 BitVector Dests(FuncInfo.MF->getNumBlockIDs()); 319 for (int64_t k = i; k <= j; k++) { 320 if (Clusters[k].Kind != CC_Range) { 321 RangesOnly = false; 322 break; 323 } 324 Dests.set(Clusters[k].MBB->getNumber()); 325 } 326 if (!RangesOnly || Dests.count() > 3) 327 break; 328 329 // Check if it's a better partition. 330 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); 331 if (NumPartitions < MinPartitions[i]) { 332 // Found a better partition. 333 MinPartitions[i] = NumPartitions; 334 LastElement[i] = j; 335 } 336 } 337 } 338 339 // Iterate over the partitions, replacing with bit-test clusters in-place. 340 unsigned DstIndex = 0; 341 for (unsigned First = 0, Last; First < N; First = Last + 1) { 342 Last = LastElement[First]; 343 assert(First <= Last); 344 assert(DstIndex <= First); 345 346 CaseCluster BitTestCluster; 347 if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) { 348 Clusters[DstIndex++] = BitTestCluster; 349 } else { 350 size_t NumClusters = Last - First + 1; 351 std::memmove(&Clusters[DstIndex], &Clusters[First], 352 sizeof(Clusters[0]) * NumClusters); 353 DstIndex += NumClusters; 354 } 355 } 356 Clusters.resize(DstIndex); 357 } 358 359 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters, 360 unsigned First, unsigned Last, 361 const SwitchInst *SI, 362 CaseCluster &BTCluster) { 363 assert(First <= Last); 364 if (First == Last) 365 return false; 366 367 BitVector Dests(FuncInfo.MF->getNumBlockIDs()); 368 unsigned NumCmps = 0; 369 for (int64_t I = First; I <= Last; ++I) { 370 assert(Clusters[I].Kind == CC_Range); 371 Dests.set(Clusters[I].MBB->getNumber()); 372 NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2; 373 } 374 unsigned NumDests = Dests.count(); 375 376 APInt Low = Clusters[First].Low->getValue(); 377 APInt High = Clusters[Last].High->getValue(); 378 assert(Low.slt(High)); 379 380 if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL)) 381 return false; 382 383 APInt LowBound; 384 APInt CmpRange; 385 386 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits(); 387 assert(TLI->rangeFitsInWord(Low, High, *DL) && 388 "Case range must fit in bit mask!"); 389 390 // Check if the clusters cover a contiguous range such that no value in the 391 // range will jump to the default statement. 392 bool ContiguousRange = true; 393 for (int64_t I = First + 1; I <= Last; ++I) { 394 if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) { 395 ContiguousRange = false; 396 break; 397 } 398 } 399 400 if (Low.isStrictlyPositive() && High.slt(BitWidth)) { 401 // Optimize the case where all the case values fit in a word without having 402 // to subtract minValue. In this case, we can optimize away the subtraction. 403 LowBound = APInt::getNullValue(Low.getBitWidth()); 404 CmpRange = High; 405 ContiguousRange = false; 406 } else { 407 LowBound = Low; 408 CmpRange = High - Low; 409 } 410 411 CaseBitsVector CBV; 412 auto TotalProb = BranchProbability::getZero(); 413 for (unsigned i = First; i <= Last; ++i) { 414 // Find the CaseBits for this destination. 415 unsigned j; 416 for (j = 0; j < CBV.size(); ++j) 417 if (CBV[j].BB == Clusters[i].MBB) 418 break; 419 if (j == CBV.size()) 420 CBV.push_back( 421 CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero())); 422 CaseBits *CB = &CBV[j]; 423 424 // Update Mask, Bits and ExtraProb. 425 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue(); 426 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue(); 427 assert(Hi >= Lo && Hi < 64 && "Invalid bit case!"); 428 CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo; 429 CB->Bits += Hi - Lo + 1; 430 CB->ExtraProb += Clusters[i].Prob; 431 TotalProb += Clusters[i].Prob; 432 } 433 434 BitTestInfo BTI; 435 llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) { 436 // Sort by probability first, number of bits second, bit mask third. 437 if (a.ExtraProb != b.ExtraProb) 438 return a.ExtraProb > b.ExtraProb; 439 if (a.Bits != b.Bits) 440 return a.Bits > b.Bits; 441 return a.Mask < b.Mask; 442 }); 443 444 for (auto &CB : CBV) { 445 MachineBasicBlock *BitTestBB = 446 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); 447 BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb)); 448 } 449 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange), 450 SI->getCondition(), -1U, MVT::Other, false, 451 ContiguousRange, nullptr, nullptr, std::move(BTI), 452 TotalProb); 453 454 BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High, 455 BitTestCases.size() - 1, TotalProb); 456 return true; 457 } 458 459 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) { 460 #ifndef NDEBUG 461 for (const CaseCluster &CC : Clusters) 462 assert(CC.Low == CC.High && "Input clusters must be single-case"); 463 #endif 464 465 llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) { 466 return a.Low->getValue().slt(b.Low->getValue()); 467 }); 468 469 // Merge adjacent clusters with the same destination. 470 const unsigned N = Clusters.size(); 471 unsigned DstIndex = 0; 472 for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) { 473 CaseCluster &CC = Clusters[SrcIndex]; 474 const ConstantInt *CaseVal = CC.Low; 475 MachineBasicBlock *Succ = CC.MBB; 476 477 if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ && 478 (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) { 479 // If this case has the same successor and is a neighbour, merge it into 480 // the previous cluster. 481 Clusters[DstIndex - 1].High = CaseVal; 482 Clusters[DstIndex - 1].Prob += CC.Prob; 483 } else { 484 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex], 485 sizeof(Clusters[SrcIndex])); 486 } 487 } 488 Clusters.resize(DstIndex); 489 } 490