1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass identifies/eliminate Redundant TLS Loads if related option is set. 10 // The example: Please refer to the comment at the head of TLSVariableHoist.h. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/IR/BasicBlock.h" 16 #include "llvm/IR/Dominators.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/InstrTypes.h" 19 #include "llvm/IR/Instruction.h" 20 #include "llvm/IR/Instructions.h" 21 #include "llvm/IR/IntrinsicInst.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/Casting.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "llvm/Transforms/Scalar.h" 30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h" 31 #include <algorithm> 32 #include <cassert> 33 #include <cstdint> 34 #include <iterator> 35 #include <tuple> 36 #include <utility> 37 38 using namespace llvm; 39 using namespace tlshoist; 40 41 #define DEBUG_TYPE "tlshoist" 42 43 static cl::opt<bool> TLSLoadHoist( 44 "tls-load-hoist", cl::init(false), cl::Hidden, 45 cl::desc("hoist the TLS loads in PIC model to eliminate redundant " 46 "TLS address calculation.")); 47 48 namespace { 49 50 /// The TLS Variable hoist pass. 51 class TLSVariableHoistLegacyPass : public FunctionPass { 52 public: 53 static char ID; // Pass identification, replacement for typeid 54 55 TLSVariableHoistLegacyPass() : FunctionPass(ID) { 56 initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry()); 57 } 58 59 bool runOnFunction(Function &Fn) override; 60 61 StringRef getPassName() const override { return "TLS Variable Hoist"; } 62 63 void getAnalysisUsage(AnalysisUsage &AU) const override { 64 AU.setPreservesCFG(); 65 AU.addRequired<DominatorTreeWrapperPass>(); 66 AU.addRequired<LoopInfoWrapperPass>(); 67 } 68 69 private: 70 TLSVariableHoistPass Impl; 71 }; 72 73 } // end anonymous namespace 74 75 char TLSVariableHoistLegacyPass::ID = 0; 76 77 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist", 78 "TLS Variable Hoist", false, false) 79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 80 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 81 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist", 82 "TLS Variable Hoist", false, false) 83 84 FunctionPass *llvm::createTLSVariableHoistPass() { 85 return new TLSVariableHoistLegacyPass(); 86 } 87 88 /// Perform the TLS Variable Hoist optimization for the given function. 89 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) { 90 if (skipFunction(Fn)) 91 return false; 92 93 LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n"); 94 LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); 95 96 bool MadeChange = 97 Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), 98 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()); 99 100 if (MadeChange) { 101 LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: " 102 << Fn.getName() << '\n'); 103 LLVM_DEBUG(dbgs() << Fn); 104 } 105 LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n"); 106 107 return MadeChange; 108 } 109 110 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) { 111 // Skip all cast instructions. They are visited indirectly later on. 112 if (Inst->isCast()) 113 return; 114 115 // Scan all operands. 116 for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { 117 auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx)); 118 if (!GV || !GV->isThreadLocal()) 119 continue; 120 121 // Add Candidate to TLSCandMap (GV --> Candidate). 122 TLSCandMap[GV].addUser(Inst, Idx); 123 } 124 } 125 126 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) { 127 // First, quickly check if there is TLS Variable. 128 Module *M = Fn.getParent(); 129 130 bool HasTLS = llvm::any_of( 131 M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); }); 132 133 // If non, directly return. 134 if (!HasTLS) 135 return; 136 137 TLSCandMap.clear(); 138 139 // Then, collect TLS Variable info. 140 for (BasicBlock &BB : Fn) { 141 // Ignore unreachable basic blocks. 142 if (!DT->isReachableFromEntry(&BB)) 143 continue; 144 145 for (Instruction &Inst : BB) 146 collectTLSCandidate(&Inst); 147 } 148 } 149 150 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) { 151 if (Cand.Users.size() != 1) 152 return false; 153 154 BasicBlock *BB = Cand.Users[0].Inst->getParent(); 155 if (LI->getLoopFor(BB)) 156 return false; 157 158 return true; 159 } 160 161 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB, 162 Loop *L) { 163 assert(L && "Unexcepted Loop status!"); 164 165 // Get the outermost loop. 166 while (Loop *Parent = L->getParentLoop()) 167 L = Parent; 168 169 BasicBlock *PreHeader = L->getLoopPreheader(); 170 171 // There is unique predecessor outside the loop. 172 if (PreHeader) 173 return PreHeader->getTerminator(); 174 175 BasicBlock *Header = L->getHeader(); 176 BasicBlock *Dom = Header; 177 for (BasicBlock *PredBB : predecessors(Header)) 178 Dom = DT->findNearestCommonDominator(Dom, PredBB); 179 180 assert(Dom && "Not find dominator BB!"); 181 Instruction *Term = Dom->getTerminator(); 182 183 return Term; 184 } 185 186 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, 187 Instruction *I2) { 188 if (!I1) 189 return I2; 190 if (DT->dominates(I1, I2)) 191 return I1; 192 if (DT->dominates(I2, I1)) 193 return I2; 194 195 // If there is no dominance relation, use common dominator. 196 BasicBlock *DomBB = 197 DT->findNearestCommonDominator(I1->getParent(), I2->getParent()); 198 199 Instruction *Dom = DomBB->getTerminator(); 200 assert(Dom && "Common dominator not found!"); 201 202 return Dom; 203 } 204 205 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, 206 GlobalVariable *GV, 207 BasicBlock *&PosBB) { 208 tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; 209 210 // We should hoist the TLS use out of loop, so choose its nearest instruction 211 // which dominate the loop and the outside loops (if exist). 212 Instruction *LastPos = nullptr; 213 for (auto &User : Cand.Users) { 214 BasicBlock *BB = User.Inst->getParent(); 215 Instruction *Pos = User.Inst; 216 if (Loop *L = LI->getLoopFor(BB)) { 217 Pos = getNearestLoopDomInst(BB, L); 218 assert(Pos && "Not find insert position out of loop!"); 219 } 220 Pos = getDomInst(LastPos, Pos); 221 LastPos = Pos; 222 } 223 224 assert(LastPos && "Unexpected insert position!"); 225 BasicBlock *Parent = LastPos->getParent(); 226 PosBB = Parent; 227 return LastPos->getIterator(); 228 } 229 230 // Generate a bitcast (no type change) to replace the uses of TLS Candidate. 231 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, 232 GlobalVariable *GV) { 233 BasicBlock *PosBB = &Fn.getEntryBlock(); 234 BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); 235 Type *Ty = GV->getType(); 236 auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); 237 PosBB->getInstList().insert(Iter, CastInst); 238 return CastInst; 239 } 240 241 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn, 242 GlobalVariable *GV) { 243 244 tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; 245 246 // If only used 1 time and not in loops, we no need to replace it. 247 if (oneUseOutsideLoop(Cand, LI)) 248 return false; 249 250 // Generate a bitcast (no type change) 251 auto *CastInst = genBitCastInst(Fn, GV); 252 253 // to replace the uses of TLS Candidate 254 for (auto &User : Cand.Users) 255 User.Inst->setOperand(User.OpndIdx, CastInst); 256 257 return true; 258 } 259 260 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) { 261 if (TLSCandMap.empty()) 262 return false; 263 264 bool Replaced = false; 265 for (auto &GV2Cand : TLSCandMap) { 266 GlobalVariable *GV = GV2Cand.first; 267 Replaced |= tryReplaceTLSCandidate(Fn, GV); 268 } 269 270 return Replaced; 271 } 272 273 /// Optimize expensive TLS variables in the given function. 274 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT, 275 LoopInfo &LI) { 276 if (Fn.hasOptNone()) 277 return false; 278 279 if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist")) 280 return false; 281 282 this->LI = &LI; 283 this->DT = &DT; 284 assert(this->LI && this->DT && "Unexcepted requirement!"); 285 286 // Collect all TLS variable candidates. 287 collectTLSCandidates(Fn); 288 289 bool MadeChange = tryReplaceTLSCandidates(Fn); 290 291 return MadeChange; 292 } 293 294 PreservedAnalyses TLSVariableHoistPass::run(Function &F, 295 FunctionAnalysisManager &AM) { 296 297 auto &LI = AM.getResult<LoopAnalysis>(F); 298 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 299 300 if (!runImpl(F, DT, LI)) 301 return PreservedAnalyses::all(); 302 303 PreservedAnalyses PA; 304 PA.preserveSet<CFGAnalyses>(); 305 return PA; 306 } 307