1 //===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // PTX does not have a notion of `unreachable`, which results in emitted basic 10 // blocks having an edge to the next block: 11 // 12 // block1: 13 // call @does_not_return(); 14 // // unreachable 15 // block2: 16 // // ptxas will create a CFG edge from block1 to block2 17 // 18 // This may result in significant changes to the control flow graph, e.g., when 19 // LLVM moves unreachable blocks to the end of the function. That's a problem 20 // in the context of divergent control flow, as `ptxas` uses the CFG to 21 // determine divergent regions, and some intructions may not be executed 22 // divergently. 23 // 24 // For example, `bar.sync` is not allowed to be executed divergently on Pascal 25 // or earlier. If we start with the following: 26 // 27 // entry: 28 // // start of divergent region 29 // @%p0 bra cont; 30 // @%p1 bra unlikely; 31 // ... 32 // bra.uni cont; 33 // unlikely: 34 // ... 35 // // unreachable 36 // cont: 37 // // end of divergent region 38 // bar.sync 0; 39 // bra.uni exit; 40 // exit: 41 // ret; 42 // 43 // it is transformed by the branch-folder and block-placement passes to: 44 // 45 // entry: 46 // // start of divergent region 47 // @%p0 bra cont; 48 // @%p1 bra unlikely; 49 // ... 50 // bra.uni cont; 51 // cont: 52 // bar.sync 0; 53 // bra.uni exit; 54 // unlikely: 55 // ... 56 // // unreachable 57 // exit: 58 // // end of divergent region 59 // ret; 60 // 61 // After moving the `unlikely` block to the end of the function, it has an edge 62 // to the `exit` block, which widens the divergent region and makes the 63 // `bar.sync` instruction happen divergently. 64 // 65 // To work around this, we add an `exit` instruction before every `unreachable`, 66 // as `ptxas` understands that exit terminates the CFG. We do only do this if 67 // `unreachable` is not lowered to `trap`, which has the same effect (although 68 // with current versions of `ptxas` only because it is emited as `trap; exit;`). 69 // 70 //===----------------------------------------------------------------------===// 71 72 #include "NVPTX.h" 73 #include "llvm/IR/Function.h" 74 #include "llvm/IR/InlineAsm.h" 75 #include "llvm/IR/Instructions.h" 76 #include "llvm/IR/Type.h" 77 #include "llvm/Pass.h" 78 79 using namespace llvm; 80 81 namespace llvm { 82 void initializeNVPTXLowerUnreachablePass(PassRegistry &); 83 } 84 85 namespace { 86 class NVPTXLowerUnreachable : public FunctionPass { 87 StringRef getPassName() const override; 88 bool runOnFunction(Function &F) override; 89 bool isLoweredToTrap(const UnreachableInst &I) const; 90 91 public: 92 static char ID; // Pass identification, replacement for typeid 93 NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn) 94 : FunctionPass(ID), TrapUnreachable(TrapUnreachable), 95 NoTrapAfterNoreturn(NoTrapAfterNoreturn) {} 96 97 private: 98 bool TrapUnreachable; 99 bool NoTrapAfterNoreturn; 100 }; 101 } // namespace 102 103 char NVPTXLowerUnreachable::ID = 1; 104 105 INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable", 106 "Lower Unreachable", false, false) 107 108 StringRef NVPTXLowerUnreachable::getPassName() const { 109 return "add an exit instruction before every unreachable"; 110 } 111 112 // ============================================================================= 113 // Returns whether a `trap` intrinsic should be emitted before I. 114 // 115 // This is a copy of the logic in SelectionDAGBuilder::visitUnreachable(). 116 // ============================================================================= 117 bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const { 118 if (!TrapUnreachable) 119 return false; 120 if (!NoTrapAfterNoreturn) 121 return true; 122 const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode()); 123 return Call && Call->doesNotReturn(); 124 } 125 126 // ============================================================================= 127 // Main function for this pass. 128 // ============================================================================= 129 bool NVPTXLowerUnreachable::runOnFunction(Function &F) { 130 if (skipFunction(F)) 131 return false; 132 // Early out iff isLoweredToTrap() always returns true. 133 if (TrapUnreachable && !NoTrapAfterNoreturn) 134 return false; 135 136 LLVMContext &C = F.getContext(); 137 FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false); 138 InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true); 139 140 bool Changed = false; 141 for (auto &BB : F) 142 for (auto &I : BB) { 143 if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) { 144 if (isLoweredToTrap(*unreachableInst)) 145 continue; // trap is emitted as `trap; exit;`. 146 CallInst::Create(ExitFTy, Exit, "", unreachableInst->getIterator()); 147 Changed = true; 148 } 149 } 150 return Changed; 151 } 152 153 FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable, 154 bool NoTrapAfterNoreturn) { 155 return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn); 156 } 157