xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp (revision 05427f4639bcf2703329a9be9d25ec09bb782742)
1 //===- JumpTableToSwitch.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/ConstantFolding.h"
12 #include "llvm/Analysis/DomTreeUpdater.h"
13 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
14 #include "llvm/Analysis/PostDominators.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
18 
19 using namespace llvm;
20 
21 static cl::opt<unsigned>
22     JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23                            cl::desc("Only split jump tables with size less or "
24                                     "equal than JumpTableSizeThreshold."),
25                            cl::init(10));
26 
27 // TODO: Consider adding a cost model for profitability analysis of this
28 // transformation. Currently we replace a jump table with a switch if all the
29 // functions in the jump table are smaller than the provided threshold.
30 static cl::opt<unsigned> FunctionSizeThreshold(
31     "jump-table-to-switch-function-size-threshold", cl::Hidden,
32     cl::desc("Only split jump tables containing functions whose sizes are less "
33              "or equal than this threshold."),
34     cl::init(50));
35 
36 #define DEBUG_TYPE "jump-table-to-switch"
37 
38 namespace {
39 struct JumpTableTy {
40   Value *Index;
41   SmallVector<Function *, 10> Funcs;
42 };
43 } // anonymous namespace
44 
45 static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46                                                  PointerType *PtrTy) {
47   Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
48   if (!Ptr)
49     return std::nullopt;
50 
51   GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
52   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53     return std::nullopt;
54 
55   Function &F = *GEP->getParent()->getParent();
56   const DataLayout &DL = F.getDataLayout();
57   const unsigned BitWidth =
58       DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
59   MapVector<Value *, APInt> VariableOffsets;
60   APInt ConstantOffset(BitWidth, 0);
61   if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62     return std::nullopt;
63   if (VariableOffsets.size() != 1)
64     return std::nullopt;
65   // TODO: consider supporting more general patterns
66   if (!ConstantOffset.isZero())
67     return std::nullopt;
68   APInt StrideBytes = VariableOffsets.front().second;
69   const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
70   if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71     return std::nullopt;
72   const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73   if (N > JumpTableSizeThreshold)
74     return std::nullopt;
75 
76   JumpTableTy JumpTable;
77   JumpTable.Index = VariableOffsets.front().first;
78   JumpTable.Funcs.reserve(N);
79   for (uint64_t Index = 0; Index < N; ++Index) {
80     // ConstantOffset is zero.
81     APInt Offset = Index * StrideBytes;
82     Constant *C =
83         ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
84     auto *Func = dyn_cast_or_null<Function>(C);
85     if (!Func || Func->isDeclaration() ||
86         Func->getInstructionCount() > FunctionSizeThreshold)
87       return std::nullopt;
88     JumpTable.Funcs.push_back(Func);
89   }
90   return JumpTable;
91 }
92 
93 static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94                                   DomTreeUpdater &DTU,
95                                   OptimizationRemarkEmitter &ORE) {
96   const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97 
98   SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99   BasicBlock *BB = CB->getParent();
100   BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
101                                 BB->getName() + Twine(".tail"));
102   DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
103   BB->getTerminator()->eraseFromParent();
104 
105   Function &F = *BB->getParent();
106   BasicBlock *BBUnreachable = BasicBlock::Create(
107       F.getContext(), "default.switch.case.unreachable", &F, Tail);
108   IRBuilder<> BuilderUnreachable(BBUnreachable);
109   BuilderUnreachable.CreateUnreachable();
110 
111   IRBuilder<> Builder(BB);
112   SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
113   DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
114 
115   IRBuilder<> BuilderTail(CB);
116   PHINode *PHI =
117       IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118 
119   for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120     BasicBlock *B = BasicBlock::Create(Func->getContext(),
121                                        "call." + Twine(Index), &F, Tail);
122     DTUpdates.push_back({DominatorTree::Insert, BB, B});
123     DTUpdates.push_back({DominatorTree::Insert, B, Tail});
124 
125     CallBase *Call = cast<CallBase>(CB->clone());
126     Call->setCalledFunction(Func);
127     Call->insertInto(B, B->end());
128     Switch->addCase(
129         cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
130     BranchInst::Create(Tail, B);
131     if (PHI)
132       PHI->addIncoming(Call, B);
133   }
134   DTU.applyUpdates(DTUpdates);
135   ORE.emit([&]() {
136     return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137            << "expanded indirect call into switch";
138   });
139   if (PHI)
140     CB->replaceAllUsesWith(PHI);
141   CB->eraseFromParent();
142   return Tail;
143 }
144 
145 PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146                                              FunctionAnalysisManager &AM) {
147   OptimizationRemarkEmitter &ORE =
148       AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
149   DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
150   PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152   bool Changed = false;
153   for (BasicBlock &BB : make_early_inc_range(F)) {
154     BasicBlock *CurrentBB = &BB;
155     while (CurrentBB) {
156       BasicBlock *SplittedOutTail = nullptr;
157       for (Instruction &I : make_early_inc_range(*CurrentBB)) {
158         auto *Call = dyn_cast<CallInst>(&I);
159         if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160           continue;
161         auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
162         // Skip atomic or volatile loads.
163         if (!L || !L->isSimple())
164           continue;
165         auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
166         if (!GEP)
167           continue;
168         auto *PtrTy = dyn_cast<PointerType>(L->getType());
169         assert(PtrTy && "call operand must be a pointer");
170         std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171         if (!JumpTable)
172           continue;
173         SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
174         Changed = true;
175         break;
176       }
177       CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178     }
179   }
180 
181   if (!Changed)
182     return PreservedAnalyses::all();
183 
184   PreservedAnalyses PA;
185   if (DT)
186     PA.preserve<DominatorTreeAnalysis>();
187   if (PDT)
188     PA.preserve<PostDominatorTreeAnalysis>();
189   return PA;
190 }
191