10b57cec5SDimitry Andric //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // The LowerSwitch transformation rewrites switch instructions with a sequence
100b57cec5SDimitry Andric // of branches, which allows targets to get away with not implementing the
110b57cec5SDimitry Andric // switch instruction until it is convenient.
120b57cec5SDimitry Andric //
130b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
140b57cec5SDimitry Andric
15e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LowerSwitch.h"
160b57cec5SDimitry Andric #include "llvm/ADT/DenseMap.h"
170b57cec5SDimitry Andric #include "llvm/ADT/STLExtras.h"
180b57cec5SDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
190b57cec5SDimitry Andric #include "llvm/ADT/SmallVector.h"
200b57cec5SDimitry Andric #include "llvm/Analysis/AssumptionCache.h"
210b57cec5SDimitry Andric #include "llvm/Analysis/LazyValueInfo.h"
220b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
230b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h"
240b57cec5SDimitry Andric #include "llvm/IR/CFG.h"
250b57cec5SDimitry Andric #include "llvm/IR/ConstantRange.h"
260b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
270b57cec5SDimitry Andric #include "llvm/IR/Function.h"
280b57cec5SDimitry Andric #include "llvm/IR/InstrTypes.h"
290b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
30e8d8bef9SDimitry Andric #include "llvm/IR/PassManager.h"
310b57cec5SDimitry Andric #include "llvm/IR/Value.h"
32480093f4SDimitry Andric #include "llvm/InitializePasses.h"
330b57cec5SDimitry Andric #include "llvm/Pass.h"
340b57cec5SDimitry Andric #include "llvm/Support/Casting.h"
350b57cec5SDimitry Andric #include "llvm/Support/Compiler.h"
360b57cec5SDimitry Andric #include "llvm/Support/Debug.h"
370b57cec5SDimitry Andric #include "llvm/Support/KnownBits.h"
380b57cec5SDimitry Andric #include "llvm/Support/raw_ostream.h"
390b57cec5SDimitry Andric #include "llvm/Transforms/Utils.h"
400b57cec5SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
410b57cec5SDimitry Andric #include <algorithm>
420b57cec5SDimitry Andric #include <cassert>
430b57cec5SDimitry Andric #include <cstdint>
440b57cec5SDimitry Andric #include <iterator>
450b57cec5SDimitry Andric #include <vector>
460b57cec5SDimitry Andric
470b57cec5SDimitry Andric using namespace llvm;
480b57cec5SDimitry Andric
490b57cec5SDimitry Andric #define DEBUG_TYPE "lower-switch"
500b57cec5SDimitry Andric
510b57cec5SDimitry Andric namespace {
520b57cec5SDimitry Andric
530b57cec5SDimitry Andric struct IntRange {
54bdd1243dSDimitry Andric APInt Low, High;
550b57cec5SDimitry Andric };
560b57cec5SDimitry Andric
570b57cec5SDimitry Andric } // end anonymous namespace
580b57cec5SDimitry Andric
59e8d8bef9SDimitry Andric namespace {
600b57cec5SDimitry Andric // Return true iff R is covered by Ranges.
IsInRanges(const IntRange & R,const std::vector<IntRange> & Ranges)61e8d8bef9SDimitry Andric bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) {
620b57cec5SDimitry Andric // Note: Ranges must be sorted, non-overlapping and non-adjacent.
630b57cec5SDimitry Andric
640b57cec5SDimitry Andric // Find the first range whose High field is >= R.High,
650b57cec5SDimitry Andric // then check if the Low field is <= R.Low. If so, we
660b57cec5SDimitry Andric // have a Range that covers R.
670b57cec5SDimitry Andric auto I = llvm::lower_bound(
68bdd1243dSDimitry Andric Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); });
69bdd1243dSDimitry Andric return I != Ranges.end() && I->Low.sle(R.Low);
700b57cec5SDimitry Andric }
710b57cec5SDimitry Andric
720b57cec5SDimitry Andric struct CaseRange {
730b57cec5SDimitry Andric ConstantInt *Low;
740b57cec5SDimitry Andric ConstantInt *High;
750b57cec5SDimitry Andric BasicBlock *BB;
760b57cec5SDimitry Andric
CaseRange__anon9b3ab8e90211::CaseRange770b57cec5SDimitry Andric CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb)
780b57cec5SDimitry Andric : Low(low), High(high), BB(bb) {}
790b57cec5SDimitry Andric };
800b57cec5SDimitry Andric
810b57cec5SDimitry Andric using CaseVector = std::vector<CaseRange>;
820b57cec5SDimitry Andric using CaseItr = std::vector<CaseRange>::iterator;
830b57cec5SDimitry Andric
840b57cec5SDimitry Andric /// The comparison function for sorting the switch case values in the vector.
850b57cec5SDimitry Andric /// WARNING: Case ranges should be disjoint!
860b57cec5SDimitry Andric struct CaseCmp {
operator ()__anon9b3ab8e90211::CaseCmp87e8d8bef9SDimitry Andric bool operator()(const CaseRange &C1, const CaseRange &C2) {
880b57cec5SDimitry Andric const ConstantInt *CI1 = cast<const ConstantInt>(C1.Low);
890b57cec5SDimitry Andric const ConstantInt *CI2 = cast<const ConstantInt>(C2.High);
900b57cec5SDimitry Andric return CI1->getValue().slt(CI2->getValue());
910b57cec5SDimitry Andric }
920b57cec5SDimitry Andric };
930b57cec5SDimitry Andric
940b57cec5SDimitry Andric /// Used for debugging purposes.
950b57cec5SDimitry Andric LLVM_ATTRIBUTE_USED
operator <<(raw_ostream & O,const CaseVector & C)96e8d8bef9SDimitry Andric raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) {
970b57cec5SDimitry Andric O << "[";
980b57cec5SDimitry Andric
99e8d8bef9SDimitry Andric for (CaseVector::const_iterator B = C.begin(), E = C.end(); B != E;) {
1000b57cec5SDimitry Andric O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]";
1010b57cec5SDimitry Andric if (++B != E)
1020b57cec5SDimitry Andric O << ", ";
1030b57cec5SDimitry Andric }
1040b57cec5SDimitry Andric
1050b57cec5SDimitry Andric return O << "]";
1060b57cec5SDimitry Andric }
1070b57cec5SDimitry Andric
1080b57cec5SDimitry Andric /// Update the first occurrence of the "switch statement" BB in the PHI
1090b57cec5SDimitry Andric /// node with the "new" BB. The other occurrences will:
1100b57cec5SDimitry Andric ///
1110b57cec5SDimitry Andric /// 1) Be updated by subsequent calls to this function. Switch statements may
1120b57cec5SDimitry Andric /// have more than one outcoming edge into the same BB if they all have the same
1130b57cec5SDimitry Andric /// value. When the switch statement is converted these incoming edges are now
1140b57cec5SDimitry Andric /// coming from multiple BBs.
1150b57cec5SDimitry Andric /// 2) Removed if subsequent incoming values now share the same case, i.e.,
1160b57cec5SDimitry Andric /// multiple outcome edges are condensed into one. This is necessary to keep the
1170b57cec5SDimitry Andric /// number of phi values equal to the number of branches to SuccBB.
FixPhis(BasicBlock * SuccBB,BasicBlock * OrigBB,BasicBlock * NewBB,const APInt & NumMergedCases)118bdd1243dSDimitry Andric void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
119bdd1243dSDimitry Andric const APInt &NumMergedCases) {
12081ad6265SDimitry Andric for (auto &I : SuccBB->phis()) {
12181ad6265SDimitry Andric PHINode *PN = cast<PHINode>(&I);
1220b57cec5SDimitry Andric
12381ad6265SDimitry Andric // Only update the first occurrence if NewBB exists.
1240b57cec5SDimitry Andric unsigned Idx = 0, E = PN->getNumIncomingValues();
125bdd1243dSDimitry Andric APInt LocalNumMergedCases = NumMergedCases;
12681ad6265SDimitry Andric for (; Idx != E && NewBB; ++Idx) {
1270b57cec5SDimitry Andric if (PN->getIncomingBlock(Idx) == OrigBB) {
1280b57cec5SDimitry Andric PN->setIncomingBlock(Idx, NewBB);
1290b57cec5SDimitry Andric break;
1300b57cec5SDimitry Andric }
1310b57cec5SDimitry Andric }
1320b57cec5SDimitry Andric
13381ad6265SDimitry Andric // Skip the updated incoming block so that it will not be removed.
13481ad6265SDimitry Andric if (NewBB)
13581ad6265SDimitry Andric ++Idx;
13681ad6265SDimitry Andric
1370b57cec5SDimitry Andric // Remove additional occurrences coming from condensed cases and keep the
1380b57cec5SDimitry Andric // number of incoming values equal to the number of branches to SuccBB.
1390b57cec5SDimitry Andric SmallVector<unsigned, 8> Indices;
140bdd1243dSDimitry Andric for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx)
1410b57cec5SDimitry Andric if (PN->getIncomingBlock(Idx) == OrigBB) {
1420b57cec5SDimitry Andric Indices.push_back(Idx);
143bdd1243dSDimitry Andric LocalNumMergedCases -= 1;
1440b57cec5SDimitry Andric }
1450b57cec5SDimitry Andric // Remove incoming values in the reverse order to prevent invalidating
1460b57cec5SDimitry Andric // *successive* index.
1470b57cec5SDimitry Andric for (unsigned III : llvm::reverse(Indices))
1480b57cec5SDimitry Andric PN->removeIncomingValue(III);
1490b57cec5SDimitry Andric }
1500b57cec5SDimitry Andric }
1510b57cec5SDimitry Andric
152e8d8bef9SDimitry Andric /// Create a new leaf block for the binary lookup tree. It checks if the
153e8d8bef9SDimitry Andric /// switch's value == the case's value. If not, then it jumps to the default
154e8d8bef9SDimitry Andric /// branch. At this point in the tree, the value can't be another valid case
155e8d8bef9SDimitry Andric /// value, so the jump to the "default" branch is warranted.
NewLeafBlock(CaseRange & Leaf,Value * Val,ConstantInt * LowerBound,ConstantInt * UpperBound,BasicBlock * OrigBlock,BasicBlock * Default)156e8d8bef9SDimitry Andric BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
157e8d8bef9SDimitry Andric ConstantInt *UpperBound, BasicBlock *OrigBlock,
158e8d8bef9SDimitry Andric BasicBlock *Default) {
159e8d8bef9SDimitry Andric Function *F = OrigBlock->getParent();
160e8d8bef9SDimitry Andric BasicBlock *NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock");
161bdd1243dSDimitry Andric F->insert(++OrigBlock->getIterator(), NewLeaf);
162e8d8bef9SDimitry Andric
163e8d8bef9SDimitry Andric // Emit comparison
164e8d8bef9SDimitry Andric ICmpInst *Comp = nullptr;
165e8d8bef9SDimitry Andric if (Leaf.Low == Leaf.High) {
166e8d8bef9SDimitry Andric // Make the seteq instruction...
167e8d8bef9SDimitry Andric Comp =
168*0fca6ea1SDimitry Andric new ICmpInst(NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low, "SwitchLeaf");
169e8d8bef9SDimitry Andric } else {
170e8d8bef9SDimitry Andric // Make range comparison
171e8d8bef9SDimitry Andric if (Leaf.Low == LowerBound) {
172e8d8bef9SDimitry Andric // Val >= Min && Val <= Hi --> Val <= Hi
173*0fca6ea1SDimitry Andric Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
174e8d8bef9SDimitry Andric "SwitchLeaf");
175e8d8bef9SDimitry Andric } else if (Leaf.High == UpperBound) {
176e8d8bef9SDimitry Andric // Val <= Max && Val >= Lo --> Val >= Lo
177*0fca6ea1SDimitry Andric Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low,
178e8d8bef9SDimitry Andric "SwitchLeaf");
179e8d8bef9SDimitry Andric } else if (Leaf.Low->isZero()) {
180e8d8bef9SDimitry Andric // Val >= 0 && Val <= Hi --> Val <=u Hi
181*0fca6ea1SDimitry Andric Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
182e8d8bef9SDimitry Andric "SwitchLeaf");
183e8d8bef9SDimitry Andric } else {
184e8d8bef9SDimitry Andric // Emit V-Lo <=u Hi-Lo
185e8d8bef9SDimitry Andric Constant *NegLo = ConstantExpr::getNeg(Leaf.Low);
186e8d8bef9SDimitry Andric Instruction *Add = BinaryOperator::CreateAdd(
187e8d8bef9SDimitry Andric Val, NegLo, Val->getName() + ".off", NewLeaf);
188e8d8bef9SDimitry Andric Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
189*0fca6ea1SDimitry Andric Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound,
190e8d8bef9SDimitry Andric "SwitchLeaf");
191e8d8bef9SDimitry Andric }
192e8d8bef9SDimitry Andric }
193e8d8bef9SDimitry Andric
194e8d8bef9SDimitry Andric // Make the conditional branch...
195e8d8bef9SDimitry Andric BasicBlock *Succ = Leaf.BB;
196e8d8bef9SDimitry Andric BranchInst::Create(Succ, Default, Comp, NewLeaf);
197e8d8bef9SDimitry Andric
19881ad6265SDimitry Andric // Update the PHI incoming value/block for the default.
19981ad6265SDimitry Andric for (auto &I : Default->phis()) {
20081ad6265SDimitry Andric PHINode *PN = cast<PHINode>(&I);
20181ad6265SDimitry Andric auto *V = PN->getIncomingValueForBlock(OrigBlock);
20281ad6265SDimitry Andric PN->addIncoming(V, NewLeaf);
20381ad6265SDimitry Andric }
20481ad6265SDimitry Andric
205e8d8bef9SDimitry Andric // If there were any PHI nodes in this successor, rewrite one entry
206e8d8bef9SDimitry Andric // from OrigBlock to come from NewLeaf.
207e8d8bef9SDimitry Andric for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
208e8d8bef9SDimitry Andric PHINode *PN = cast<PHINode>(I);
209e8d8bef9SDimitry Andric // Remove all but one incoming entries from the cluster
210bdd1243dSDimitry Andric APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
211*0fca6ea1SDimitry Andric for (APInt j(Range.getBitWidth(), 0, false); j.ult(Range); ++j) {
212e8d8bef9SDimitry Andric PN->removeIncomingValue(OrigBlock);
213e8d8bef9SDimitry Andric }
214e8d8bef9SDimitry Andric
215e8d8bef9SDimitry Andric int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
216e8d8bef9SDimitry Andric assert(BlockIdx != -1 && "Switch didn't go to this successor??");
217e8d8bef9SDimitry Andric PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
218e8d8bef9SDimitry Andric }
219e8d8bef9SDimitry Andric
220e8d8bef9SDimitry Andric return NewLeaf;
221e8d8bef9SDimitry Andric }
222e8d8bef9SDimitry Andric
2230b57cec5SDimitry Andric /// Convert the switch statement into a binary lookup of the case values.
2240b57cec5SDimitry Andric /// The function recursively builds this tree. LowerBound and UpperBound are
2250b57cec5SDimitry Andric /// used to keep track of the bounds for Val that have already been checked by
2260b57cec5SDimitry Andric /// a block emitted by one of the previous calls to switchConvert in the call
2270b57cec5SDimitry Andric /// stack.
SwitchConvert(CaseItr Begin,CaseItr End,ConstantInt * LowerBound,ConstantInt * UpperBound,Value * Val,BasicBlock * Predecessor,BasicBlock * OrigBlock,BasicBlock * Default,const std::vector<IntRange> & UnreachableRanges)228e8d8bef9SDimitry Andric BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
2290b57cec5SDimitry Andric ConstantInt *UpperBound, Value *Val,
2300b57cec5SDimitry Andric BasicBlock *Predecessor, BasicBlock *OrigBlock,
2310b57cec5SDimitry Andric BasicBlock *Default,
2320b57cec5SDimitry Andric const std::vector<IntRange> &UnreachableRanges) {
2330b57cec5SDimitry Andric assert(LowerBound && UpperBound && "Bounds must be initialized");
2340b57cec5SDimitry Andric unsigned Size = End - Begin;
2350b57cec5SDimitry Andric
2360b57cec5SDimitry Andric if (Size == 1) {
2370b57cec5SDimitry Andric // Check if the Case Range is perfectly squeezed in between
2380b57cec5SDimitry Andric // already checked Upper and Lower bounds. If it is then we can avoid
2390b57cec5SDimitry Andric // emitting the code that checks if the value actually falls in the range
2400b57cec5SDimitry Andric // because the bounds already tell us so.
2410b57cec5SDimitry Andric if (Begin->Low == LowerBound && Begin->High == UpperBound) {
242bdd1243dSDimitry Andric APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue();
243e8d8bef9SDimitry Andric FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
2440b57cec5SDimitry Andric return Begin->BB;
2450b57cec5SDimitry Andric }
246e8d8bef9SDimitry Andric return NewLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock,
2470b57cec5SDimitry Andric Default);
2480b57cec5SDimitry Andric }
2490b57cec5SDimitry Andric
2500b57cec5SDimitry Andric unsigned Mid = Size / 2;
2510b57cec5SDimitry Andric std::vector<CaseRange> LHS(Begin, Begin + Mid);
2520b57cec5SDimitry Andric LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n");
2530b57cec5SDimitry Andric std::vector<CaseRange> RHS(Begin + Mid, End);
2540b57cec5SDimitry Andric LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n");
2550b57cec5SDimitry Andric
2560b57cec5SDimitry Andric CaseRange &Pivot = *(Begin + Mid);
2570b57cec5SDimitry Andric LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", "
2580b57cec5SDimitry Andric << Pivot.High->getValue() << "]\n");
2590b57cec5SDimitry Andric
2600b57cec5SDimitry Andric // NewLowerBound here should never be the integer minimal value.
2610b57cec5SDimitry Andric // This is because it is computed from a case range that is never
2620b57cec5SDimitry Andric // the smallest, so there is always a case range that has at least
2630b57cec5SDimitry Andric // a smaller value.
2640b57cec5SDimitry Andric ConstantInt *NewLowerBound = Pivot.Low;
2650b57cec5SDimitry Andric
2660b57cec5SDimitry Andric // Because NewLowerBound is never the smallest representable integer
2670b57cec5SDimitry Andric // it is safe here to subtract one.
2680b57cec5SDimitry Andric ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(),
2690b57cec5SDimitry Andric NewLowerBound->getValue() - 1);
2700b57cec5SDimitry Andric
2710b57cec5SDimitry Andric if (!UnreachableRanges.empty()) {
2720b57cec5SDimitry Andric // Check if the gap between LHS's highest and NewLowerBound is unreachable.
273bdd1243dSDimitry Andric APInt GapLow = LHS.back().High->getValue() + 1;
274bdd1243dSDimitry Andric APInt GapHigh = NewLowerBound->getValue() - 1;
2750b57cec5SDimitry Andric IntRange Gap = {GapLow, GapHigh};
276bdd1243dSDimitry Andric if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
2770b57cec5SDimitry Andric NewUpperBound = LHS.back().High;
2780b57cec5SDimitry Andric }
2790b57cec5SDimitry Andric
280bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", "
281bdd1243dSDimitry Andric << NewUpperBound->getValue() << "]\n"
282bdd1243dSDimitry Andric << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", "
283bdd1243dSDimitry Andric << UpperBound->getValue() << "]\n");
2840b57cec5SDimitry Andric
2850b57cec5SDimitry Andric // Create a new node that checks if the value is < pivot. Go to the
2860b57cec5SDimitry Andric // left branch if it is and right branch if not.
2870b57cec5SDimitry Andric Function *F = OrigBlock->getParent();
2880b57cec5SDimitry Andric BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
2890b57cec5SDimitry Andric
290bdd1243dSDimitry Andric ICmpInst *Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot");
2910b57cec5SDimitry Andric
292e8d8bef9SDimitry Andric BasicBlock *LBranch =
293e8d8bef9SDimitry Andric SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val,
294e8d8bef9SDimitry Andric NewNode, OrigBlock, Default, UnreachableRanges);
295e8d8bef9SDimitry Andric BasicBlock *RBranch =
296e8d8bef9SDimitry Andric SwitchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val,
297e8d8bef9SDimitry Andric NewNode, OrigBlock, Default, UnreachableRanges);
2980b57cec5SDimitry Andric
299bdd1243dSDimitry Andric F->insert(++OrigBlock->getIterator(), NewNode);
300bdd1243dSDimitry Andric Comp->insertInto(NewNode, NewNode->end());
3010b57cec5SDimitry Andric
3020b57cec5SDimitry Andric BranchInst::Create(LBranch, RBranch, Comp, NewNode);
3030b57cec5SDimitry Andric return NewNode;
3040b57cec5SDimitry Andric }
3050b57cec5SDimitry Andric
3060b57cec5SDimitry Andric /// Transform simple list of \p SI's cases into list of CaseRange's \p Cases.
3070b57cec5SDimitry Andric /// \post \p Cases wouldn't contain references to \p SI's default BB.
3080b57cec5SDimitry Andric /// \returns Number of \p SI's cases that do not reference \p SI's default BB.
Clusterify(CaseVector & Cases,SwitchInst * SI)309e8d8bef9SDimitry Andric unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) {
3100b57cec5SDimitry Andric unsigned NumSimpleCases = 0;
3110b57cec5SDimitry Andric
3120b57cec5SDimitry Andric // Start with "simple" cases
3130b57cec5SDimitry Andric for (auto Case : SI->cases()) {
3140b57cec5SDimitry Andric if (Case.getCaseSuccessor() == SI->getDefaultDest())
3150b57cec5SDimitry Andric continue;
3160b57cec5SDimitry Andric Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
3170b57cec5SDimitry Andric Case.getCaseSuccessor()));
3180b57cec5SDimitry Andric ++NumSimpleCases;
3190b57cec5SDimitry Andric }
3200b57cec5SDimitry Andric
3210b57cec5SDimitry Andric llvm::sort(Cases, CaseCmp());
3220b57cec5SDimitry Andric
3230b57cec5SDimitry Andric // Merge case into clusters
3240b57cec5SDimitry Andric if (Cases.size() >= 2) {
3250b57cec5SDimitry Andric CaseItr I = Cases.begin();
3260b57cec5SDimitry Andric for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
327bdd1243dSDimitry Andric const APInt &nextValue = J->Low->getValue();
328bdd1243dSDimitry Andric const APInt ¤tValue = I->High->getValue();
3290b57cec5SDimitry Andric BasicBlock *nextBB = J->BB;
3300b57cec5SDimitry Andric BasicBlock *currentBB = I->BB;
3310b57cec5SDimitry Andric
3320b57cec5SDimitry Andric // If the two neighboring cases go to the same destination, merge them
3330b57cec5SDimitry Andric // into a single case.
334bdd1243dSDimitry Andric assert(nextValue.sgt(currentValue) &&
335bdd1243dSDimitry Andric "Cases should be strictly ascending");
3360b57cec5SDimitry Andric if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
3370b57cec5SDimitry Andric I->High = J->High;
3380b57cec5SDimitry Andric // FIXME: Combine branch weights.
3390b57cec5SDimitry Andric } else if (++I != J) {
3400b57cec5SDimitry Andric *I = *J;
3410b57cec5SDimitry Andric }
3420b57cec5SDimitry Andric }
3430b57cec5SDimitry Andric Cases.erase(std::next(I), Cases.end());
3440b57cec5SDimitry Andric }
3450b57cec5SDimitry Andric
3460b57cec5SDimitry Andric return NumSimpleCases;
3470b57cec5SDimitry Andric }
3480b57cec5SDimitry Andric
3490b57cec5SDimitry Andric /// Replace the specified switch instruction with a sequence of chained if-then
3500b57cec5SDimitry Andric /// insts in a balanced binary search.
ProcessSwitchInst(SwitchInst * SI,SmallPtrSetImpl<BasicBlock * > & DeleteList,AssumptionCache * AC,LazyValueInfo * LVI)351e8d8bef9SDimitry Andric void ProcessSwitchInst(SwitchInst *SI,
3520b57cec5SDimitry Andric SmallPtrSetImpl<BasicBlock *> &DeleteList,
3530b57cec5SDimitry Andric AssumptionCache *AC, LazyValueInfo *LVI) {
3540b57cec5SDimitry Andric BasicBlock *OrigBlock = SI->getParent();
3550b57cec5SDimitry Andric Function *F = OrigBlock->getParent();
3560b57cec5SDimitry Andric Value *Val = SI->getCondition(); // The value we are switching on...
3570b57cec5SDimitry Andric BasicBlock *Default = SI->getDefaultDest();
3580b57cec5SDimitry Andric
3590b57cec5SDimitry Andric // Don't handle unreachable blocks. If there are successors with phis, this
3600b57cec5SDimitry Andric // would leave them behind with missing predecessors.
3610b57cec5SDimitry Andric if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) ||
3620b57cec5SDimitry Andric OrigBlock->getSinglePredecessor() == OrigBlock) {
3630b57cec5SDimitry Andric DeleteList.insert(OrigBlock);
3640b57cec5SDimitry Andric return;
3650b57cec5SDimitry Andric }
3660b57cec5SDimitry Andric
3670b57cec5SDimitry Andric // Prepare cases vector.
3680b57cec5SDimitry Andric CaseVector Cases;
3690b57cec5SDimitry Andric const unsigned NumSimpleCases = Clusterify(Cases, SI);
370bdd1243dSDimitry Andric IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
371bdd1243dSDimitry Andric const unsigned BitWidth = IT->getBitWidth();
372*0fca6ea1SDimitry Andric // Explicitly use higher precision to prevent unsigned overflow where
373bdd1243dSDimitry Andric // `UnsignedMax - 0 + 1 == 0`
374bdd1243dSDimitry Andric APInt UnsignedZero(BitWidth + 1, 0);
375bdd1243dSDimitry Andric APInt UnsignedMax = APInt::getMaxValue(BitWidth);
3760b57cec5SDimitry Andric LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
3770b57cec5SDimitry Andric << ". Total non-default cases: " << NumSimpleCases
3780b57cec5SDimitry Andric << "\nCase clusters: " << Cases << "\n");
3790b57cec5SDimitry Andric
3800b57cec5SDimitry Andric // If there is only the default destination, just branch.
3810b57cec5SDimitry Andric if (Cases.empty()) {
3820b57cec5SDimitry Andric BranchInst::Create(Default, OrigBlock);
3830b57cec5SDimitry Andric // Remove all the references from Default's PHIs to OrigBlock, but one.
384bdd1243dSDimitry Andric FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax);
3850b57cec5SDimitry Andric SI->eraseFromParent();
3860b57cec5SDimitry Andric return;
3870b57cec5SDimitry Andric }
3880b57cec5SDimitry Andric
3890b57cec5SDimitry Andric ConstantInt *LowerBound = nullptr;
3900b57cec5SDimitry Andric ConstantInt *UpperBound = nullptr;
3910b57cec5SDimitry Andric bool DefaultIsUnreachableFromSwitch = false;
3920b57cec5SDimitry Andric
3930b57cec5SDimitry Andric if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) {
3940b57cec5SDimitry Andric // Make the bounds tightly fitted around the case value range, because we
3950b57cec5SDimitry Andric // know that the value passed to the switch must be exactly one of the case
3960b57cec5SDimitry Andric // values.
3970b57cec5SDimitry Andric LowerBound = Cases.front().Low;
3980b57cec5SDimitry Andric UpperBound = Cases.back().High;
3990b57cec5SDimitry Andric DefaultIsUnreachableFromSwitch = true;
4000b57cec5SDimitry Andric } else {
4010b57cec5SDimitry Andric // Constraining the range of the value being switched over helps eliminating
4020b57cec5SDimitry Andric // unreachable BBs and minimizing the number of `add` instructions
4030b57cec5SDimitry Andric // newLeafBlock ends up emitting. Running CorrelatedValuePropagation after
4040b57cec5SDimitry Andric // LowerSwitch isn't as good, and also much more expensive in terms of
4050b57cec5SDimitry Andric // compile time for the following reasons:
4060b57cec5SDimitry Andric // 1. it processes many kinds of instructions, not just switches;
4070b57cec5SDimitry Andric // 2. even if limited to icmp instructions only, it will have to process
4080b57cec5SDimitry Andric // roughly C icmp's per switch, where C is the number of cases in the
4090b57cec5SDimitry Andric // switch, while LowerSwitch only needs to call LVI once per switch.
410*0fca6ea1SDimitry Andric const DataLayout &DL = F->getDataLayout();
4110b57cec5SDimitry Andric KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI);
4120b57cec5SDimitry Andric // TODO Shouldn't this create a signed range?
4130b57cec5SDimitry Andric ConstantRange KnownBitsRange =
4140b57cec5SDimitry Andric ConstantRange::fromKnownBits(Known, /*IsSigned=*/false);
4155f757f3fSDimitry Andric const ConstantRange LVIRange =
4165f757f3fSDimitry Andric LVI->getConstantRange(Val, SI, /*UndefAllowed*/ false);
4170b57cec5SDimitry Andric ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange);
4180b57cec5SDimitry Andric // We delegate removal of unreachable non-default cases to other passes. In
4190b57cec5SDimitry Andric // the unlikely event that some of them survived, we just conservatively
4200b57cec5SDimitry Andric // maintain the invariant that all the cases lie between the bounds. This
4210b57cec5SDimitry Andric // may, however, still render the default case effectively unreachable.
422bdd1243dSDimitry Andric const APInt &Low = Cases.front().Low->getValue();
423bdd1243dSDimitry Andric const APInt &High = Cases.back().High->getValue();
4240b57cec5SDimitry Andric APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low);
4250b57cec5SDimitry Andric APInt Max = APIntOps::smax(ValRange.getSignedMax(), High);
4260b57cec5SDimitry Andric
4270b57cec5SDimitry Andric LowerBound = ConstantInt::get(SI->getContext(), Min);
4280b57cec5SDimitry Andric UpperBound = ConstantInt::get(SI->getContext(), Max);
4290b57cec5SDimitry Andric DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max);
4300b57cec5SDimitry Andric }
4310b57cec5SDimitry Andric
4320b57cec5SDimitry Andric std::vector<IntRange> UnreachableRanges;
4330b57cec5SDimitry Andric
4340b57cec5SDimitry Andric if (DefaultIsUnreachableFromSwitch) {
435bdd1243dSDimitry Andric DenseMap<BasicBlock *, APInt> Popularity;
436bdd1243dSDimitry Andric APInt MaxPop(UnsignedZero);
4370b57cec5SDimitry Andric BasicBlock *PopSucc = nullptr;
4380b57cec5SDimitry Andric
439bdd1243dSDimitry Andric APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
440bdd1243dSDimitry Andric APInt SignedMin = APInt::getSignedMinValue(BitWidth);
441bdd1243dSDimitry Andric IntRange R = {SignedMin, SignedMax};
4420b57cec5SDimitry Andric UnreachableRanges.push_back(R);
4430b57cec5SDimitry Andric for (const auto &I : Cases) {
444bdd1243dSDimitry Andric const APInt &Low = I.Low->getValue();
445bdd1243dSDimitry Andric const APInt &High = I.High->getValue();
4460b57cec5SDimitry Andric
4470b57cec5SDimitry Andric IntRange &LastRange = UnreachableRanges.back();
448bdd1243dSDimitry Andric if (LastRange.Low.eq(Low)) {
4490b57cec5SDimitry Andric // There is nothing left of the previous range.
4500b57cec5SDimitry Andric UnreachableRanges.pop_back();
4510b57cec5SDimitry Andric } else {
4520b57cec5SDimitry Andric // Terminate the previous range.
453bdd1243dSDimitry Andric assert(Low.sgt(LastRange.Low));
4540b57cec5SDimitry Andric LastRange.High = Low - 1;
4550b57cec5SDimitry Andric }
456bdd1243dSDimitry Andric if (High.ne(SignedMax)) {
457bdd1243dSDimitry Andric IntRange R = {High + 1, SignedMax};
4580b57cec5SDimitry Andric UnreachableRanges.push_back(R);
4590b57cec5SDimitry Andric }
4600b57cec5SDimitry Andric
4610b57cec5SDimitry Andric // Count popularity.
462bdd1243dSDimitry Andric assert(High.sge(Low) && "Popularity shouldn't be negative.");
463bdd1243dSDimitry Andric APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1;
464bdd1243dSDimitry Andric // Explict insert to make sure the bitwidth of APInts match
465bdd1243dSDimitry Andric APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second;
466bdd1243dSDimitry Andric if ((Pop += N).ugt(MaxPop)) {
4670b57cec5SDimitry Andric MaxPop = Pop;
4680b57cec5SDimitry Andric PopSucc = I.BB;
4690b57cec5SDimitry Andric }
4700b57cec5SDimitry Andric }
4710b57cec5SDimitry Andric #ifndef NDEBUG
4720b57cec5SDimitry Andric /* UnreachableRanges should be sorted and the ranges non-adjacent. */
4730b57cec5SDimitry Andric for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
4740b57cec5SDimitry Andric I != E; ++I) {
475bdd1243dSDimitry Andric assert(I->Low.sle(I->High));
4760b57cec5SDimitry Andric auto Next = I + 1;
4770b57cec5SDimitry Andric if (Next != E) {
478bdd1243dSDimitry Andric assert(Next->Low.sgt(I->High));
4790b57cec5SDimitry Andric }
4800b57cec5SDimitry Andric }
4810b57cec5SDimitry Andric #endif
4820b57cec5SDimitry Andric
4830b57cec5SDimitry Andric // As the default block in the switch is unreachable, update the PHI nodes
4840b57cec5SDimitry Andric // (remove all of the references to the default block) to reflect this.
4850b57cec5SDimitry Andric const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases;
4860b57cec5SDimitry Andric for (unsigned I = 0; I < NumDefaultEdges; ++I)
4870b57cec5SDimitry Andric Default->removePredecessor(OrigBlock);
4880b57cec5SDimitry Andric
4890b57cec5SDimitry Andric // Use the most popular block as the new default, reducing the number of
4900b57cec5SDimitry Andric // cases.
4910b57cec5SDimitry Andric Default = PopSucc;
492e8d8bef9SDimitry Andric llvm::erase_if(Cases,
493e8d8bef9SDimitry Andric [PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
4940b57cec5SDimitry Andric
4950b57cec5SDimitry Andric // If there are no cases left, just branch.
4960b57cec5SDimitry Andric if (Cases.empty()) {
4970b57cec5SDimitry Andric BranchInst::Create(Default, OrigBlock);
4980b57cec5SDimitry Andric SI->eraseFromParent();
4990b57cec5SDimitry Andric // As all the cases have been replaced with a single branch, only keep
5000b57cec5SDimitry Andric // one entry in the PHI nodes.
501bdd1243dSDimitry Andric if (!MaxPop.isZero())
502bdd1243dSDimitry Andric for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I)
5030b57cec5SDimitry Andric PopSucc->removePredecessor(OrigBlock);
5040b57cec5SDimitry Andric return;
5050b57cec5SDimitry Andric }
5060b57cec5SDimitry Andric
5070b57cec5SDimitry Andric // If the condition was a PHI node with the switch block as a predecessor
5080b57cec5SDimitry Andric // removing predecessors may have caused the condition to be erased.
5090b57cec5SDimitry Andric // Getting the condition value again here protects against that.
5100b57cec5SDimitry Andric Val = SI->getCondition();
5110b57cec5SDimitry Andric }
5120b57cec5SDimitry Andric
5130b57cec5SDimitry Andric BasicBlock *SwitchBlock =
514e8d8bef9SDimitry Andric SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
51581ad6265SDimitry Andric OrigBlock, OrigBlock, Default, UnreachableRanges);
5160b57cec5SDimitry Andric
51781ad6265SDimitry Andric // We have added incoming values for newly-created predecessors in
51881ad6265SDimitry Andric // NewLeafBlock(). The only meaningful work we offload to FixPhis() is to
51981ad6265SDimitry Andric // remove the incoming values from OrigBlock. There might be a special case
52081ad6265SDimitry Andric // that SwitchBlock is the same as Default, under which the PHIs in Default
52181ad6265SDimitry Andric // are fixed inside SwitchConvert().
52281ad6265SDimitry Andric if (SwitchBlock != Default)
523bdd1243dSDimitry Andric FixPhis(Default, OrigBlock, nullptr, UnsignedMax);
5240b57cec5SDimitry Andric
5250b57cec5SDimitry Andric // Branch to our shiny new if-then stuff...
5260b57cec5SDimitry Andric BranchInst::Create(SwitchBlock, OrigBlock);
5270b57cec5SDimitry Andric
5280b57cec5SDimitry Andric // We are now done with the switch instruction, delete it.
5290b57cec5SDimitry Andric BasicBlock *OldDefault = SI->getDefaultDest();
530bdd1243dSDimitry Andric SI->eraseFromParent();
5310b57cec5SDimitry Andric
5320b57cec5SDimitry Andric // If the Default block has no more predecessors just add it to DeleteList.
533e8d8bef9SDimitry Andric if (pred_empty(OldDefault))
5340b57cec5SDimitry Andric DeleteList.insert(OldDefault);
5350b57cec5SDimitry Andric }
536e8d8bef9SDimitry Andric
LowerSwitch(Function & F,LazyValueInfo * LVI,AssumptionCache * AC)537e8d8bef9SDimitry Andric bool LowerSwitch(Function &F, LazyValueInfo *LVI, AssumptionCache *AC) {
538e8d8bef9SDimitry Andric bool Changed = false;
539e8d8bef9SDimitry Andric SmallPtrSet<BasicBlock *, 8> DeleteList;
540e8d8bef9SDimitry Andric
541349cc55cSDimitry Andric // We use make_early_inc_range here so that we don't traverse new blocks.
542349cc55cSDimitry Andric for (BasicBlock &Cur : llvm::make_early_inc_range(F)) {
543e8d8bef9SDimitry Andric // If the block is a dead Default block that will be deleted later, don't
544e8d8bef9SDimitry Andric // waste time processing it.
545349cc55cSDimitry Andric if (DeleteList.count(&Cur))
546e8d8bef9SDimitry Andric continue;
547e8d8bef9SDimitry Andric
548349cc55cSDimitry Andric if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur.getTerminator())) {
549e8d8bef9SDimitry Andric Changed = true;
550e8d8bef9SDimitry Andric ProcessSwitchInst(SI, DeleteList, AC, LVI);
551e8d8bef9SDimitry Andric }
552e8d8bef9SDimitry Andric }
553e8d8bef9SDimitry Andric
554e8d8bef9SDimitry Andric for (BasicBlock *BB : DeleteList) {
555e8d8bef9SDimitry Andric LVI->eraseBlock(BB);
556e8d8bef9SDimitry Andric DeleteDeadBlock(BB);
557e8d8bef9SDimitry Andric }
558e8d8bef9SDimitry Andric
559e8d8bef9SDimitry Andric return Changed;
560e8d8bef9SDimitry Andric }
561e8d8bef9SDimitry Andric
562e8d8bef9SDimitry Andric /// Replace all SwitchInst instructions with chained branch instructions.
563e8d8bef9SDimitry Andric class LowerSwitchLegacyPass : public FunctionPass {
564e8d8bef9SDimitry Andric public:
565e8d8bef9SDimitry Andric // Pass identification, replacement for typeid
566e8d8bef9SDimitry Andric static char ID;
567e8d8bef9SDimitry Andric
LowerSwitchLegacyPass()568e8d8bef9SDimitry Andric LowerSwitchLegacyPass() : FunctionPass(ID) {
569e8d8bef9SDimitry Andric initializeLowerSwitchLegacyPassPass(*PassRegistry::getPassRegistry());
570e8d8bef9SDimitry Andric }
571e8d8bef9SDimitry Andric
572e8d8bef9SDimitry Andric bool runOnFunction(Function &F) override;
573e8d8bef9SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const574e8d8bef9SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
575e8d8bef9SDimitry Andric AU.addRequired<LazyValueInfoWrapperPass>();
576e8d8bef9SDimitry Andric }
577e8d8bef9SDimitry Andric };
578e8d8bef9SDimitry Andric
579e8d8bef9SDimitry Andric } // end anonymous namespace
580e8d8bef9SDimitry Andric
581e8d8bef9SDimitry Andric char LowerSwitchLegacyPass::ID = 0;
582e8d8bef9SDimitry Andric
583e8d8bef9SDimitry Andric // Publicly exposed interface to pass...
584e8d8bef9SDimitry Andric char &llvm::LowerSwitchID = LowerSwitchLegacyPass::ID;
585e8d8bef9SDimitry Andric
586e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LowerSwitchLegacyPass, "lowerswitch",
587e8d8bef9SDimitry Andric "Lower SwitchInst's to branches", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)588e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
589e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
590e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LowerSwitchLegacyPass, "lowerswitch",
591e8d8bef9SDimitry Andric "Lower SwitchInst's to branches", false, false)
592e8d8bef9SDimitry Andric
593e8d8bef9SDimitry Andric // createLowerSwitchPass - Interface to this file...
594e8d8bef9SDimitry Andric FunctionPass *llvm::createLowerSwitchPass() {
595e8d8bef9SDimitry Andric return new LowerSwitchLegacyPass();
596e8d8bef9SDimitry Andric }
597e8d8bef9SDimitry Andric
runOnFunction(Function & F)598e8d8bef9SDimitry Andric bool LowerSwitchLegacyPass::runOnFunction(Function &F) {
599e8d8bef9SDimitry Andric LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
600e8d8bef9SDimitry Andric auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>();
601e8d8bef9SDimitry Andric AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr;
602e8d8bef9SDimitry Andric return LowerSwitch(F, LVI, AC);
603e8d8bef9SDimitry Andric }
604e8d8bef9SDimitry Andric
run(Function & F,FunctionAnalysisManager & AM)605e8d8bef9SDimitry Andric PreservedAnalyses LowerSwitchPass::run(Function &F,
606e8d8bef9SDimitry Andric FunctionAnalysisManager &AM) {
607e8d8bef9SDimitry Andric LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
608e8d8bef9SDimitry Andric AssumptionCache *AC = AM.getCachedResult<AssumptionAnalysis>(F);
609e8d8bef9SDimitry Andric return LowerSwitch(F, LVI, AC) ? PreservedAnalyses::none()
610e8d8bef9SDimitry Andric : PreservedAnalyses::all();
611e8d8bef9SDimitry Andric }
612