1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/Analysis/LoopAccessAnalysis.h" 11 #include "llvm/Analysis/LoopAnalysisManager.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/LoopIterator.h" 14 #include "llvm/Analysis/LoopPass.h" 15 #include "llvm/Analysis/MemorySSA.h" 16 #include "llvm/Analysis/MemorySSAUpdater.h" 17 #include "llvm/Analysis/ScalarEvolution.h" 18 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 19 #include "llvm/IR/PatternMatch.h" 20 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 21 #include "llvm/Transforms/Utils/Cloning.h" 22 #include "llvm/Transforms/Utils/LoopSimplify.h" 23 #include "llvm/Transforms/Utils/LoopUtils.h" 24 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 25 26 #define DEBUG_TYPE "loop-bound-split" 27 28 namespace llvm { 29 30 using namespace PatternMatch; 31 32 namespace { 33 struct ConditionInfo { 34 /// Branch instruction with this condition 35 BranchInst *BI; 36 /// ICmp instruction with this condition 37 ICmpInst *ICmp; 38 /// Preciate info 39 ICmpInst::Predicate Pred; 40 /// AddRec llvm value 41 Value *AddRecValue; 42 /// Bound llvm value 43 Value *BoundValue; 44 /// AddRec SCEV 45 const SCEV *AddRecSCEV; 46 /// Bound SCEV 47 const SCEV *BoundSCEV; 48 49 ConditionInfo() 50 : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE), 51 AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr), 52 BoundSCEV(nullptr) {} 53 }; 54 } // namespace 55 56 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 57 ConditionInfo &Cond) { 58 Cond.ICmp = ICmp; 59 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 60 m_Value(Cond.BoundValue)))) { 61 Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 62 Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue); 63 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 64 if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) && 65 !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) { 66 std::swap(Cond.AddRecValue, Cond.BoundValue); 67 std::swap(Cond.AddRecSCEV, Cond.BoundSCEV); 68 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 69 } 70 } 71 } 72 73 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 74 ConditionInfo &Cond, bool IsExitCond) { 75 if (IsExitCond) { 76 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 77 if (isa<SCEVCouldNotCompute>(ExitCount)) 78 return false; 79 80 Cond.BoundSCEV = ExitCount; 81 return true; 82 } 83 84 // For non-exit condtion, if pred is LT, keep existing bound. 85 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 86 return true; 87 88 // For non-exit condition, if pre is LE, try to convert it to LT. 89 // Range Range 90 // AddRec <= Bound --> AddRec < Bound + 1 91 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 92 return false; 93 94 if (IntegerType *BoundSCEVIntType = 95 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 96 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 97 APInt Max = ICmpInst::isSigned(Cond.Pred) 98 ? APInt::getSignedMaxValue(BitWidth) 99 : APInt::getMaxValue(BitWidth); 100 const SCEV *MaxSCEV = SE.getConstant(Max); 101 // Check Bound < INT_MAX 102 ICmpInst::Predicate Pred = 103 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 104 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 105 const SCEV *BoundPlusOneSCEV = 106 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 107 Cond.BoundSCEV = BoundPlusOneSCEV; 108 Cond.Pred = Pred; 109 return true; 110 } 111 } 112 113 // ToDo: Support ICMP_NE/EQ. 114 115 return false; 116 } 117 118 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 119 ICmpInst *ICmp, ConditionInfo &Cond, 120 bool IsExitCond) { 121 analyzeICmp(SE, ICmp, Cond); 122 123 // The BoundSCEV should be evaluated at loop entry. 124 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 125 return false; 126 127 const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV); 128 // Allowed AddRec as induction variable. 129 if (!AddRecSCEV) 130 return false; 131 132 if (!AddRecSCEV->isAffine()) 133 return false; 134 135 const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE); 136 // Allowed constant step. 137 if (!isa<SCEVConstant>(StepRecSCEV)) 138 return false; 139 140 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 141 // Allowed positive step for now. 142 // TODO: Support negative step. 143 if (StepCI->isNegative() || StepCI->isZero()) 144 return false; 145 146 // Calculate upper bound. 147 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 148 return false; 149 150 return true; 151 } 152 153 static bool isProcessableCondBI(const ScalarEvolution &SE, 154 const BranchInst *BI) { 155 BasicBlock *TrueSucc = nullptr; 156 BasicBlock *FalseSucc = nullptr; 157 ICmpInst::Predicate Pred; 158 Value *LHS, *RHS; 159 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 160 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 161 return false; 162 163 if (!SE.isSCEVable(LHS->getType())) 164 return false; 165 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 166 167 if (TrueSucc == FalseSucc) 168 return false; 169 170 return true; 171 } 172 173 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 174 ScalarEvolution &SE, ConditionInfo &Cond) { 175 // Skip function with optsize. 176 if (L.getHeader()->getParent()->hasOptSize()) 177 return false; 178 179 // Split only innermost loop. 180 if (!L.isInnermost()) 181 return false; 182 183 // Check loop is in simplified form. 184 if (!L.isLoopSimplifyForm()) 185 return false; 186 187 // Check loop is in LCSSA form. 188 if (!L.isLCSSAForm(DT)) 189 return false; 190 191 // Skip loop that cannot be cloned. 192 if (!L.isSafeToClone()) 193 return false; 194 195 BasicBlock *ExitingBB = L.getExitingBlock(); 196 // Assumed only one exiting block. 197 if (!ExitingBB) 198 return false; 199 200 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 201 if (!ExitingBI) 202 return false; 203 204 // Allowed only conditional branch with ICmp. 205 if (!isProcessableCondBI(SE, ExitingBI)) 206 return false; 207 208 // Check the condition is processable. 209 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 210 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 211 return false; 212 213 Cond.BI = ExitingBI; 214 return true; 215 } 216 217 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 218 // If the conditional branch splits a loop into two halves, we could 219 // generally say it is profitable. 220 // 221 // ToDo: Add more profitable cases here. 222 223 // Check this branch causes diamond CFG. 224 BasicBlock *Succ0 = BI->getSuccessor(0); 225 BasicBlock *Succ1 = BI->getSuccessor(1); 226 227 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 228 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 229 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 230 return false; 231 232 // ToDo: Calculate each successor's instruction cost. 233 234 return true; 235 } 236 237 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 238 ConditionInfo &ExitingCond, 239 ConditionInfo &SplitCandidateCond) { 240 for (auto *BB : L.blocks()) { 241 // Skip condition of backedge. 242 if (L.getLoopLatch() == BB) 243 continue; 244 245 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 246 if (!BI) 247 continue; 248 249 // Check conditional branch with ICmp. 250 if (!isProcessableCondBI(SE, BI)) 251 continue; 252 253 // Skip loop invariant condition. 254 if (L.isLoopInvariant(BI->getCondition())) 255 continue; 256 257 // Check the condition is processable. 258 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 259 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 260 /*IsExitCond*/ false)) 261 continue; 262 263 if (ExitingCond.BoundSCEV->getType() != 264 SplitCandidateCond.BoundSCEV->getType()) 265 continue; 266 267 SplitCandidateCond.BI = BI; 268 return BI; 269 } 270 271 return nullptr; 272 } 273 274 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 275 ScalarEvolution &SE, LPMUpdater &U) { 276 ConditionInfo SplitCandidateCond; 277 ConditionInfo ExitingCond; 278 279 // Check we can split this loop's bound. 280 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 281 return false; 282 283 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 284 return false; 285 286 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 287 return false; 288 289 // Now, we have a split candidate. Let's build a form as below. 290 // +--------------------+ 291 // | preheader | 292 // | set up newbound | 293 // +--------------------+ 294 // | /----------------\ 295 // +--------v----v------+ | 296 // | header |---\ | 297 // | with true condition| | | 298 // +--------------------+ | | 299 // | | | 300 // +--------v-----------+ | | 301 // | if.then.BB | | | 302 // +--------------------+ | | 303 // | | | 304 // +--------v-----------<---/ | 305 // | latch >----------/ 306 // | with newbound | 307 // +--------------------+ 308 // | 309 // +--------v-----------+ 310 // | preheader2 |--------------\ 311 // | if (AddRec i != | | 312 // | org bound) | | 313 // +--------------------+ | 314 // | /----------------\ | 315 // +--------v----v------+ | | 316 // | header2 |---\ | | 317 // | conditional branch | | | | 318 // |with false condition| | | | 319 // +--------------------+ | | | 320 // | | | | 321 // +--------v-----------+ | | | 322 // | if.then.BB2 | | | | 323 // +--------------------+ | | | 324 // | | | | 325 // +--------v-----------<---/ | | 326 // | latch2 >----------/ | 327 // | with org bound | | 328 // +--------v-----------+ | 329 // | | 330 // | +---------------+ | 331 // +--> exit <-------/ 332 // +---------------+ 333 334 // Let's create post loop. 335 SmallVector<BasicBlock *, 8> PostLoopBlocks; 336 Loop *PostLoop; 337 ValueToValueMapTy VMap; 338 BasicBlock *PreHeader = L.getLoopPreheader(); 339 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 340 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 341 ".split", &LI, &DT, PostLoopBlocks); 342 remapInstructionsInBlocks(PostLoopBlocks, VMap); 343 344 // Add conditional branch to check we can skip post-loop in its preheader. 345 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 346 IRBuilder<> Builder(PostLoopPreHeader); 347 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 348 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 349 Value *Cond = 350 Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue); 351 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 352 OrigBI->eraseFromParent(); 353 354 // Create new loop bound and add it into preheader of pre-loop. 355 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 356 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 357 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 358 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 359 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 360 361 SCEVExpander Expander( 362 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); 363 Instruction *InsertPt = SplitLoopPH->getTerminator(); 364 Value *NewBoundValue = 365 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 366 NewBoundValue->setName("new.bound"); 367 368 // Replace exiting bound value of pre-loop NewBound. 369 ExitingCond.ICmp->setOperand(1, NewBoundValue); 370 371 // Replace IV's start value of post-loop by NewBound. 372 for (PHINode &PN : L.getHeader()->phis()) { 373 // Find PHI with exiting condition from pre-loop. 374 if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) { 375 for (Value *Op : PN.incoming_values()) { 376 if (Op == ExitingCond.AddRecValue) { 377 // Find cloned PHI for post-loop. 378 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 379 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, 380 NewBoundValue); 381 } 382 } 383 } 384 } 385 386 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 387 LLVMContext &Context = PreHeader->getContext(); 388 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 389 390 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 391 BranchInst *ClonedSplitCandidateBI = 392 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 393 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 394 395 // Replace exit branch target of pre-loop by post-loop's preheader. 396 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 397 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 398 else 399 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 400 401 // Update dominator tree. 402 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 403 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 404 405 // Invalidate cached SE information. 406 SE.forgetLoop(&L); 407 408 // Canonicalize loops. 409 // TODO: Try to update LCSSA information according to above change. 410 formLCSSA(L, DT, &LI, &SE); 411 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 412 formLCSSA(*PostLoop, DT, &LI, &SE); 413 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 414 415 // Add new post-loop to loop pass manager. 416 U.addSiblingLoops(PostLoop); 417 418 return true; 419 } 420 421 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 422 LoopStandardAnalysisResults &AR, 423 LPMUpdater &U) { 424 Function &F = *L.getHeader()->getParent(); 425 (void)F; 426 427 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 428 << "\n"); 429 430 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 431 return PreservedAnalyses::all(); 432 433 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 434 AR.LI.verify(AR.DT); 435 436 return getLoopPassPreservedAnalyses(); 437 } 438 439 } // end namespace llvm 440