1 //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/Transforms/Scalar/LoopTermFold.h" 11 #include "llvm/ADT/Statistic.h" 12 #include "llvm/Analysis/LoopAnalysisManager.h" 13 #include "llvm/Analysis/LoopInfo.h" 14 #include "llvm/Analysis/LoopPass.h" 15 #include "llvm/Analysis/MemorySSA.h" 16 #include "llvm/Analysis/MemorySSAUpdater.h" 17 #include "llvm/Analysis/ScalarEvolution.h" 18 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 19 #include "llvm/Analysis/TargetLibraryInfo.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/Config/llvm-config.h" 23 #include "llvm/IR/BasicBlock.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/InstrTypes.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/Type.h" 30 #include "llvm/IR/Value.h" 31 #include "llvm/InitializePasses.h" 32 #include "llvm/Pass.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include "llvm/Transforms/Scalar.h" 36 #include "llvm/Transforms/Utils.h" 37 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 38 #include "llvm/Transforms/Utils/Local.h" 39 #include "llvm/Transforms/Utils/LoopUtils.h" 40 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 41 #include <cassert> 42 #include <optional> 43 44 using namespace llvm; 45 46 #define DEBUG_TYPE "loop-term-fold" 47 48 STATISTIC(NumTermFold, 49 "Number of terminating condition fold recognized and performed"); 50 51 static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> 52 canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, 53 const LoopInfo &LI, const TargetTransformInfo &TTI) { 54 if (!L->isInnermost()) { 55 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); 56 return std::nullopt; 57 } 58 // Only inspect on simple loop structure 59 if (!L->isLoopSimplifyForm()) { 60 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); 61 return std::nullopt; 62 } 63 64 if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { 65 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); 66 return std::nullopt; 67 } 68 69 BasicBlock *LoopLatch = L->getLoopLatch(); 70 BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); 71 if (!BI || BI->isUnconditional()) 72 return std::nullopt; 73 auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition()); 74 if (!TermCond) { 75 LLVM_DEBUG( 76 dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); 77 return std::nullopt; 78 } 79 if (!TermCond->hasOneUse()) { 80 LLVM_DEBUG( 81 dbgs() 82 << "Cannot replace terminating condition with more than one use\n"); 83 return std::nullopt; 84 } 85 86 BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0)); 87 Value *RHS = TermCond->getOperand(1); 88 if (!LHS || !L->isLoopInvariant(RHS)) 89 // We could pattern match the inverse form of the icmp, but that is 90 // non-canonical, and this pass is running *very* late in the pipeline. 91 return std::nullopt; 92 93 // Find the IV used by the current exit condition. 94 PHINode *ToFold; 95 Value *ToFoldStart, *ToFoldStep; 96 if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) 97 return std::nullopt; 98 99 // Ensure the simple recurrence is a part of the current loop. 100 if (ToFold->getParent() != L->getHeader()) 101 return std::nullopt; 102 103 // If that IV isn't dead after we rewrite the exit condition in terms of 104 // another IV, there's no point in doing the transform. 105 if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) 106 return std::nullopt; 107 108 // Inserting instructions in the preheader has a runtime cost, scale 109 // the allowed cost with the loops trip count as best we can. 110 const unsigned ExpansionBudget = [&]() { 111 unsigned Budget = 2 * SCEVCheapExpansionBudget; 112 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) 113 return std::min(Budget, SmallTC); 114 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L)) 115 return std::min(Budget, *SmallTC); 116 // Unknown trip count, assume long running by default. 117 return Budget; 118 }(); 119 120 const SCEV *BECount = SE.getBackedgeTakenCount(L); 121 const DataLayout &DL = L->getHeader()->getDataLayout(); 122 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); 123 124 PHINode *ToHelpFold = nullptr; 125 const SCEV *TermValueS = nullptr; 126 bool MustDropPoison = false; 127 auto InsertPt = L->getLoopPreheader()->getTerminator(); 128 for (PHINode &PN : L->getHeader()->phis()) { 129 if (ToFold == &PN) 130 continue; 131 132 if (!SE.isSCEVable(PN.getType())) { 133 LLVM_DEBUG(dbgs() << "IV of phi '" << PN 134 << "' is not SCEV-able, not qualified for the " 135 "terminating condition folding.\n"); 136 continue; 137 } 138 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); 139 // Only speculate on affine AddRec 140 if (!AddRec || !AddRec->isAffine()) { 141 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN 142 << "' is not an affine add recursion, not qualified " 143 "for the terminating condition folding.\n"); 144 continue; 145 } 146 147 // Check that we can compute the value of AddRec on the exiting iteration 148 // without soundness problems. evaluateAtIteration internally needs 149 // to multiply the stride of the iteration number - which may wrap around. 150 // The issue here is subtle because computing the result accounting for 151 // wrap is insufficient. In order to use the result in an exit test, we 152 // must also know that AddRec doesn't take the same value on any previous 153 // iteration. The simplest case to consider is a candidate IV which is 154 // narrower than the trip count (and thus original IV), but this can 155 // also happen due to non-unit strides on the candidate IVs. 156 if (!AddRec->hasNoSelfWrap() || 157 !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) 158 continue; 159 160 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); 161 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); 162 if (!Expander.isSafeToExpand(TermValueSLocal)) { 163 LLVM_DEBUG( 164 dbgs() << "Is not safe to expand terminating value for phi node" << PN 165 << "\n"); 166 continue; 167 } 168 169 if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI, 170 InsertPt)) { 171 LLVM_DEBUG( 172 dbgs() << "Is too expensive to expand terminating value for phi node" 173 << PN << "\n"); 174 continue; 175 } 176 177 // The candidate IV may have been otherwise dead and poison from the 178 // very first iteration. If we can't disprove that, we can't use the IV. 179 if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { 180 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n"); 181 continue; 182 } 183 184 // The candidate IV may become poison on the last iteration. If this 185 // value is not branched on, this is a well defined program. We're 186 // about to add a new use to this IV, and we have to ensure we don't 187 // insert UB which didn't previously exist. 188 bool MustDropPoisonLocal = false; 189 Instruction *PostIncV = 190 cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch)); 191 if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), 192 &DT)) { 193 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN 194 << "\n"); 195 196 // If this is a complex recurrance with multiple instructions computing 197 // the backedge value, we might need to strip poison flags from all of 198 // them. 199 if (PostIncV->getOperand(0) != &PN) 200 continue; 201 202 // In order to perform the transform, we need to drop the poison 203 // generating flags on this instruction (if any). 204 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); 205 } 206 207 // We pick the last legal alternate IV. We could expore choosing an optimal 208 // alternate IV if we had a decent heuristic to do so. 209 ToHelpFold = &PN; 210 TermValueS = TermValueSLocal; 211 MustDropPoison = MustDropPoisonLocal; 212 } 213 214 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() 215 << "Cannot find other AddRec IV to help folding\n";); 216 217 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() 218 << "\nFound loop that can fold terminating condition\n" 219 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" 220 << " TermCond: " << *TermCond << "\n" 221 << " BrandInst: " << *BI << "\n" 222 << " ToFold: " << *ToFold << "\n" 223 << " ToHelpFold: " << *ToHelpFold << "\n"); 224 225 if (!ToFold || !ToHelpFold) 226 return std::nullopt; 227 return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); 228 } 229 230 static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, 231 LoopInfo &LI, const TargetTransformInfo &TTI, 232 TargetLibraryInfo &TLI, MemorySSA *MSSA) { 233 std::unique_ptr<MemorySSAUpdater> MSSAU; 234 if (MSSA) 235 MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); 236 237 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); 238 if (!Opt) 239 return false; 240 241 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; 242 243 NumTermFold++; 244 245 BasicBlock *LoopPreheader = L->getLoopPreheader(); 246 BasicBlock *LoopLatch = L->getLoopLatch(); 247 248 (void)ToFold; 249 LLVM_DEBUG(dbgs() << "To fold phi-node:\n" 250 << *ToFold << "\n" 251 << "New term-cond phi-node:\n" 252 << *ToHelpFold << "\n"); 253 254 Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); 255 (void)StartValue; 256 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); 257 258 // See comment in canFoldTermCondOfLoop on why this is sufficient. 259 if (MustDrop) 260 cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); 261 262 // SCEVExpander for both use in preheader and latch 263 const DataLayout &DL = L->getHeader()->getDataLayout(); 264 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); 265 266 assert(Expander.isSafeToExpand(TermValueS) && 267 "Terminating value was checked safe in canFoldTerminatingCondition"); 268 269 // Create new terminating value at loop preheader 270 Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), 271 LoopPreheader->getTerminator()); 272 273 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" 274 << *StartValue << "\n" 275 << "Terminating value of new term-cond phi-node:\n" 276 << *TermValue << "\n"); 277 278 // Create new terminating condition at loop latch 279 BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); 280 ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); 281 IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); 282 Value *NewTermCond = 283 LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, 284 "lsr_fold_term_cond.replaced_term_cond"); 285 // Swap successors to exit loop body if IV equals to new TermValue 286 if (BI->getSuccessor(0) == L->getHeader()) 287 BI->swapSuccessors(); 288 289 LLVM_DEBUG(dbgs() << "Old term-cond:\n" 290 << *OldTermCond << "\n" 291 << "New term-cond:\n" 292 << *NewTermCond << "\n"); 293 294 BI->setCondition(NewTermCond); 295 296 Expander.clear(); 297 OldTermCond->eraseFromParent(); 298 DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); 299 return true; 300 } 301 302 namespace { 303 304 class LoopTermFold : public LoopPass { 305 public: 306 static char ID; // Pass ID, replacement for typeid 307 308 LoopTermFold(); 309 310 private: 311 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 312 void getAnalysisUsage(AnalysisUsage &AU) const override; 313 }; 314 315 } // end anonymous namespace 316 317 LoopTermFold::LoopTermFold() : LoopPass(ID) { 318 initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); 319 } 320 321 void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { 322 AU.addRequired<LoopInfoWrapperPass>(); 323 AU.addPreserved<LoopInfoWrapperPass>(); 324 AU.addPreservedID(LoopSimplifyID); 325 AU.addRequiredID(LoopSimplifyID); 326 AU.addRequired<DominatorTreeWrapperPass>(); 327 AU.addPreserved<DominatorTreeWrapperPass>(); 328 AU.addRequired<ScalarEvolutionWrapperPass>(); 329 AU.addPreserved<ScalarEvolutionWrapperPass>(); 330 AU.addRequired<TargetLibraryInfoWrapperPass>(); 331 AU.addRequired<TargetTransformInfoWrapperPass>(); 332 AU.addPreserved<MemorySSAWrapperPass>(); 333 } 334 335 bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { 336 if (skipLoop(L)) 337 return false; 338 339 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 340 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 341 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 342 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( 343 *L->getHeader()->getParent()); 344 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( 345 *L->getHeader()->getParent()); 346 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); 347 MemorySSA *MSSA = nullptr; 348 if (MSSAAnalysis) 349 MSSA = &MSSAAnalysis->getMSSA(); 350 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); 351 } 352 353 PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, 354 LoopStandardAnalysisResults &AR, 355 LPMUpdater &) { 356 if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA)) 357 return PreservedAnalyses::all(); 358 359 auto PA = getLoopPassPreservedAnalyses(); 360 if (AR.MSSA) 361 PA.preserve<MemorySSAAnalysis>(); 362 return PA; 363 } 364 365 char LoopTermFold::ID = 0; 366 367 INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", 368 false, false) 369 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 370 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 371 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 372 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 373 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 374 INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", 375 false, false) 376 377 Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); } 378