1 //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass is used to ensure that functions have at most one return 10 // instruction in them. Additionally, it keeps track of which node is the new 11 // exit node of the CFG. If there are no exit nodes in the CFG, the getExitNode 12 // method will return a null pointer. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 17 #include "llvm/IR/BasicBlock.h" 18 #include "llvm/IR/Function.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/Type.h" 21 #include "llvm/Transforms/Utils.h" 22 using namespace llvm; 23 24 char UnifyFunctionExitNodes::ID = 0; 25 INITIALIZE_PASS(UnifyFunctionExitNodes, "mergereturn", 26 "Unify function exit nodes", false, false) 27 28 Pass *llvm::createUnifyFunctionExitNodesPass() { 29 return new UnifyFunctionExitNodes(); 30 } 31 32 void UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{ 33 // We preserve the non-critical-edgeness property 34 AU.addPreservedID(BreakCriticalEdgesID); 35 // This is a cluster of orthogonal Transforms 36 AU.addPreservedID(LowerSwitchID); 37 } 38 39 // UnifyAllExitNodes - Unify all exit nodes of the CFG by creating a new 40 // BasicBlock, and converting all returns to unconditional branches to this 41 // new basic block. The singular exit node is returned. 42 // 43 // If there are no return stmts in the Function, a null pointer is returned. 44 // 45 bool UnifyFunctionExitNodes::runOnFunction(Function &F) { 46 // Loop over all of the blocks in a function, tracking all of the blocks that 47 // return. 48 // 49 std::vector<BasicBlock*> ReturningBlocks; 50 std::vector<BasicBlock*> UnreachableBlocks; 51 for (BasicBlock &I : F) 52 if (isa<ReturnInst>(I.getTerminator())) 53 ReturningBlocks.push_back(&I); 54 else if (isa<UnreachableInst>(I.getTerminator())) 55 UnreachableBlocks.push_back(&I); 56 57 // Then unreachable blocks. 58 if (UnreachableBlocks.empty()) { 59 UnreachableBlock = nullptr; 60 } else if (UnreachableBlocks.size() == 1) { 61 UnreachableBlock = UnreachableBlocks.front(); 62 } else { 63 UnreachableBlock = BasicBlock::Create(F.getContext(), 64 "UnifiedUnreachableBlock", &F); 65 new UnreachableInst(F.getContext(), UnreachableBlock); 66 67 for (BasicBlock *BB : UnreachableBlocks) { 68 BB->getInstList().pop_back(); // Remove the unreachable inst. 69 BranchInst::Create(UnreachableBlock, BB); 70 } 71 } 72 73 // Now handle return blocks. 74 if (ReturningBlocks.empty()) { 75 ReturnBlock = nullptr; 76 return false; // No blocks return 77 } else if (ReturningBlocks.size() == 1) { 78 ReturnBlock = ReturningBlocks.front(); // Already has a single return block 79 return false; 80 } 81 82 // Otherwise, we need to insert a new basic block into the function, add a PHI 83 // nodes (if the function returns values), and convert all of the return 84 // instructions into unconditional branches. 85 // 86 BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), 87 "UnifiedReturnBlock", &F); 88 89 PHINode *PN = nullptr; 90 if (F.getReturnType()->isVoidTy()) { 91 ReturnInst::Create(F.getContext(), nullptr, NewRetBlock); 92 } else { 93 // If the function doesn't return void... add a PHI node to the block... 94 PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), 95 "UnifiedRetVal"); 96 NewRetBlock->getInstList().push_back(PN); 97 ReturnInst::Create(F.getContext(), PN, NewRetBlock); 98 } 99 100 // Loop over all of the blocks, replacing the return instruction with an 101 // unconditional branch. 102 // 103 for (BasicBlock *BB : ReturningBlocks) { 104 // Add an incoming element to the PHI node for every return instruction that 105 // is merging into this new block... 106 if (PN) 107 PN->addIncoming(BB->getTerminator()->getOperand(0), BB); 108 109 BB->getInstList().pop_back(); // Remove the return insn 110 BranchInst::Create(NewRetBlock, BB); 111 } 112 ReturnBlock = NewRetBlock; 113 return true; 114 } 115