1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/ADT/Sequence.h" 11 #include "llvm/Analysis/LoopAnalysisManager.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/ScalarEvolution.h" 14 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 15 #include "llvm/IR/PatternMatch.h" 16 #include "llvm/Transforms/Scalar/LoopPassManager.h" 17 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 18 #include "llvm/Transforms/Utils/Cloning.h" 19 #include "llvm/Transforms/Utils/LoopSimplify.h" 20 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 21 22 #define DEBUG_TYPE "loop-bound-split" 23 24 namespace llvm { 25 26 using namespace PatternMatch; 27 28 namespace { 29 struct ConditionInfo { 30 /// Branch instruction with this condition 31 BranchInst *BI = nullptr; 32 /// ICmp instruction with this condition 33 ICmpInst *ICmp = nullptr; 34 /// Preciate info 35 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 36 /// AddRec llvm value 37 Value *AddRecValue = nullptr; 38 /// Non PHI AddRec llvm value 39 Value *NonPHIAddRecValue; 40 /// Bound llvm value 41 Value *BoundValue = nullptr; 42 /// AddRec SCEV 43 const SCEVAddRecExpr *AddRecSCEV = nullptr; 44 /// Bound SCEV 45 const SCEV *BoundSCEV = nullptr; 46 47 ConditionInfo() = default; 48 }; 49 } // namespace 50 51 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 52 ConditionInfo &Cond, const Loop &L) { 53 Cond.ICmp = ICmp; 54 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 55 m_Value(Cond.BoundValue)))) { 56 const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 57 const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue); 58 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 59 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV); 60 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 61 if (!LHSAddRecSCEV && RHSAddRecSCEV) { 62 std::swap(Cond.AddRecValue, Cond.BoundValue); 63 std::swap(AddRecSCEV, BoundSCEV); 64 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 65 } 66 67 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 68 Cond.BoundSCEV = BoundSCEV; 69 Cond.NonPHIAddRecValue = Cond.AddRecValue; 70 71 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with 72 // value from backedge. 73 if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) { 74 PHINode *PN = cast<PHINode>(Cond.AddRecValue); 75 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch()); 76 } 77 } 78 } 79 80 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 81 ConditionInfo &Cond, bool IsExitCond) { 82 if (IsExitCond) { 83 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 84 if (isa<SCEVCouldNotCompute>(ExitCount)) 85 return false; 86 87 Cond.BoundSCEV = ExitCount; 88 return true; 89 } 90 91 // For non-exit condtion, if pred is LT, keep existing bound. 92 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 93 return true; 94 95 // For non-exit condition, if pre is LE, try to convert it to LT. 96 // Range Range 97 // AddRec <= Bound --> AddRec < Bound + 1 98 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 99 return false; 100 101 if (IntegerType *BoundSCEVIntType = 102 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 103 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 104 APInt Max = ICmpInst::isSigned(Cond.Pred) 105 ? APInt::getSignedMaxValue(BitWidth) 106 : APInt::getMaxValue(BitWidth); 107 const SCEV *MaxSCEV = SE.getConstant(Max); 108 // Check Bound < INT_MAX 109 ICmpInst::Predicate Pred = 110 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 111 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 112 const SCEV *BoundPlusOneSCEV = 113 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 114 Cond.BoundSCEV = BoundPlusOneSCEV; 115 Cond.Pred = Pred; 116 return true; 117 } 118 } 119 120 // ToDo: Support ICMP_NE/EQ. 121 122 return false; 123 } 124 125 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 126 ICmpInst *ICmp, ConditionInfo &Cond, 127 bool IsExitCond) { 128 analyzeICmp(SE, ICmp, Cond, L); 129 130 // The BoundSCEV should be evaluated at loop entry. 131 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 132 return false; 133 134 // Allowed AddRec as induction variable. 135 if (!Cond.AddRecSCEV) 136 return false; 137 138 if (!Cond.AddRecSCEV->isAffine()) 139 return false; 140 141 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE); 142 // Allowed constant step. 143 if (!isa<SCEVConstant>(StepRecSCEV)) 144 return false; 145 146 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 147 // Allowed positive step for now. 148 // TODO: Support negative step. 149 if (StepCI->isNegative() || StepCI->isZero()) 150 return false; 151 152 // Calculate upper bound. 153 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 154 return false; 155 156 return true; 157 } 158 159 static bool isProcessableCondBI(const ScalarEvolution &SE, 160 const BranchInst *BI) { 161 BasicBlock *TrueSucc = nullptr; 162 BasicBlock *FalseSucc = nullptr; 163 ICmpInst::Predicate Pred; 164 Value *LHS, *RHS; 165 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 166 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 167 return false; 168 169 if (!SE.isSCEVable(LHS->getType())) 170 return false; 171 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 172 173 if (TrueSucc == FalseSucc) 174 return false; 175 176 return true; 177 } 178 179 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 180 ScalarEvolution &SE, ConditionInfo &Cond) { 181 // Skip function with optsize. 182 if (L.getHeader()->getParent()->hasOptSize()) 183 return false; 184 185 // Split only innermost loop. 186 if (!L.isInnermost()) 187 return false; 188 189 // Check loop is in simplified form. 190 if (!L.isLoopSimplifyForm()) 191 return false; 192 193 // Check loop is in LCSSA form. 194 if (!L.isLCSSAForm(DT)) 195 return false; 196 197 // Skip loop that cannot be cloned. 198 if (!L.isSafeToClone()) 199 return false; 200 201 BasicBlock *ExitingBB = L.getExitingBlock(); 202 // Assumed only one exiting block. 203 if (!ExitingBB) 204 return false; 205 206 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 207 if (!ExitingBI) 208 return false; 209 210 // Allowed only conditional branch with ICmp. 211 if (!isProcessableCondBI(SE, ExitingBI)) 212 return false; 213 214 // Check the condition is processable. 215 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 216 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 217 return false; 218 219 Cond.BI = ExitingBI; 220 return true; 221 } 222 223 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 224 // If the conditional branch splits a loop into two halves, we could 225 // generally say it is profitable. 226 // 227 // ToDo: Add more profitable cases here. 228 229 // Check this branch causes diamond CFG. 230 BasicBlock *Succ0 = BI->getSuccessor(0); 231 BasicBlock *Succ1 = BI->getSuccessor(1); 232 233 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 234 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 235 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 236 return false; 237 238 // ToDo: Calculate each successor's instruction cost. 239 240 return true; 241 } 242 243 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 244 ConditionInfo &ExitingCond, 245 ConditionInfo &SplitCandidateCond) { 246 for (auto *BB : L.blocks()) { 247 // Skip condition of backedge. 248 if (L.getLoopLatch() == BB) 249 continue; 250 251 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 252 if (!BI) 253 continue; 254 255 // Check conditional branch with ICmp. 256 if (!isProcessableCondBI(SE, BI)) 257 continue; 258 259 // Skip loop invariant condition. 260 if (L.isLoopInvariant(BI->getCondition())) 261 continue; 262 263 // Check the condition is processable. 264 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 265 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 266 /*IsExitCond*/ false)) 267 continue; 268 269 if (ExitingCond.BoundSCEV->getType() != 270 SplitCandidateCond.BoundSCEV->getType()) 271 continue; 272 273 // After transformation, we assume the split condition of the pre-loop is 274 // always true. In order to guarantee it, we need to check the start value 275 // of the split cond AddRec satisfies the split condition. 276 if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred, 277 SplitCandidateCond.AddRecSCEV->getStart(), 278 SplitCandidateCond.BoundSCEV)) 279 continue; 280 281 SplitCandidateCond.BI = BI; 282 return BI; 283 } 284 285 return nullptr; 286 } 287 288 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 289 ScalarEvolution &SE, LPMUpdater &U) { 290 ConditionInfo SplitCandidateCond; 291 ConditionInfo ExitingCond; 292 293 // Check we can split this loop's bound. 294 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 295 return false; 296 297 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 298 return false; 299 300 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 301 return false; 302 303 // Now, we have a split candidate. Let's build a form as below. 304 // +--------------------+ 305 // | preheader | 306 // | set up newbound | 307 // +--------------------+ 308 // | /----------------\ 309 // +--------v----v------+ | 310 // | header |---\ | 311 // | with true condition| | | 312 // +--------------------+ | | 313 // | | | 314 // +--------v-----------+ | | 315 // | if.then.BB | | | 316 // +--------------------+ | | 317 // | | | 318 // +--------v-----------<---/ | 319 // | latch >----------/ 320 // | with newbound | 321 // +--------------------+ 322 // | 323 // +--------v-----------+ 324 // | preheader2 |--------------\ 325 // | if (AddRec i != | | 326 // | org bound) | | 327 // +--------------------+ | 328 // | /----------------\ | 329 // +--------v----v------+ | | 330 // | header2 |---\ | | 331 // | conditional branch | | | | 332 // |with false condition| | | | 333 // +--------------------+ | | | 334 // | | | | 335 // +--------v-----------+ | | | 336 // | if.then.BB2 | | | | 337 // +--------------------+ | | | 338 // | | | | 339 // +--------v-----------<---/ | | 340 // | latch2 >----------/ | 341 // | with org bound | | 342 // +--------v-----------+ | 343 // | | 344 // | +---------------+ | 345 // +--> exit <-------/ 346 // +---------------+ 347 348 // Let's create post loop. 349 SmallVector<BasicBlock *, 8> PostLoopBlocks; 350 Loop *PostLoop; 351 ValueToValueMapTy VMap; 352 BasicBlock *PreHeader = L.getLoopPreheader(); 353 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 354 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 355 ".split", &LI, &DT, PostLoopBlocks); 356 remapInstructionsInBlocks(PostLoopBlocks, VMap); 357 358 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 359 IRBuilder<> Builder(&PostLoopPreHeader->front()); 360 361 // Update phi nodes in header of post-loop. 362 bool isExitingLatch = 363 (L.getExitingBlock() == L.getLoopLatch()) ? true : false; 364 Value *ExitingCondLCSSAPhi = nullptr; 365 for (PHINode &PN : L.getHeader()->phis()) { 366 // Create LCSSA phi node in preheader of post-loop. 367 PHINode *LCSSAPhi = 368 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 369 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 370 // If the exiting block is loop latch, the phi does not have the update at 371 // last iteration. In this case, update lcssa phi with value from backedge. 372 LCSSAPhi->addIncoming( 373 isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN, 374 L.getExitingBlock()); 375 376 // Update the start value of phi node in post-loop with the LCSSA phi node. 377 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 378 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi); 379 380 // Find PHI with exiting condition from pre-loop. The PHI should be 381 // SCEVAddRecExpr and have same incoming value from backedge with 382 // ExitingCond. 383 if (!SE.isSCEVable(PN.getType())) 384 continue; 385 386 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); 387 if (PhiSCEV && ExitingCond.NonPHIAddRecValue == 388 PN.getIncomingValueForBlock(L.getLoopLatch())) 389 ExitingCondLCSSAPhi = LCSSAPhi; 390 } 391 392 // Add conditional branch to check we can skip post-loop in its preheader. 393 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 394 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 395 Value *Cond = 396 Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue); 397 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 398 OrigBI->eraseFromParent(); 399 400 // Create new loop bound and add it into preheader of pre-loop. 401 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 402 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 403 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 404 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 405 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 406 407 SCEVExpander Expander( 408 SE, L.getHeader()->getDataLayout(), "split"); 409 Instruction *InsertPt = SplitLoopPH->getTerminator(); 410 Value *NewBoundValue = 411 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 412 NewBoundValue->setName("new.bound"); 413 414 // Replace exiting bound value of pre-loop NewBound. 415 ExitingCond.ICmp->setOperand(1, NewBoundValue); 416 417 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 418 LLVMContext &Context = PreHeader->getContext(); 419 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 420 421 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 422 BranchInst *ClonedSplitCandidateBI = 423 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 424 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 425 426 // Replace exit branch target of pre-loop by post-loop's preheader. 427 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 428 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 429 else 430 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 431 432 // Update phi node in exit block of post-loop. 433 Builder.SetInsertPoint(PostLoopPreHeader, PostLoopPreHeader->begin()); 434 for (PHINode &PN : PostLoop->getExitBlock()->phis()) { 435 for (auto i : seq<int>(0, PN.getNumOperands())) { 436 // Check incoming block is pre-loop's exiting block. 437 if (PN.getIncomingBlock(i) == L.getExitingBlock()) { 438 Value *IncomingValue = PN.getIncomingValue(i); 439 440 // Create LCSSA phi node for incoming value. 441 PHINode *LCSSAPhi = 442 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 443 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 444 LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i)); 445 446 // Replace pre-loop's exiting block by post-loop's preheader. 447 PN.setIncomingBlock(i, PostLoopPreHeader); 448 // Replace incoming value by LCSSAPhi. 449 PN.setIncomingValue(i, LCSSAPhi); 450 // Add a new incoming value with post-loop's exiting block. 451 PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock()); 452 } 453 } 454 } 455 456 // Update dominator tree. 457 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 458 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 459 460 // Invalidate cached SE information. 461 SE.forgetLoop(&L); 462 463 // Canonicalize loops. 464 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 465 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 466 467 // Add new post-loop to loop pass manager. 468 U.addSiblingLoops(PostLoop); 469 470 return true; 471 } 472 473 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 474 LoopStandardAnalysisResults &AR, 475 LPMUpdater &U) { 476 Function &F = *L.getHeader()->getParent(); 477 (void)F; 478 479 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 480 << "\n"); 481 482 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 483 return PreservedAnalyses::all(); 484 485 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 486 AR.LI.verify(AR.DT); 487 488 return getLoopPassPreservedAnalyses(); 489 } 490 491 } // end namespace llvm 492