1 //===- JumpTableToSwitch.cpp ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/JumpTableToSwitch.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/ConstantFolding.h" 12 #include "llvm/Analysis/DomTreeUpdater.h" 13 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 14 #include "llvm/Analysis/PostDominators.h" 15 #include "llvm/IR/IRBuilder.h" 16 #include "llvm/Support/CommandLine.h" 17 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 18 19 using namespace llvm; 20 21 static cl::opt<unsigned> 22 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, 23 cl::desc("Only split jump tables with size less or " 24 "equal than JumpTableSizeThreshold."), 25 cl::init(10)); 26 27 // TODO: Consider adding a cost model for profitability analysis of this 28 // transformation. Currently we replace a jump table with a switch if all the 29 // functions in the jump table are smaller than the provided threshold. 30 static cl::opt<unsigned> FunctionSizeThreshold( 31 "jump-table-to-switch-function-size-threshold", cl::Hidden, 32 cl::desc("Only split jump tables containing functions whose sizes are less " 33 "or equal than this threshold."), 34 cl::init(50)); 35 36 #define DEBUG_TYPE "jump-table-to-switch" 37 38 namespace { 39 struct JumpTableTy { 40 Value *Index; 41 SmallVector<Function *, 10> Funcs; 42 }; 43 } // anonymous namespace 44 45 static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP, 46 PointerType *PtrTy) { 47 Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand()); 48 if (!Ptr) 49 return std::nullopt; 50 51 GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr); 52 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 53 return std::nullopt; 54 55 Function &F = *GEP->getParent()->getParent(); 56 const DataLayout &DL = F.getDataLayout(); 57 const unsigned BitWidth = 58 DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); 59 MapVector<Value *, APInt> VariableOffsets; 60 APInt ConstantOffset(BitWidth, 0); 61 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) 62 return std::nullopt; 63 if (VariableOffsets.size() != 1) 64 return std::nullopt; 65 // TODO: consider supporting more general patterns 66 if (!ConstantOffset.isZero()) 67 return std::nullopt; 68 APInt StrideBytes = VariableOffsets.front().second; 69 const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType()); 70 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0) 71 return std::nullopt; 72 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue(); 73 if (N > JumpTableSizeThreshold) 74 return std::nullopt; 75 76 JumpTableTy JumpTable; 77 JumpTable.Index = VariableOffsets.front().first; 78 JumpTable.Funcs.reserve(N); 79 for (uint64_t Index = 0; Index < N; ++Index) { 80 // ConstantOffset is zero. 81 APInt Offset = Index * StrideBytes; 82 Constant *C = 83 ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL); 84 auto *Func = dyn_cast_or_null<Function>(C); 85 if (!Func || Func->isDeclaration() || 86 Func->getInstructionCount() > FunctionSizeThreshold) 87 return std::nullopt; 88 JumpTable.Funcs.push_back(Func); 89 } 90 return JumpTable; 91 } 92 93 static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, 94 DomTreeUpdater &DTU, 95 OptimizationRemarkEmitter &ORE) { 96 const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); 97 98 SmallVector<DominatorTree::UpdateType, 8> DTUpdates; 99 BasicBlock *BB = CB->getParent(); 100 BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr, 101 BB->getName() + Twine(".tail")); 102 DTUpdates.push_back({DominatorTree::Delete, BB, Tail}); 103 BB->getTerminator()->eraseFromParent(); 104 105 Function &F = *BB->getParent(); 106 BasicBlock *BBUnreachable = BasicBlock::Create( 107 F.getContext(), "default.switch.case.unreachable", &F, Tail); 108 IRBuilder<> BuilderUnreachable(BBUnreachable); 109 BuilderUnreachable.CreateUnreachable(); 110 111 IRBuilder<> Builder(BB); 112 SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable); 113 DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable}); 114 115 IRBuilder<> BuilderTail(CB); 116 PHINode *PHI = 117 IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); 118 119 for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { 120 BasicBlock *B = BasicBlock::Create(Func->getContext(), 121 "call." + Twine(Index), &F, Tail); 122 DTUpdates.push_back({DominatorTree::Insert, BB, B}); 123 DTUpdates.push_back({DominatorTree::Insert, B, Tail}); 124 125 CallBase *Call = cast<CallBase>(CB->clone()); 126 Call->setCalledFunction(Func); 127 Call->insertInto(B, B->end()); 128 Switch->addCase( 129 cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B); 130 BranchInst::Create(Tail, B); 131 if (PHI) 132 PHI->addIncoming(Call, B); 133 } 134 DTU.applyUpdates(DTUpdates); 135 ORE.emit([&]() { 136 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB) 137 << "expanded indirect call into switch"; 138 }); 139 if (PHI) 140 CB->replaceAllUsesWith(PHI); 141 CB->eraseFromParent(); 142 return Tail; 143 } 144 145 PreservedAnalyses JumpTableToSwitchPass::run(Function &F, 146 FunctionAnalysisManager &AM) { 147 OptimizationRemarkEmitter &ORE = 148 AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 149 DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 150 PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); 151 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); 152 bool Changed = false; 153 for (BasicBlock &BB : make_early_inc_range(F)) { 154 BasicBlock *CurrentBB = &BB; 155 while (CurrentBB) { 156 BasicBlock *SplittedOutTail = nullptr; 157 for (Instruction &I : make_early_inc_range(*CurrentBB)) { 158 auto *Call = dyn_cast<CallInst>(&I); 159 if (!Call || Call->getCalledFunction() || Call->isMustTailCall()) 160 continue; 161 auto *L = dyn_cast<LoadInst>(Call->getCalledOperand()); 162 // Skip atomic or volatile loads. 163 if (!L || !L->isSimple()) 164 continue; 165 auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand()); 166 if (!GEP) 167 continue; 168 auto *PtrTy = dyn_cast<PointerType>(L->getType()); 169 assert(PtrTy && "call operand must be a pointer"); 170 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy); 171 if (!JumpTable) 172 continue; 173 SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE); 174 Changed = true; 175 break; 176 } 177 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; 178 } 179 } 180 181 if (!Changed) 182 return PreservedAnalyses::all(); 183 184 PreservedAnalyses PA; 185 if (DT) 186 PA.preserve<DominatorTreeAnalysis>(); 187 if (PDT) 188 PA.preserve<PostDominatorTreeAnalysis>(); 189 return PA; 190 } 191