1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains switch inst lowering optimizations and utilities for 10 // codegen, so that it can be used for both SelectionDAG and GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/SwitchLoweringUtils.h" 15 #include "llvm/CodeGen/FunctionLoweringInfo.h" 16 #include "llvm/CodeGen/MachineJumpTableInfo.h" 17 #include "llvm/CodeGen/TargetLowering.h" 18 #include "llvm/Target/TargetMachine.h" 19 20 using namespace llvm; 21 using namespace SwitchCG; 22 23 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters, 24 unsigned First, unsigned Last) { 25 assert(Last >= First); 26 const APInt &LowCase = Clusters[First].Low->getValue(); 27 const APInt &HighCase = Clusters[Last].High->getValue(); 28 assert(LowCase.getBitWidth() == HighCase.getBitWidth()); 29 30 // FIXME: A range of consecutive cases has 100% density, but only requires one 31 // comparison to lower. We should discriminate against such consecutive ranges 32 // in jump tables. 33 return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1; 34 } 35 36 uint64_t 37 SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, 38 unsigned First, unsigned Last) { 39 assert(Last >= First); 40 assert(TotalCases[Last] >= TotalCases[First]); 41 uint64_t NumCases = 42 TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); 43 return NumCases; 44 } 45 46 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, 47 const SwitchInst *SI, 48 MachineBasicBlock *DefaultMBB, 49 ProfileSummaryInfo *PSI, 50 BlockFrequencyInfo *BFI) { 51 #ifndef NDEBUG 52 // Clusters must be non-empty, sorted, and only contain Range clusters. 53 assert(!Clusters.empty()); 54 for (CaseCluster &C : Clusters) 55 assert(C.Kind == CC_Range); 56 for (unsigned i = 1, e = Clusters.size(); i < e; ++i) 57 assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); 58 #endif 59 60 assert(TLI && "TLI not set!"); 61 if (!TLI->areJTsAllowed(SI->getParent()->getParent())) 62 return; 63 64 const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries(); 65 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2; 66 67 // Bail if not enough cases. 68 const int64_t N = Clusters.size(); 69 if (N < 2 || N < MinJumpTableEntries) 70 return; 71 72 // Accumulated number of cases in each cluster and those prior to it. 73 SmallVector<unsigned, 8> TotalCases(N); 74 for (unsigned i = 0; i < N; ++i) { 75 const APInt &Hi = Clusters[i].High->getValue(); 76 const APInt &Lo = Clusters[i].Low->getValue(); 77 TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; 78 if (i != 0) 79 TotalCases[i] += TotalCases[i - 1]; 80 } 81 82 uint64_t Range = getJumpTableRange(Clusters,0, N - 1); 83 uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1); 84 assert(NumCases < UINT64_MAX / 100); 85 assert(Range >= NumCases); 86 87 // Cheap case: the whole range may be suitable for jump table. 88 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) { 89 CaseCluster JTCluster; 90 if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { 91 Clusters[0] = JTCluster; 92 Clusters.resize(1); 93 return; 94 } 95 } 96 97 // The algorithm below is not suitable for -O0. 98 if (TM->getOptLevel() == CodeGenOpt::None) 99 return; 100 101 // Split Clusters into minimum number of dense partitions. The algorithm uses 102 // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code 103 // for the Case Statement'" (1994), but builds the MinPartitions array in 104 // reverse order to make it easier to reconstruct the partitions in ascending 105 // order. In the choice between two optimal partitionings, it picks the one 106 // which yields more jump tables. 107 108 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. 109 SmallVector<unsigned, 8> MinPartitions(N); 110 // LastElement[i] is the last element of the partition starting at i. 111 SmallVector<unsigned, 8> LastElement(N); 112 // PartitionsScore[i] is used to break ties when choosing between two 113 // partitionings resulting in the same number of partitions. 114 SmallVector<unsigned, 8> PartitionsScore(N); 115 // For PartitionsScore, a small number of comparisons is considered as good as 116 // a jump table and a single comparison is considered better than a jump 117 // table. 118 enum PartitionScores : unsigned { 119 NoTable = 0, 120 Table = 1, 121 FewCases = 1, 122 SingleCase = 2 123 }; 124 125 // Base case: There is only one way to partition Clusters[N-1]. 126 MinPartitions[N - 1] = 1; 127 LastElement[N - 1] = N - 1; 128 PartitionsScore[N - 1] = PartitionScores::SingleCase; 129 130 // Note: loop indexes are signed to avoid underflow. 131 for (int64_t i = N - 2; i >= 0; i--) { 132 // Find optimal partitioning of Clusters[i..N-1]. 133 // Baseline: Put Clusters[i] into a partition on its own. 134 MinPartitions[i] = MinPartitions[i + 1] + 1; 135 LastElement[i] = i; 136 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase; 137 138 // Search for a solution that results in fewer partitions. 139 for (int64_t j = N - 1; j > i; j--) { 140 // Try building a partition from Clusters[i..j]. 141 Range = getJumpTableRange(Clusters, i, j); 142 NumCases = getJumpTableNumCases(TotalCases, i, j); 143 assert(NumCases < UINT64_MAX / 100); 144 assert(Range >= NumCases); 145 146 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) { 147 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); 148 unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; 149 int64_t NumEntries = j - i + 1; 150 151 if (NumEntries == 1) 152 Score += PartitionScores::SingleCase; 153 else if (NumEntries <= SmallNumberOfEntries) 154 Score += PartitionScores::FewCases; 155 else if (NumEntries >= MinJumpTableEntries) 156 Score += PartitionScores::Table; 157 158 // If this leads to fewer partitions, or to the same number of 159 // partitions with better score, it is a better partitioning. 160 if (NumPartitions < MinPartitions[i] || 161 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) { 162 MinPartitions[i] = NumPartitions; 163 LastElement[i] = j; 164 PartitionsScore[i] = Score; 165 } 166 } 167 } 168 } 169 170 // Iterate over the partitions, replacing some with jump tables in-place. 171 unsigned DstIndex = 0; 172 for (unsigned First = 0, Last; First < N; First = Last + 1) { 173 Last = LastElement[First]; 174 assert(Last >= First); 175 assert(DstIndex <= First); 176 unsigned NumClusters = Last - First + 1; 177 178 CaseCluster JTCluster; 179 if (NumClusters >= MinJumpTableEntries && 180 buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) { 181 Clusters[DstIndex++] = JTCluster; 182 } else { 183 for (unsigned I = First; I <= Last; ++I) 184 std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); 185 } 186 } 187 Clusters.resize(DstIndex); 188 } 189 190 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters, 191 unsigned First, unsigned Last, 192 const SwitchInst *SI, 193 MachineBasicBlock *DefaultMBB, 194 CaseCluster &JTCluster) { 195 assert(First <= Last); 196 197 auto Prob = BranchProbability::getZero(); 198 unsigned NumCmps = 0; 199 std::vector<MachineBasicBlock*> Table; 200 DenseMap<MachineBasicBlock*, BranchProbability> JTProbs; 201 202 // Initialize probabilities in JTProbs. 203 for (unsigned I = First; I <= Last; ++I) 204 JTProbs[Clusters[I].MBB] = BranchProbability::getZero(); 205 206 for (unsigned I = First; I <= Last; ++I) { 207 assert(Clusters[I].Kind == CC_Range); 208 Prob += Clusters[I].Prob; 209 const APInt &Low = Clusters[I].Low->getValue(); 210 const APInt &High = Clusters[I].High->getValue(); 211 NumCmps += (Low == High) ? 1 : 2; 212 if (I != First) { 213 // Fill the gap between this and the previous cluster. 214 const APInt &PreviousHigh = Clusters[I - 1].High->getValue(); 215 assert(PreviousHigh.slt(Low)); 216 uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1; 217 for (uint64_t J = 0; J < Gap; J++) 218 Table.push_back(DefaultMBB); 219 } 220 uint64_t ClusterSize = (High - Low).getLimitedValue() + 1; 221 for (uint64_t J = 0; J < ClusterSize; ++J) 222 Table.push_back(Clusters[I].MBB); 223 JTProbs[Clusters[I].MBB] += Clusters[I].Prob; 224 } 225 226 unsigned NumDests = JTProbs.size(); 227 if (TLI->isSuitableForBitTests(NumDests, NumCmps, 228 Clusters[First].Low->getValue(), 229 Clusters[Last].High->getValue(), *DL)) { 230 // Clusters[First..Last] should be lowered as bit tests instead. 231 return false; 232 } 233 234 // Create the MBB that will load from and jump through the table. 235 // Note: We create it here, but it's not inserted into the function yet. 236 MachineFunction *CurMF = FuncInfo.MF; 237 MachineBasicBlock *JumpTableMBB = 238 CurMF->CreateMachineBasicBlock(SI->getParent()); 239 240 // Add successors. Note: use table order for determinism. 241 SmallPtrSet<MachineBasicBlock *, 8> Done; 242 for (MachineBasicBlock *Succ : Table) { 243 if (Done.count(Succ)) 244 continue; 245 addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]); 246 Done.insert(Succ); 247 } 248 JumpTableMBB->normalizeSuccProbs(); 249 250 unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding()) 251 ->createJumpTableIndex(Table); 252 253 // Set up the jump table info. 254 JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); 255 JumpTableHeader JTH(Clusters[First].Low->getValue(), 256 Clusters[Last].High->getValue(), SI->getCondition(), 257 nullptr, false); 258 JTCases.emplace_back(std::move(JTH), std::move(JT)); 259 260 JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, 261 JTCases.size() - 1, Prob); 262 return true; 263 } 264 265 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters, 266 const SwitchInst *SI) { 267 // Partition Clusters into as few subsets as possible, where each subset has a 268 // range that fits in a machine word and has <= 3 unique destinations. 269 270 #ifndef NDEBUG 271 // Clusters must be sorted and contain Range or JumpTable clusters. 272 assert(!Clusters.empty()); 273 assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable); 274 for (const CaseCluster &C : Clusters) 275 assert(C.Kind == CC_Range || C.Kind == CC_JumpTable); 276 for (unsigned i = 1; i < Clusters.size(); ++i) 277 assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue())); 278 #endif 279 280 // The algorithm below is not suitable for -O0. 281 if (TM->getOptLevel() == CodeGenOpt::None) 282 return; 283 284 // If target does not have legal shift left, do not emit bit tests at all. 285 EVT PTy = TLI->getPointerTy(*DL); 286 if (!TLI->isOperationLegal(ISD::SHL, PTy)) 287 return; 288 289 int BitWidth = PTy.getSizeInBits(); 290 const int64_t N = Clusters.size(); 291 292 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. 293 SmallVector<unsigned, 8> MinPartitions(N); 294 // LastElement[i] is the last element of the partition starting at i. 295 SmallVector<unsigned, 8> LastElement(N); 296 297 // FIXME: This might not be the best algorithm for finding bit test clusters. 298 299 // Base case: There is only one way to partition Clusters[N-1]. 300 MinPartitions[N - 1] = 1; 301 LastElement[N - 1] = N - 1; 302 303 // Note: loop indexes are signed to avoid underflow. 304 for (int64_t i = N - 2; i >= 0; --i) { 305 // Find optimal partitioning of Clusters[i..N-1]. 306 // Baseline: Put Clusters[i] into a partition on its own. 307 MinPartitions[i] = MinPartitions[i + 1] + 1; 308 LastElement[i] = i; 309 310 // Search for a solution that results in fewer partitions. 311 // Note: the search is limited by BitWidth, reducing time complexity. 312 for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) { 313 // Try building a partition from Clusters[i..j]. 314 315 // Check the range. 316 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(), 317 Clusters[j].High->getValue(), *DL)) 318 continue; 319 320 // Check nbr of destinations and cluster types. 321 // FIXME: This works, but doesn't seem very efficient. 322 bool RangesOnly = true; 323 BitVector Dests(FuncInfo.MF->getNumBlockIDs()); 324 for (int64_t k = i; k <= j; k++) { 325 if (Clusters[k].Kind != CC_Range) { 326 RangesOnly = false; 327 break; 328 } 329 Dests.set(Clusters[k].MBB->getNumber()); 330 } 331 if (!RangesOnly || Dests.count() > 3) 332 break; 333 334 // Check if it's a better partition. 335 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); 336 if (NumPartitions < MinPartitions[i]) { 337 // Found a better partition. 338 MinPartitions[i] = NumPartitions; 339 LastElement[i] = j; 340 } 341 } 342 } 343 344 // Iterate over the partitions, replacing with bit-test clusters in-place. 345 unsigned DstIndex = 0; 346 for (unsigned First = 0, Last; First < N; First = Last + 1) { 347 Last = LastElement[First]; 348 assert(First <= Last); 349 assert(DstIndex <= First); 350 351 CaseCluster BitTestCluster; 352 if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) { 353 Clusters[DstIndex++] = BitTestCluster; 354 } else { 355 size_t NumClusters = Last - First + 1; 356 std::memmove(&Clusters[DstIndex], &Clusters[First], 357 sizeof(Clusters[0]) * NumClusters); 358 DstIndex += NumClusters; 359 } 360 } 361 Clusters.resize(DstIndex); 362 } 363 364 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters, 365 unsigned First, unsigned Last, 366 const SwitchInst *SI, 367 CaseCluster &BTCluster) { 368 assert(First <= Last); 369 if (First == Last) 370 return false; 371 372 BitVector Dests(FuncInfo.MF->getNumBlockIDs()); 373 unsigned NumCmps = 0; 374 for (int64_t I = First; I <= Last; ++I) { 375 assert(Clusters[I].Kind == CC_Range); 376 Dests.set(Clusters[I].MBB->getNumber()); 377 NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2; 378 } 379 unsigned NumDests = Dests.count(); 380 381 APInt Low = Clusters[First].Low->getValue(); 382 APInt High = Clusters[Last].High->getValue(); 383 assert(Low.slt(High)); 384 385 if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL)) 386 return false; 387 388 APInt LowBound; 389 APInt CmpRange; 390 391 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits(); 392 assert(TLI->rangeFitsInWord(Low, High, *DL) && 393 "Case range must fit in bit mask!"); 394 395 // Check if the clusters cover a contiguous range such that no value in the 396 // range will jump to the default statement. 397 bool ContiguousRange = true; 398 for (int64_t I = First + 1; I <= Last; ++I) { 399 if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) { 400 ContiguousRange = false; 401 break; 402 } 403 } 404 405 if (Low.isStrictlyPositive() && High.slt(BitWidth)) { 406 // Optimize the case where all the case values fit in a word without having 407 // to subtract minValue. In this case, we can optimize away the subtraction. 408 LowBound = APInt::getNullValue(Low.getBitWidth()); 409 CmpRange = High; 410 ContiguousRange = false; 411 } else { 412 LowBound = Low; 413 CmpRange = High - Low; 414 } 415 416 CaseBitsVector CBV; 417 auto TotalProb = BranchProbability::getZero(); 418 for (unsigned i = First; i <= Last; ++i) { 419 // Find the CaseBits for this destination. 420 unsigned j; 421 for (j = 0; j < CBV.size(); ++j) 422 if (CBV[j].BB == Clusters[i].MBB) 423 break; 424 if (j == CBV.size()) 425 CBV.push_back( 426 CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero())); 427 CaseBits *CB = &CBV[j]; 428 429 // Update Mask, Bits and ExtraProb. 430 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue(); 431 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue(); 432 assert(Hi >= Lo && Hi < 64 && "Invalid bit case!"); 433 CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo; 434 CB->Bits += Hi - Lo + 1; 435 CB->ExtraProb += Clusters[i].Prob; 436 TotalProb += Clusters[i].Prob; 437 } 438 439 BitTestInfo BTI; 440 llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) { 441 // Sort by probability first, number of bits second, bit mask third. 442 if (a.ExtraProb != b.ExtraProb) 443 return a.ExtraProb > b.ExtraProb; 444 if (a.Bits != b.Bits) 445 return a.Bits > b.Bits; 446 return a.Mask < b.Mask; 447 }); 448 449 for (auto &CB : CBV) { 450 MachineBasicBlock *BitTestBB = 451 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); 452 BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb)); 453 } 454 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange), 455 SI->getCondition(), -1U, MVT::Other, false, 456 ContiguousRange, nullptr, nullptr, std::move(BTI), 457 TotalProb); 458 459 BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High, 460 BitTestCases.size() - 1, TotalProb); 461 return true; 462 } 463 464 void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) { 465 #ifndef NDEBUG 466 for (const CaseCluster &CC : Clusters) 467 assert(CC.Low == CC.High && "Input clusters must be single-case"); 468 #endif 469 470 llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) { 471 return a.Low->getValue().slt(b.Low->getValue()); 472 }); 473 474 // Merge adjacent clusters with the same destination. 475 const unsigned N = Clusters.size(); 476 unsigned DstIndex = 0; 477 for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) { 478 CaseCluster &CC = Clusters[SrcIndex]; 479 const ConstantInt *CaseVal = CC.Low; 480 MachineBasicBlock *Succ = CC.MBB; 481 482 if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ && 483 (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) { 484 // If this case has the same successor and is a neighbour, merge it into 485 // the previous cluster. 486 Clusters[DstIndex - 1].High = CaseVal; 487 Clusters[DstIndex - 1].Prob += CC.Prob; 488 } else { 489 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex], 490 sizeof(Clusters[SrcIndex])); 491 } 492 } 493 Clusters.resize(DstIndex); 494 } 495