1 #include "llvm/Transforms/Utils/LoopConstrainer.h" 2 #include "llvm/Analysis/LoopInfo.h" 3 #include "llvm/Analysis/ScalarEvolution.h" 4 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 5 #include "llvm/IR/Dominators.h" 6 #include "llvm/Transforms/Utils/Cloning.h" 7 #include "llvm/Transforms/Utils/LoopSimplify.h" 8 #include "llvm/Transforms/Utils/LoopUtils.h" 9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 10 11 using namespace llvm; 12 13 static const char *ClonedLoopTag = "loop_constrainer.loop.clone"; 14 15 #define DEBUG_TYPE "loop-constrainer" 16 17 /// Given a loop with an deccreasing induction variable, is it possible to 18 /// safely calculate the bounds of a new loop using the given Predicate. 19 static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, 20 const SCEV *Step, ICmpInst::Predicate Pred, 21 unsigned LatchBrExitIdx, Loop *L, 22 ScalarEvolution &SE) { 23 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && 24 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) 25 return false; 26 27 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) 28 return false; 29 30 assert(SE.isKnownNegative(Step) && "expecting negative step"); 31 32 LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n"); 33 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); 34 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); 35 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); 36 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); 37 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); 38 39 bool IsSigned = ICmpInst::isSigned(Pred); 40 // The predicate that we need to check that the induction variable lies 41 // within bounds. 42 ICmpInst::Predicate BoundPred = 43 IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; 44 45 if (LatchBrExitIdx == 1) 46 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); 47 48 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); 49 50 const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); 51 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); 52 APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) 53 : APInt::getMinValue(BitWidth); 54 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); 55 56 const SCEV *MinusOne = 57 SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); 58 59 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && 60 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); 61 } 62 63 /// Given a loop with an increasing induction variable, is it possible to 64 /// safely calculate the bounds of a new loop using the given Predicate. 65 static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, 66 const SCEV *Step, ICmpInst::Predicate Pred, 67 unsigned LatchBrExitIdx, Loop *L, 68 ScalarEvolution &SE) { 69 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && 70 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) 71 return false; 72 73 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) 74 return false; 75 76 LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n"); 77 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); 78 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); 79 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); 80 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); 81 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); 82 83 bool IsSigned = ICmpInst::isSigned(Pred); 84 // The predicate that we need to check that the induction variable lies 85 // within bounds. 86 ICmpInst::Predicate BoundPred = 87 IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; 88 89 if (LatchBrExitIdx == 1) 90 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); 91 92 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); 93 94 const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType())); 95 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); 96 APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) 97 : APInt::getMaxValue(BitWidth); 98 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); 99 100 return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, 101 SE.getAddExpr(BoundSCEV, Step)) && 102 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); 103 } 104 105 /// Returns estimate for max latch taken count of the loop of the narrowest 106 /// available type. If the latch block has such estimate, it is returned. 107 /// Otherwise, we use max exit count of whole loop (that is potentially of wider 108 /// type than latch check itself), which is still better than no estimate. 109 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, 110 const Loop &L) { 111 const SCEV *FromBlock = 112 SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); 113 if (isa<SCEVCouldNotCompute>(FromBlock)) 114 return SE.getSymbolicMaxBackedgeTakenCount(&L); 115 return FromBlock; 116 } 117 118 std::optional<LoopStructure> 119 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, 120 bool AllowUnsignedLatchCond, 121 const char *&FailureReason) { 122 if (!L.isLoopSimplifyForm()) { 123 FailureReason = "loop not in LoopSimplify form"; 124 return std::nullopt; 125 } 126 127 BasicBlock *Latch = L.getLoopLatch(); 128 assert(Latch && "Simplified loops only have one latch!"); 129 130 if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { 131 FailureReason = "loop has already been cloned"; 132 return std::nullopt; 133 } 134 135 if (!L.isLoopExiting(Latch)) { 136 FailureReason = "no loop latch"; 137 return std::nullopt; 138 } 139 140 BasicBlock *Header = L.getHeader(); 141 BasicBlock *Preheader = L.getLoopPreheader(); 142 if (!Preheader) { 143 FailureReason = "no preheader"; 144 return std::nullopt; 145 } 146 147 BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); 148 if (!LatchBr || LatchBr->isUnconditional()) { 149 FailureReason = "latch terminator not conditional branch"; 150 return std::nullopt; 151 } 152 153 unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; 154 155 ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); 156 if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { 157 FailureReason = "latch terminator branch not conditional on integral icmp"; 158 return std::nullopt; 159 } 160 161 const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); 162 if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { 163 FailureReason = "could not compute latch count"; 164 return std::nullopt; 165 } 166 assert(SE.getLoopDisposition(MaxBETakenCount, &L) == 167 ScalarEvolution::LoopInvariant && 168 "loop variant exit count doesn't make sense!"); 169 170 ICmpInst::Predicate Pred = ICI->getPredicate(); 171 Value *LeftValue = ICI->getOperand(0); 172 const SCEV *LeftSCEV = SE.getSCEV(LeftValue); 173 IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); 174 175 Value *RightValue = ICI->getOperand(1); 176 const SCEV *RightSCEV = SE.getSCEV(RightValue); 177 178 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. 179 if (!isa<SCEVAddRecExpr>(LeftSCEV)) { 180 if (isa<SCEVAddRecExpr>(RightSCEV)) { 181 std::swap(LeftSCEV, RightSCEV); 182 std::swap(LeftValue, RightValue); 183 Pred = ICmpInst::getSwappedPredicate(Pred); 184 } else { 185 FailureReason = "no add recurrences in the icmp"; 186 return std::nullopt; 187 } 188 } 189 190 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { 191 if (AR->getNoWrapFlags(SCEV::FlagNSW)) 192 return true; 193 194 IntegerType *Ty = cast<IntegerType>(AR->getType()); 195 IntegerType *WideTy = 196 IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); 197 198 const SCEVAddRecExpr *ExtendAfterOp = 199 dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); 200 if (ExtendAfterOp) { 201 const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); 202 const SCEV *ExtendedStep = 203 SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); 204 205 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && 206 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; 207 208 if (NoSignedWrap) 209 return true; 210 } 211 212 // We may have proved this when computing the sign extension above. 213 return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; 214 }; 215 216 // `ICI` is interpreted as taking the backedge if the *next* value of the 217 // induction variable satisfies some constraint. 218 219 const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); 220 if (IndVarBase->getLoop() != &L) { 221 FailureReason = "LHS in cmp is not an AddRec for this loop"; 222 return std::nullopt; 223 } 224 if (!IndVarBase->isAffine()) { 225 FailureReason = "LHS in icmp not induction variable"; 226 return std::nullopt; 227 } 228 const SCEV *StepRec = IndVarBase->getStepRecurrence(SE); 229 if (!isa<SCEVConstant>(StepRec)) { 230 FailureReason = "LHS in icmp not induction variable"; 231 return std::nullopt; 232 } 233 ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); 234 235 if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { 236 FailureReason = "LHS in icmp needs nsw for equality predicates"; 237 return std::nullopt; 238 } 239 240 assert(!StepCI->isZero() && "Zero step?"); 241 bool IsIncreasing = !StepCI->isNegative(); 242 bool IsSignedPredicate; 243 const SCEV *StartNext = IndVarBase->getStart(); 244 const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); 245 const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); 246 const SCEV *Step = SE.getSCEV(StepCI); 247 248 const SCEV *FixedRightSCEV = nullptr; 249 250 // If RightValue resides within loop (but still being loop invariant), 251 // regenerate it as preheader. 252 if (auto *I = dyn_cast<Instruction>(RightValue)) 253 if (L.contains(I->getParent())) 254 FixedRightSCEV = RightSCEV; 255 256 if (IsIncreasing) { 257 bool DecreasedRightValueByOne = false; 258 if (StepCI->isOne()) { 259 // Try to turn eq/ne predicates to those we can work with. 260 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) 261 // while (++i != len) { while (++i < len) { 262 // ... ---> ... 263 // } } 264 // If both parts are known non-negative, it is profitable to use 265 // unsigned comparison in increasing loop. This allows us to make the 266 // comparison check against "RightSCEV + 1" more optimistic. 267 if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && 268 isKnownNonNegativeInLoop(RightSCEV, &L, SE)) 269 Pred = ICmpInst::ICMP_ULT; 270 else 271 Pred = ICmpInst::ICMP_SLT; 272 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { 273 // while (true) { while (true) { 274 // if (++i == len) ---> if (++i > len - 1) 275 // break; break; 276 // ... ... 277 // } } 278 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && 279 cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) { 280 Pred = ICmpInst::ICMP_UGT; 281 RightSCEV = 282 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 283 DecreasedRightValueByOne = true; 284 } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) { 285 Pred = ICmpInst::ICMP_SGT; 286 RightSCEV = 287 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 288 DecreasedRightValueByOne = true; 289 } 290 } 291 } 292 293 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); 294 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); 295 bool FoundExpectedPred = 296 (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); 297 298 if (!FoundExpectedPred) { 299 FailureReason = "expected icmp slt semantically, found something else"; 300 return std::nullopt; 301 } 302 303 IsSignedPredicate = ICmpInst::isSigned(Pred); 304 if (!IsSignedPredicate && !AllowUnsignedLatchCond) { 305 FailureReason = "unsigned latch conditions are explicitly prohibited"; 306 return std::nullopt; 307 } 308 309 if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, 310 LatchBrExitIdx, &L, SE)) { 311 FailureReason = "Unsafe loop bounds"; 312 return std::nullopt; 313 } 314 if (LatchBrExitIdx == 0) { 315 // We need to increase the right value unless we have already decreased 316 // it virtually when we replaced EQ with SGT. 317 if (!DecreasedRightValueByOne) 318 FixedRightSCEV = 319 SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 320 } else { 321 assert(!DecreasedRightValueByOne && 322 "Right value can be decreased only for LatchBrExitIdx == 0!"); 323 } 324 } else { 325 bool IncreasedRightValueByOne = false; 326 if (StepCI->isMinusOne()) { 327 // Try to turn eq/ne predicates to those we can work with. 328 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) 329 // while (--i != len) { while (--i > len) { 330 // ... ---> ... 331 // } } 332 // We intentionally don't turn the predicate into UGT even if we know 333 // that both operands are non-negative, because it will only pessimize 334 // our check against "RightSCEV - 1". 335 Pred = ICmpInst::ICMP_SGT; 336 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { 337 // while (true) { while (true) { 338 // if (--i == len) ---> if (--i < len + 1) 339 // break; break; 340 // ... ... 341 // } } 342 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && 343 cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { 344 Pred = ICmpInst::ICMP_ULT; 345 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 346 IncreasedRightValueByOne = true; 347 } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { 348 Pred = ICmpInst::ICMP_SLT; 349 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 350 IncreasedRightValueByOne = true; 351 } 352 } 353 } 354 355 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); 356 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); 357 358 bool FoundExpectedPred = 359 (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); 360 361 if (!FoundExpectedPred) { 362 FailureReason = "expected icmp sgt semantically, found something else"; 363 return std::nullopt; 364 } 365 366 IsSignedPredicate = 367 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; 368 369 if (!IsSignedPredicate && !AllowUnsignedLatchCond) { 370 FailureReason = "unsigned latch conditions are explicitly prohibited"; 371 return std::nullopt; 372 } 373 374 if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, 375 LatchBrExitIdx, &L, SE)) { 376 FailureReason = "Unsafe bounds"; 377 return std::nullopt; 378 } 379 380 if (LatchBrExitIdx == 0) { 381 // We need to decrease the right value unless we have already increased 382 // it virtually when we replaced EQ with SLT. 383 if (!IncreasedRightValueByOne) 384 FixedRightSCEV = 385 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 386 } else { 387 assert(!IncreasedRightValueByOne && 388 "Right value can be increased only for LatchBrExitIdx == 0!"); 389 } 390 } 391 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); 392 393 assert(!L.contains(LatchExit) && "expected an exit block!"); 394 const DataLayout &DL = Preheader->getModule()->getDataLayout(); 395 SCEVExpander Expander(SE, DL, "loop-constrainer"); 396 Instruction *Ins = Preheader->getTerminator(); 397 398 if (FixedRightSCEV) 399 RightValue = 400 Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); 401 402 Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); 403 IndVarStartV->setName("indvar.start"); 404 405 LoopStructure Result; 406 407 Result.Tag = "main"; 408 Result.Header = Header; 409 Result.Latch = Latch; 410 Result.LatchBr = LatchBr; 411 Result.LatchExit = LatchExit; 412 Result.LatchBrExitIdx = LatchBrExitIdx; 413 Result.IndVarStart = IndVarStartV; 414 Result.IndVarStep = StepCI; 415 Result.IndVarBase = LeftValue; 416 Result.IndVarIncreasing = IsIncreasing; 417 Result.LoopExitAt = RightValue; 418 Result.IsSignedPredicate = IsSignedPredicate; 419 Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); 420 421 FailureReason = nullptr; 422 423 return Result; 424 } 425 426 // Add metadata to the loop L to disable loop optimizations. Callers need to 427 // confirm that optimizing loop L is not beneficial. 428 static void DisableAllLoopOptsOnLoop(Loop &L) { 429 // We do not care about any existing loopID related metadata for L, since we 430 // are setting all loop metadata to false. 431 LLVMContext &Context = L.getHeader()->getContext(); 432 // Reserve first location for self reference to the LoopID metadata node. 433 MDNode *Dummy = MDNode::get(Context, {}); 434 MDNode *DisableUnroll = MDNode::get( 435 Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); 436 Metadata *FalseVal = 437 ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); 438 MDNode *DisableVectorize = MDNode::get( 439 Context, 440 {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); 441 MDNode *DisableLICMVersioning = MDNode::get( 442 Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); 443 MDNode *DisableDistribution = MDNode::get( 444 Context, 445 {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); 446 MDNode *NewLoopID = 447 MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, 448 DisableLICMVersioning, DisableDistribution}); 449 // Set operand 0 to refer to the loop id itself. 450 NewLoopID->replaceOperandWith(0, NewLoopID); 451 L.setLoopID(NewLoopID); 452 } 453 454 LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI, 455 function_ref<void(Loop *, bool)> LPMAddNewLoop, 456 const LoopStructure &LS, ScalarEvolution &SE, 457 DominatorTree &DT, Type *T, SubRanges SR) 458 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), 459 DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T), 460 MainLoopStructure(LS), SR(SR) {} 461 462 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, 463 const char *Tag) const { 464 for (BasicBlock *BB : OriginalLoop.getBlocks()) { 465 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); 466 Result.Blocks.push_back(Clone); 467 Result.Map[BB] = Clone; 468 } 469 470 auto GetClonedValue = [&Result](Value *V) { 471 assert(V && "null values not in domain!"); 472 auto It = Result.Map.find(V); 473 if (It == Result.Map.end()) 474 return V; 475 return static_cast<Value *>(It->second); 476 }; 477 478 auto *ClonedLatch = 479 cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); 480 ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, 481 MDNode::get(Ctx, {})); 482 483 Result.Structure = MainLoopStructure.map(GetClonedValue); 484 Result.Structure.Tag = Tag; 485 486 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { 487 BasicBlock *ClonedBB = Result.Blocks[i]; 488 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; 489 490 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); 491 492 for (Instruction &I : *ClonedBB) 493 RemapInstruction(&I, Result.Map, 494 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); 495 496 // Exit blocks will now have one more predecessor and their PHI nodes need 497 // to be edited to reflect that. No phi nodes need to be introduced because 498 // the loop is in LCSSA. 499 500 for (auto *SBB : successors(OriginalBB)) { 501 if (OriginalLoop.contains(SBB)) 502 continue; // not an exit block 503 504 for (PHINode &PN : SBB->phis()) { 505 Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); 506 PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); 507 SE.forgetValue(&PN); 508 } 509 } 510 } 511 } 512 513 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( 514 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, 515 BasicBlock *ContinuationBlock) const { 516 // We start with a loop with a single latch: 517 // 518 // +--------------------+ 519 // | | 520 // | preheader | 521 // | | 522 // +--------+-----------+ 523 // | ----------------\ 524 // | / | 525 // +--------v----v------+ | 526 // | | | 527 // | header | | 528 // | | | 529 // +--------------------+ | 530 // | 531 // ..... | 532 // | 533 // +--------------------+ | 534 // | | | 535 // | latch >----------/ 536 // | | 537 // +-------v------------+ 538 // | 539 // | 540 // | +--------------------+ 541 // | | | 542 // +---> original exit | 543 // | | 544 // +--------------------+ 545 // 546 // We change the control flow to look like 547 // 548 // 549 // +--------------------+ 550 // | | 551 // | preheader >-------------------------+ 552 // | | | 553 // +--------v-----------+ | 554 // | /-------------+ | 555 // | / | | 556 // +--------v--v--------+ | | 557 // | | | | 558 // | header | | +--------+ | 559 // | | | | | | 560 // +--------------------+ | | +-----v-----v-----------+ 561 // | | | | 562 // | | | .pseudo.exit | 563 // | | | | 564 // | | +-----------v-----------+ 565 // | | | 566 // ..... | | | 567 // | | +--------v-------------+ 568 // +--------------------+ | | | | 569 // | | | | | ContinuationBlock | 570 // | latch >------+ | | | 571 // | | | +----------------------+ 572 // +---------v----------+ | 573 // | | 574 // | | 575 // | +---------------^-----+ 576 // | | | 577 // +-----> .exit.selector | 578 // | | 579 // +----------v----------+ 580 // | 581 // +--------------------+ | 582 // | | | 583 // | original exit <----+ 584 // | | 585 // +--------------------+ 586 587 RewrittenRangeInfo RRI; 588 589 BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); 590 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", 591 &F, BBInsertLocation); 592 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, 593 BBInsertLocation); 594 595 BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); 596 bool Increasing = LS.IndVarIncreasing; 597 bool IsSignedPredicate = LS.IsSignedPredicate; 598 599 IRBuilder<> B(PreheaderJump); 600 auto NoopOrExt = [&](Value *V) { 601 if (V->getType() == RangeTy) 602 return V; 603 return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) 604 : B.CreateZExt(V, RangeTy, "wide." + V->getName()); 605 }; 606 607 // EnterLoopCond - is it okay to start executing this `LS'? 608 Value *EnterLoopCond = nullptr; 609 auto Pred = 610 Increasing 611 ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) 612 : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); 613 Value *IndVarStart = NoopOrExt(LS.IndVarStart); 614 EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); 615 616 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); 617 PreheaderJump->eraseFromParent(); 618 619 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); 620 B.SetInsertPoint(LS.LatchBr); 621 Value *IndVarBase = NoopOrExt(LS.IndVarBase); 622 Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); 623 624 Value *CondForBranch = LS.LatchBrExitIdx == 1 625 ? TakeBackedgeLoopCond 626 : B.CreateNot(TakeBackedgeLoopCond); 627 628 LS.LatchBr->setCondition(CondForBranch); 629 630 B.SetInsertPoint(RRI.ExitSelector); 631 632 // IterationsLeft - are there any more iterations left, given the original 633 // upper bound on the induction variable? If not, we branch to the "real" 634 // exit. 635 Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); 636 Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); 637 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); 638 639 BranchInst *BranchToContinuation = 640 BranchInst::Create(ContinuationBlock, RRI.PseudoExit); 641 642 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of 643 // each of the PHI nodes in the loop header. This feeds into the initial 644 // value of the same PHI nodes if/when we continue execution. 645 for (PHINode &PN : LS.Header->phis()) { 646 PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", 647 BranchToContinuation); 648 649 NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); 650 NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), 651 RRI.ExitSelector); 652 RRI.PHIValuesAtPseudoExit.push_back(NewPHI); 653 } 654 655 RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", 656 BranchToContinuation); 657 RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); 658 RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); 659 660 // The latch exit now has a branch from `RRI.ExitSelector' instead of 661 // `LS.Latch'. The PHI nodes need to be updated to reflect that. 662 LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); 663 664 return RRI; 665 } 666 667 void LoopConstrainer::rewriteIncomingValuesForPHIs( 668 LoopStructure &LS, BasicBlock *ContinuationBlock, 669 const LoopConstrainer::RewrittenRangeInfo &RRI) const { 670 unsigned PHIIndex = 0; 671 for (PHINode &PN : LS.Header->phis()) 672 PN.setIncomingValueForBlock(ContinuationBlock, 673 RRI.PHIValuesAtPseudoExit[PHIIndex++]); 674 675 LS.IndVarStart = RRI.IndVarEnd; 676 } 677 678 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, 679 BasicBlock *OldPreheader, 680 const char *Tag) const { 681 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); 682 BranchInst::Create(LS.Header, Preheader); 683 684 LS.Header->replacePhiUsesWith(OldPreheader, Preheader); 685 686 return Preheader; 687 } 688 689 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { 690 Loop *ParentLoop = OriginalLoop.getParentLoop(); 691 if (!ParentLoop) 692 return; 693 694 for (BasicBlock *BB : BBs) 695 ParentLoop->addBasicBlockToLoop(BB, LI); 696 } 697 698 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, 699 ValueToValueMapTy &VM, 700 bool IsSubloop) { 701 Loop &New = *LI.AllocateLoop(); 702 if (Parent) 703 Parent->addChildLoop(&New); 704 else 705 LI.addTopLevelLoop(&New); 706 LPMAddNewLoop(&New, IsSubloop); 707 708 // Add all of the blocks in Original to the new loop. 709 for (auto *BB : Original->blocks()) 710 if (LI.getLoopFor(BB) == Original) 711 New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); 712 713 // Add all of the subloops to the new loop. 714 for (Loop *SubLoop : *Original) 715 createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); 716 717 return &New; 718 } 719 720 bool LoopConstrainer::run() { 721 BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); 722 assert(Preheader != nullptr && "precondition!"); 723 724 OriginalPreheader = Preheader; 725 MainLoopPreheader = Preheader; 726 bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; 727 bool Increasing = MainLoopStructure.IndVarIncreasing; 728 IntegerType *IVTy = cast<IntegerType>(RangeTy); 729 730 SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer"); 731 Instruction *InsertPt = OriginalPreheader->getTerminator(); 732 733 // It would have been better to make `PreLoop' and `PostLoop' 734 // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy 735 // constructor. 736 ClonedLoop PreLoop, PostLoop; 737 bool NeedsPreLoop = 738 Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); 739 bool NeedsPostLoop = 740 Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); 741 742 Value *ExitPreLoopAt = nullptr; 743 Value *ExitMainLoopAt = nullptr; 744 const SCEVConstant *MinusOneS = 745 cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); 746 747 if (NeedsPreLoop) { 748 const SCEV *ExitPreLoopAtSCEV = nullptr; 749 750 if (Increasing) 751 ExitPreLoopAtSCEV = *SR.LowLimit; 752 else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, 753 IsSignedPredicate)) 754 ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); 755 else { 756 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " 757 << "preloop exit limit. HighLimit = " 758 << *(*SR.HighLimit) << "\n"); 759 return false; 760 } 761 762 if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { 763 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" 764 << " preloop exit limit " << *ExitPreLoopAtSCEV 765 << " at block " << InsertPt->getParent()->getName() 766 << "\n"); 767 return false; 768 } 769 770 ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); 771 ExitPreLoopAt->setName("exit.preloop.at"); 772 } 773 774 if (NeedsPostLoop) { 775 const SCEV *ExitMainLoopAtSCEV = nullptr; 776 777 if (Increasing) 778 ExitMainLoopAtSCEV = *SR.HighLimit; 779 else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, 780 IsSignedPredicate)) 781 ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); 782 else { 783 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " 784 << "mainloop exit limit. LowLimit = " 785 << *(*SR.LowLimit) << "\n"); 786 return false; 787 } 788 789 if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { 790 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" 791 << " main loop exit limit " << *ExitMainLoopAtSCEV 792 << " at block " << InsertPt->getParent()->getName() 793 << "\n"); 794 return false; 795 } 796 797 ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); 798 ExitMainLoopAt->setName("exit.mainloop.at"); 799 } 800 801 // We clone these ahead of time so that we don't have to deal with changing 802 // and temporarily invalid IR as we transform the loops. 803 if (NeedsPreLoop) 804 cloneLoop(PreLoop, "preloop"); 805 if (NeedsPostLoop) 806 cloneLoop(PostLoop, "postloop"); 807 808 RewrittenRangeInfo PreLoopRRI; 809 810 if (NeedsPreLoop) { 811 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, 812 PreLoop.Structure.Header); 813 814 MainLoopPreheader = 815 createPreheader(MainLoopStructure, Preheader, "mainloop"); 816 PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, 817 ExitPreLoopAt, MainLoopPreheader); 818 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, 819 PreLoopRRI); 820 } 821 822 BasicBlock *PostLoopPreheader = nullptr; 823 RewrittenRangeInfo PostLoopRRI; 824 825 if (NeedsPostLoop) { 826 PostLoopPreheader = 827 createPreheader(PostLoop.Structure, Preheader, "postloop"); 828 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, 829 ExitMainLoopAt, PostLoopPreheader); 830 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, 831 PostLoopRRI); 832 } 833 834 BasicBlock *NewMainLoopPreheader = 835 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; 836 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, 837 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, 838 PostLoopRRI.ExitSelector, NewMainLoopPreheader}; 839 840 // Some of the above may be nullptr, filter them out before passing to 841 // addToParentLoopIfNeeded. 842 auto NewBlocksEnd = 843 std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); 844 845 addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); 846 847 DT.recalculate(F); 848 849 // We need to first add all the pre and post loop blocks into the loop 850 // structures (as part of createClonedLoopStructure), and then update the 851 // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating 852 // LI when LoopSimplifyForm is generated. 853 Loop *PreL = nullptr, *PostL = nullptr; 854 if (!PreLoop.Blocks.empty()) { 855 PreL = createClonedLoopStructure(&OriginalLoop, 856 OriginalLoop.getParentLoop(), PreLoop.Map, 857 /* IsSubLoop */ false); 858 } 859 860 if (!PostLoop.Blocks.empty()) { 861 PostL = 862 createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), 863 PostLoop.Map, /* IsSubLoop */ false); 864 } 865 866 // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. 867 auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) { 868 formLCSSARecursively(*L, DT, &LI, &SE); 869 simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); 870 // Pre/post loops are slow paths, we do not need to perform any loop 871 // optimizations on them. 872 if (!IsOriginalLoop) 873 DisableAllLoopOptsOnLoop(*L); 874 }; 875 if (PreL) 876 CanonicalizeLoop(PreL, false); 877 if (PostL) 878 CanonicalizeLoop(PostL, false); 879 CanonicalizeLoop(&OriginalLoop, true); 880 881 /// At this point: 882 /// - We've broken a "main loop" out of the loop in a way that the "main loop" 883 /// runs with the induction variable in a subset of [Begin, End). 884 /// - There is no overflow when computing "main loop" exit limit. 885 /// - Max latch taken count of the loop is limited. 886 /// It guarantees that induction variable will not overflow iterating in the 887 /// "main loop". 888 if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase)) 889 if (IsSignedPredicate) 890 cast<BinaryOperator>(MainLoopStructure.IndVarBase) 891 ->setHasNoSignedWrap(true); 892 /// TODO: support unsigned predicate. 893 /// To add NUW flag we need to prove that both operands of BO are 894 /// non-negative. E.g: 895 /// ... 896 /// %iv.next = add nsw i32 %iv, -1 897 /// %cmp = icmp ult i32 %iv.next, %n 898 /// br i1 %cmp, label %loopexit, label %loop 899 /// 900 /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will 901 /// overflow, therefore NUW flag is not legal here. 902 903 return true; 904 } 905