1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/ADT/Sequence.h" 11 #include "llvm/Analysis/LoopAccessAnalysis.h" 12 #include "llvm/Analysis/LoopAnalysisManager.h" 13 #include "llvm/Analysis/LoopInfo.h" 14 #include "llvm/Analysis/LoopIterator.h" 15 #include "llvm/Analysis/LoopPass.h" 16 #include "llvm/Analysis/MemorySSA.h" 17 #include "llvm/Analysis/MemorySSAUpdater.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 20 #include "llvm/IR/PatternMatch.h" 21 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 #include "llvm/Transforms/Utils/LoopSimplify.h" 24 #include "llvm/Transforms/Utils/LoopUtils.h" 25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 26 27 #define DEBUG_TYPE "loop-bound-split" 28 29 namespace llvm { 30 31 using namespace PatternMatch; 32 33 namespace { 34 struct ConditionInfo { 35 /// Branch instruction with this condition 36 BranchInst *BI; 37 /// ICmp instruction with this condition 38 ICmpInst *ICmp; 39 /// Preciate info 40 ICmpInst::Predicate Pred; 41 /// AddRec llvm value 42 Value *AddRecValue; 43 /// Non PHI AddRec llvm value 44 Value *NonPHIAddRecValue; 45 /// Bound llvm value 46 Value *BoundValue; 47 /// AddRec SCEV 48 const SCEVAddRecExpr *AddRecSCEV; 49 /// Bound SCEV 50 const SCEV *BoundSCEV; 51 52 ConditionInfo() 53 : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE), 54 AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr), 55 BoundSCEV(nullptr) {} 56 }; 57 } // namespace 58 59 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 60 ConditionInfo &Cond, const Loop &L) { 61 Cond.ICmp = ICmp; 62 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 63 m_Value(Cond.BoundValue)))) { 64 const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 65 const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue); 66 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 67 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV); 68 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 69 if (!LHSAddRecSCEV && RHSAddRecSCEV) { 70 std::swap(Cond.AddRecValue, Cond.BoundValue); 71 std::swap(AddRecSCEV, BoundSCEV); 72 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 73 } 74 75 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 76 Cond.BoundSCEV = BoundSCEV; 77 Cond.NonPHIAddRecValue = Cond.AddRecValue; 78 79 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with 80 // value from backedge. 81 if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) { 82 PHINode *PN = cast<PHINode>(Cond.AddRecValue); 83 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch()); 84 } 85 } 86 } 87 88 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 89 ConditionInfo &Cond, bool IsExitCond) { 90 if (IsExitCond) { 91 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 92 if (isa<SCEVCouldNotCompute>(ExitCount)) 93 return false; 94 95 Cond.BoundSCEV = ExitCount; 96 return true; 97 } 98 99 // For non-exit condtion, if pred is LT, keep existing bound. 100 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 101 return true; 102 103 // For non-exit condition, if pre is LE, try to convert it to LT. 104 // Range Range 105 // AddRec <= Bound --> AddRec < Bound + 1 106 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 107 return false; 108 109 if (IntegerType *BoundSCEVIntType = 110 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 111 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 112 APInt Max = ICmpInst::isSigned(Cond.Pred) 113 ? APInt::getSignedMaxValue(BitWidth) 114 : APInt::getMaxValue(BitWidth); 115 const SCEV *MaxSCEV = SE.getConstant(Max); 116 // Check Bound < INT_MAX 117 ICmpInst::Predicate Pred = 118 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 119 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 120 const SCEV *BoundPlusOneSCEV = 121 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 122 Cond.BoundSCEV = BoundPlusOneSCEV; 123 Cond.Pred = Pred; 124 return true; 125 } 126 } 127 128 // ToDo: Support ICMP_NE/EQ. 129 130 return false; 131 } 132 133 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 134 ICmpInst *ICmp, ConditionInfo &Cond, 135 bool IsExitCond) { 136 analyzeICmp(SE, ICmp, Cond, L); 137 138 // The BoundSCEV should be evaluated at loop entry. 139 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 140 return false; 141 142 // Allowed AddRec as induction variable. 143 if (!Cond.AddRecSCEV) 144 return false; 145 146 if (!Cond.AddRecSCEV->isAffine()) 147 return false; 148 149 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE); 150 // Allowed constant step. 151 if (!isa<SCEVConstant>(StepRecSCEV)) 152 return false; 153 154 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 155 // Allowed positive step for now. 156 // TODO: Support negative step. 157 if (StepCI->isNegative() || StepCI->isZero()) 158 return false; 159 160 // Calculate upper bound. 161 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 162 return false; 163 164 return true; 165 } 166 167 static bool isProcessableCondBI(const ScalarEvolution &SE, 168 const BranchInst *BI) { 169 BasicBlock *TrueSucc = nullptr; 170 BasicBlock *FalseSucc = nullptr; 171 ICmpInst::Predicate Pred; 172 Value *LHS, *RHS; 173 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 174 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 175 return false; 176 177 if (!SE.isSCEVable(LHS->getType())) 178 return false; 179 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 180 181 if (TrueSucc == FalseSucc) 182 return false; 183 184 return true; 185 } 186 187 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 188 ScalarEvolution &SE, ConditionInfo &Cond) { 189 // Skip function with optsize. 190 if (L.getHeader()->getParent()->hasOptSize()) 191 return false; 192 193 // Split only innermost loop. 194 if (!L.isInnermost()) 195 return false; 196 197 // Check loop is in simplified form. 198 if (!L.isLoopSimplifyForm()) 199 return false; 200 201 // Check loop is in LCSSA form. 202 if (!L.isLCSSAForm(DT)) 203 return false; 204 205 // Skip loop that cannot be cloned. 206 if (!L.isSafeToClone()) 207 return false; 208 209 BasicBlock *ExitingBB = L.getExitingBlock(); 210 // Assumed only one exiting block. 211 if (!ExitingBB) 212 return false; 213 214 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 215 if (!ExitingBI) 216 return false; 217 218 // Allowed only conditional branch with ICmp. 219 if (!isProcessableCondBI(SE, ExitingBI)) 220 return false; 221 222 // Check the condition is processable. 223 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 224 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 225 return false; 226 227 Cond.BI = ExitingBI; 228 return true; 229 } 230 231 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 232 // If the conditional branch splits a loop into two halves, we could 233 // generally say it is profitable. 234 // 235 // ToDo: Add more profitable cases here. 236 237 // Check this branch causes diamond CFG. 238 BasicBlock *Succ0 = BI->getSuccessor(0); 239 BasicBlock *Succ1 = BI->getSuccessor(1); 240 241 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 242 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 243 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 244 return false; 245 246 // ToDo: Calculate each successor's instruction cost. 247 248 return true; 249 } 250 251 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 252 ConditionInfo &ExitingCond, 253 ConditionInfo &SplitCandidateCond) { 254 for (auto *BB : L.blocks()) { 255 // Skip condition of backedge. 256 if (L.getLoopLatch() == BB) 257 continue; 258 259 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 260 if (!BI) 261 continue; 262 263 // Check conditional branch with ICmp. 264 if (!isProcessableCondBI(SE, BI)) 265 continue; 266 267 // Skip loop invariant condition. 268 if (L.isLoopInvariant(BI->getCondition())) 269 continue; 270 271 // Check the condition is processable. 272 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 273 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 274 /*IsExitCond*/ false)) 275 continue; 276 277 if (ExitingCond.BoundSCEV->getType() != 278 SplitCandidateCond.BoundSCEV->getType()) 279 continue; 280 281 // After transformation, we assume the split condition of the pre-loop is 282 // always true. In order to guarantee it, we need to check the start value 283 // of the split cond AddRec satisfies the split condition. 284 if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred, 285 SplitCandidateCond.AddRecSCEV->getStart(), 286 SplitCandidateCond.BoundSCEV)) 287 continue; 288 289 SplitCandidateCond.BI = BI; 290 return BI; 291 } 292 293 return nullptr; 294 } 295 296 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 297 ScalarEvolution &SE, LPMUpdater &U) { 298 ConditionInfo SplitCandidateCond; 299 ConditionInfo ExitingCond; 300 301 // Check we can split this loop's bound. 302 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 303 return false; 304 305 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 306 return false; 307 308 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 309 return false; 310 311 // Now, we have a split candidate. Let's build a form as below. 312 // +--------------------+ 313 // | preheader | 314 // | set up newbound | 315 // +--------------------+ 316 // | /----------------\ 317 // +--------v----v------+ | 318 // | header |---\ | 319 // | with true condition| | | 320 // +--------------------+ | | 321 // | | | 322 // +--------v-----------+ | | 323 // | if.then.BB | | | 324 // +--------------------+ | | 325 // | | | 326 // +--------v-----------<---/ | 327 // | latch >----------/ 328 // | with newbound | 329 // +--------------------+ 330 // | 331 // +--------v-----------+ 332 // | preheader2 |--------------\ 333 // | if (AddRec i != | | 334 // | org bound) | | 335 // +--------------------+ | 336 // | /----------------\ | 337 // +--------v----v------+ | | 338 // | header2 |---\ | | 339 // | conditional branch | | | | 340 // |with false condition| | | | 341 // +--------------------+ | | | 342 // | | | | 343 // +--------v-----------+ | | | 344 // | if.then.BB2 | | | | 345 // +--------------------+ | | | 346 // | | | | 347 // +--------v-----------<---/ | | 348 // | latch2 >----------/ | 349 // | with org bound | | 350 // +--------v-----------+ | 351 // | | 352 // | +---------------+ | 353 // +--> exit <-------/ 354 // +---------------+ 355 356 // Let's create post loop. 357 SmallVector<BasicBlock *, 8> PostLoopBlocks; 358 Loop *PostLoop; 359 ValueToValueMapTy VMap; 360 BasicBlock *PreHeader = L.getLoopPreheader(); 361 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 362 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 363 ".split", &LI, &DT, PostLoopBlocks); 364 remapInstructionsInBlocks(PostLoopBlocks, VMap); 365 366 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 367 IRBuilder<> Builder(&PostLoopPreHeader->front()); 368 369 // Update phi nodes in header of post-loop. 370 bool isExitingLatch = 371 (L.getExitingBlock() == L.getLoopLatch()) ? true : false; 372 Value *ExitingCondLCSSAPhi = nullptr; 373 for (PHINode &PN : L.getHeader()->phis()) { 374 // Create LCSSA phi node in preheader of post-loop. 375 PHINode *LCSSAPhi = 376 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 377 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 378 // If the exiting block is loop latch, the phi does not have the update at 379 // last iteration. In this case, update lcssa phi with value from backedge. 380 LCSSAPhi->addIncoming( 381 isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN, 382 L.getExitingBlock()); 383 384 // Update the start value of phi node in post-loop with the LCSSA phi node. 385 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 386 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi); 387 388 // Find PHI with exiting condition from pre-loop. The PHI should be 389 // SCEVAddRecExpr and have same incoming value from backedge with 390 // ExitingCond. 391 if (!SE.isSCEVable(PN.getType())) 392 continue; 393 394 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); 395 if (PhiSCEV && ExitingCond.NonPHIAddRecValue == 396 PN.getIncomingValueForBlock(L.getLoopLatch())) 397 ExitingCondLCSSAPhi = LCSSAPhi; 398 } 399 400 // Add conditional branch to check we can skip post-loop in its preheader. 401 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 402 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 403 Value *Cond = 404 Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue); 405 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 406 OrigBI->eraseFromParent(); 407 408 // Create new loop bound and add it into preheader of pre-loop. 409 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 410 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 411 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 412 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 413 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 414 415 SCEVExpander Expander( 416 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); 417 Instruction *InsertPt = SplitLoopPH->getTerminator(); 418 Value *NewBoundValue = 419 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 420 NewBoundValue->setName("new.bound"); 421 422 // Replace exiting bound value of pre-loop NewBound. 423 ExitingCond.ICmp->setOperand(1, NewBoundValue); 424 425 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 426 LLVMContext &Context = PreHeader->getContext(); 427 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 428 429 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 430 BranchInst *ClonedSplitCandidateBI = 431 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 432 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 433 434 // Replace exit branch target of pre-loop by post-loop's preheader. 435 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 436 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 437 else 438 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 439 440 // Update phi node in exit block of post-loop. 441 Builder.SetInsertPoint(&PostLoopPreHeader->front()); 442 for (PHINode &PN : PostLoop->getExitBlock()->phis()) { 443 for (auto i : seq<int>(0, PN.getNumOperands())) { 444 // Check incoming block is pre-loop's exiting block. 445 if (PN.getIncomingBlock(i) == L.getExitingBlock()) { 446 Value *IncomingValue = PN.getIncomingValue(i); 447 448 // Create LCSSA phi node for incoming value. 449 PHINode *LCSSAPhi = 450 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 451 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 452 LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i)); 453 454 // Replace pre-loop's exiting block by post-loop's preheader. 455 PN.setIncomingBlock(i, PostLoopPreHeader); 456 // Replace incoming value by LCSSAPhi. 457 PN.setIncomingValue(i, LCSSAPhi); 458 // Add a new incoming value with post-loop's exiting block. 459 PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock()); 460 } 461 } 462 } 463 464 // Update dominator tree. 465 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 466 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 467 468 // Invalidate cached SE information. 469 SE.forgetLoop(&L); 470 471 // Canonicalize loops. 472 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 473 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 474 475 // Add new post-loop to loop pass manager. 476 U.addSiblingLoops(PostLoop); 477 478 return true; 479 } 480 481 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 482 LoopStandardAnalysisResults &AR, 483 LPMUpdater &U) { 484 Function &F = *L.getHeader()->getParent(); 485 (void)F; 486 487 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 488 << "\n"); 489 490 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 491 return PreservedAnalyses::all(); 492 493 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 494 AR.LI.verify(AR.DT); 495 496 return getLoopPassPreservedAnalyses(); 497 } 498 499 } // end namespace llvm 500