1 #include "llvm/Transforms/Utils/LoopConstrainer.h" 2 #include "llvm/Analysis/LoopInfo.h" 3 #include "llvm/Analysis/ScalarEvolution.h" 4 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 5 #include "llvm/IR/Dominators.h" 6 #include "llvm/Transforms/Utils/Cloning.h" 7 #include "llvm/Transforms/Utils/LoopSimplify.h" 8 #include "llvm/Transforms/Utils/LoopUtils.h" 9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 10 11 using namespace llvm; 12 13 static const char *ClonedLoopTag = "loop_constrainer.loop.clone"; 14 15 #define DEBUG_TYPE "loop-constrainer" 16 17 /// Given a loop with an deccreasing induction variable, is it possible to 18 /// safely calculate the bounds of a new loop using the given Predicate. 19 static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, 20 const SCEV *Step, ICmpInst::Predicate Pred, 21 unsigned LatchBrExitIdx, Loop *L, 22 ScalarEvolution &SE) { 23 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && 24 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) 25 return false; 26 27 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) 28 return false; 29 30 assert(SE.isKnownNegative(Step) && "expecting negative step"); 31 32 LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n"); 33 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); 34 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); 35 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); 36 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); 37 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); 38 39 bool IsSigned = ICmpInst::isSigned(Pred); 40 // The predicate that we need to check that the induction variable lies 41 // within bounds. 42 ICmpInst::Predicate BoundPred = 43 IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; 44 45 auto StartLG = SE.applyLoopGuards(Start, L); 46 auto BoundLG = SE.applyLoopGuards(BoundSCEV, L); 47 48 if (LatchBrExitIdx == 1) 49 return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG); 50 51 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); 52 53 const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); 54 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); 55 APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) 56 : APInt::getMinValue(BitWidth); 57 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); 58 59 const SCEV *MinusOne = 60 SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType())); 61 62 return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) && 63 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit); 64 } 65 66 /// Given a loop with an increasing induction variable, is it possible to 67 /// safely calculate the bounds of a new loop using the given Predicate. 68 static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, 69 const SCEV *Step, ICmpInst::Predicate Pred, 70 unsigned LatchBrExitIdx, Loop *L, 71 ScalarEvolution &SE) { 72 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && 73 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) 74 return false; 75 76 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) 77 return false; 78 79 LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n"); 80 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); 81 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); 82 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); 83 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); 84 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); 85 86 bool IsSigned = ICmpInst::isSigned(Pred); 87 // The predicate that we need to check that the induction variable lies 88 // within bounds. 89 ICmpInst::Predicate BoundPred = 90 IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; 91 92 auto StartLG = SE.applyLoopGuards(Start, L); 93 auto BoundLG = SE.applyLoopGuards(BoundSCEV, L); 94 95 if (LatchBrExitIdx == 1) 96 return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG); 97 98 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); 99 100 const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType())); 101 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); 102 APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) 103 : APInt::getMaxValue(BitWidth); 104 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); 105 106 return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, 107 SE.getAddExpr(BoundLG, Step)) && 108 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit)); 109 } 110 111 /// Returns estimate for max latch taken count of the loop of the narrowest 112 /// available type. If the latch block has such estimate, it is returned. 113 /// Otherwise, we use max exit count of whole loop (that is potentially of wider 114 /// type than latch check itself), which is still better than no estimate. 115 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, 116 const Loop &L) { 117 const SCEV *FromBlock = 118 SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); 119 if (isa<SCEVCouldNotCompute>(FromBlock)) 120 return SE.getSymbolicMaxBackedgeTakenCount(&L); 121 return FromBlock; 122 } 123 124 std::optional<LoopStructure> 125 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, 126 bool AllowUnsignedLatchCond, 127 const char *&FailureReason) { 128 if (!L.isLoopSimplifyForm()) { 129 FailureReason = "loop not in LoopSimplify form"; 130 return std::nullopt; 131 } 132 133 BasicBlock *Latch = L.getLoopLatch(); 134 assert(Latch && "Simplified loops only have one latch!"); 135 136 if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { 137 FailureReason = "loop has already been cloned"; 138 return std::nullopt; 139 } 140 141 if (!L.isLoopExiting(Latch)) { 142 FailureReason = "no loop latch"; 143 return std::nullopt; 144 } 145 146 BasicBlock *Header = L.getHeader(); 147 BasicBlock *Preheader = L.getLoopPreheader(); 148 if (!Preheader) { 149 FailureReason = "no preheader"; 150 return std::nullopt; 151 } 152 153 BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); 154 if (!LatchBr || LatchBr->isUnconditional()) { 155 FailureReason = "latch terminator not conditional branch"; 156 return std::nullopt; 157 } 158 159 unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; 160 161 ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); 162 if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { 163 FailureReason = "latch terminator branch not conditional on integral icmp"; 164 return std::nullopt; 165 } 166 167 const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); 168 if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { 169 FailureReason = "could not compute latch count"; 170 return std::nullopt; 171 } 172 assert(SE.getLoopDisposition(MaxBETakenCount, &L) == 173 ScalarEvolution::LoopInvariant && 174 "loop variant exit count doesn't make sense!"); 175 176 ICmpInst::Predicate Pred = ICI->getPredicate(); 177 Value *LeftValue = ICI->getOperand(0); 178 const SCEV *LeftSCEV = SE.getSCEV(LeftValue); 179 IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); 180 181 Value *RightValue = ICI->getOperand(1); 182 const SCEV *RightSCEV = SE.getSCEV(RightValue); 183 184 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. 185 if (!isa<SCEVAddRecExpr>(LeftSCEV)) { 186 if (isa<SCEVAddRecExpr>(RightSCEV)) { 187 std::swap(LeftSCEV, RightSCEV); 188 std::swap(LeftValue, RightValue); 189 Pred = ICmpInst::getSwappedPredicate(Pred); 190 } else { 191 FailureReason = "no add recurrences in the icmp"; 192 return std::nullopt; 193 } 194 } 195 196 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { 197 if (AR->getNoWrapFlags(SCEV::FlagNSW)) 198 return true; 199 200 IntegerType *Ty = cast<IntegerType>(AR->getType()); 201 IntegerType *WideTy = 202 IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); 203 204 const SCEVAddRecExpr *ExtendAfterOp = 205 dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); 206 if (ExtendAfterOp) { 207 const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); 208 const SCEV *ExtendedStep = 209 SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); 210 211 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && 212 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; 213 214 if (NoSignedWrap) 215 return true; 216 } 217 218 // We may have proved this when computing the sign extension above. 219 return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; 220 }; 221 222 // `ICI` is interpreted as taking the backedge if the *next* value of the 223 // induction variable satisfies some constraint. 224 225 const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); 226 if (IndVarBase->getLoop() != &L) { 227 FailureReason = "LHS in cmp is not an AddRec for this loop"; 228 return std::nullopt; 229 } 230 if (!IndVarBase->isAffine()) { 231 FailureReason = "LHS in icmp not induction variable"; 232 return std::nullopt; 233 } 234 const SCEV *StepRec = IndVarBase->getStepRecurrence(SE); 235 if (!isa<SCEVConstant>(StepRec)) { 236 FailureReason = "LHS in icmp not induction variable"; 237 return std::nullopt; 238 } 239 ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); 240 241 if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { 242 FailureReason = "LHS in icmp needs nsw for equality predicates"; 243 return std::nullopt; 244 } 245 246 assert(!StepCI->isZero() && "Zero step?"); 247 bool IsIncreasing = !StepCI->isNegative(); 248 bool IsSignedPredicate; 249 const SCEV *StartNext = IndVarBase->getStart(); 250 const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); 251 const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); 252 const SCEV *Step = SE.getSCEV(StepCI); 253 254 const SCEV *FixedRightSCEV = nullptr; 255 256 // If RightValue resides within loop (but still being loop invariant), 257 // regenerate it as preheader. 258 if (auto *I = dyn_cast<Instruction>(RightValue)) 259 if (L.contains(I->getParent())) 260 FixedRightSCEV = RightSCEV; 261 262 if (IsIncreasing) { 263 bool DecreasedRightValueByOne = false; 264 if (StepCI->isOne()) { 265 // Try to turn eq/ne predicates to those we can work with. 266 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) 267 // while (++i != len) { while (++i < len) { 268 // ... ---> ... 269 // } } 270 // If both parts are known non-negative, it is profitable to use 271 // unsigned comparison in increasing loop. This allows us to make the 272 // comparison check against "RightSCEV + 1" more optimistic. 273 if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && 274 isKnownNonNegativeInLoop(RightSCEV, &L, SE)) 275 Pred = ICmpInst::ICMP_ULT; 276 else 277 Pred = ICmpInst::ICMP_SLT; 278 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { 279 // while (true) { while (true) { 280 // if (++i == len) ---> if (++i > len - 1) 281 // break; break; 282 // ... ... 283 // } } 284 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && 285 cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) { 286 Pred = ICmpInst::ICMP_UGT; 287 RightSCEV = 288 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 289 DecreasedRightValueByOne = true; 290 } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) { 291 Pred = ICmpInst::ICMP_SGT; 292 RightSCEV = 293 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 294 DecreasedRightValueByOne = true; 295 } 296 } 297 } 298 299 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); 300 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); 301 bool FoundExpectedPred = 302 (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); 303 304 if (!FoundExpectedPred) { 305 FailureReason = "expected icmp slt semantically, found something else"; 306 return std::nullopt; 307 } 308 309 IsSignedPredicate = ICmpInst::isSigned(Pred); 310 if (!IsSignedPredicate && !AllowUnsignedLatchCond) { 311 FailureReason = "unsigned latch conditions are explicitly prohibited"; 312 return std::nullopt; 313 } 314 315 if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, 316 LatchBrExitIdx, &L, SE)) { 317 FailureReason = "Unsafe loop bounds"; 318 return std::nullopt; 319 } 320 if (LatchBrExitIdx == 0) { 321 // We need to increase the right value unless we have already decreased 322 // it virtually when we replaced EQ with SGT. 323 if (!DecreasedRightValueByOne) 324 FixedRightSCEV = 325 SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 326 } else { 327 assert(!DecreasedRightValueByOne && 328 "Right value can be decreased only for LatchBrExitIdx == 0!"); 329 } 330 } else { 331 bool IncreasedRightValueByOne = false; 332 if (StepCI->isMinusOne()) { 333 // Try to turn eq/ne predicates to those we can work with. 334 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) 335 // while (--i != len) { while (--i > len) { 336 // ... ---> ... 337 // } } 338 // We intentionally don't turn the predicate into UGT even if we know 339 // that both operands are non-negative, because it will only pessimize 340 // our check against "RightSCEV - 1". 341 Pred = ICmpInst::ICMP_SGT; 342 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { 343 // while (true) { while (true) { 344 // if (--i == len) ---> if (--i < len + 1) 345 // break; break; 346 // ... ... 347 // } } 348 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && 349 cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { 350 Pred = ICmpInst::ICMP_ULT; 351 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 352 IncreasedRightValueByOne = true; 353 } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { 354 Pred = ICmpInst::ICMP_SLT; 355 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); 356 IncreasedRightValueByOne = true; 357 } 358 } 359 } 360 361 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); 362 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); 363 364 bool FoundExpectedPred = 365 (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); 366 367 if (!FoundExpectedPred) { 368 FailureReason = "expected icmp sgt semantically, found something else"; 369 return std::nullopt; 370 } 371 372 IsSignedPredicate = 373 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; 374 375 if (!IsSignedPredicate && !AllowUnsignedLatchCond) { 376 FailureReason = "unsigned latch conditions are explicitly prohibited"; 377 return std::nullopt; 378 } 379 380 if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, 381 LatchBrExitIdx, &L, SE)) { 382 FailureReason = "Unsafe bounds"; 383 return std::nullopt; 384 } 385 386 if (LatchBrExitIdx == 0) { 387 // We need to decrease the right value unless we have already increased 388 // it virtually when we replaced EQ with SLT. 389 if (!IncreasedRightValueByOne) 390 FixedRightSCEV = 391 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); 392 } else { 393 assert(!IncreasedRightValueByOne && 394 "Right value can be increased only for LatchBrExitIdx == 0!"); 395 } 396 } 397 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); 398 399 assert(!L.contains(LatchExit) && "expected an exit block!"); 400 const DataLayout &DL = Preheader->getDataLayout(); 401 SCEVExpander Expander(SE, DL, "loop-constrainer"); 402 Instruction *Ins = Preheader->getTerminator(); 403 404 if (FixedRightSCEV) 405 RightValue = 406 Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); 407 408 Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); 409 IndVarStartV->setName("indvar.start"); 410 411 LoopStructure Result; 412 413 Result.Tag = "main"; 414 Result.Header = Header; 415 Result.Latch = Latch; 416 Result.LatchBr = LatchBr; 417 Result.LatchExit = LatchExit; 418 Result.LatchBrExitIdx = LatchBrExitIdx; 419 Result.IndVarStart = IndVarStartV; 420 Result.IndVarStep = StepCI; 421 Result.IndVarBase = LeftValue; 422 Result.IndVarIncreasing = IsIncreasing; 423 Result.LoopExitAt = RightValue; 424 Result.IsSignedPredicate = IsSignedPredicate; 425 Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); 426 427 FailureReason = nullptr; 428 429 return Result; 430 } 431 432 // Add metadata to the loop L to disable loop optimizations. Callers need to 433 // confirm that optimizing loop L is not beneficial. 434 static void DisableAllLoopOptsOnLoop(Loop &L) { 435 // We do not care about any existing loopID related metadata for L, since we 436 // are setting all loop metadata to false. 437 LLVMContext &Context = L.getHeader()->getContext(); 438 // Reserve first location for self reference to the LoopID metadata node. 439 MDNode *Dummy = MDNode::get(Context, {}); 440 MDNode *DisableUnroll = MDNode::get( 441 Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); 442 Metadata *FalseVal = 443 ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); 444 MDNode *DisableVectorize = MDNode::get( 445 Context, 446 {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); 447 MDNode *DisableLICMVersioning = MDNode::get( 448 Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); 449 MDNode *DisableDistribution = MDNode::get( 450 Context, 451 {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); 452 MDNode *NewLoopID = 453 MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, 454 DisableLICMVersioning, DisableDistribution}); 455 // Set operand 0 to refer to the loop id itself. 456 NewLoopID->replaceOperandWith(0, NewLoopID); 457 L.setLoopID(NewLoopID); 458 } 459 460 LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI, 461 function_ref<void(Loop *, bool)> LPMAddNewLoop, 462 const LoopStructure &LS, ScalarEvolution &SE, 463 DominatorTree &DT, Type *T, SubRanges SR) 464 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), 465 DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T), 466 MainLoopStructure(LS), SR(SR) {} 467 468 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, 469 const char *Tag) const { 470 for (BasicBlock *BB : OriginalLoop.getBlocks()) { 471 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); 472 Result.Blocks.push_back(Clone); 473 Result.Map[BB] = Clone; 474 } 475 476 auto GetClonedValue = [&Result](Value *V) { 477 assert(V && "null values not in domain!"); 478 auto It = Result.Map.find(V); 479 if (It == Result.Map.end()) 480 return V; 481 return static_cast<Value *>(It->second); 482 }; 483 484 auto *ClonedLatch = 485 cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); 486 ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, 487 MDNode::get(Ctx, {})); 488 489 Result.Structure = MainLoopStructure.map(GetClonedValue); 490 Result.Structure.Tag = Tag; 491 492 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { 493 BasicBlock *ClonedBB = Result.Blocks[i]; 494 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; 495 496 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); 497 498 for (Instruction &I : *ClonedBB) 499 RemapInstruction(&I, Result.Map, 500 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); 501 502 // Exit blocks will now have one more predecessor and their PHI nodes need 503 // to be edited to reflect that. No phi nodes need to be introduced because 504 // the loop is in LCSSA. 505 506 for (auto *SBB : successors(OriginalBB)) { 507 if (OriginalLoop.contains(SBB)) 508 continue; // not an exit block 509 510 for (PHINode &PN : SBB->phis()) { 511 Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); 512 PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); 513 SE.forgetValue(&PN); 514 } 515 } 516 } 517 } 518 519 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( 520 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, 521 BasicBlock *ContinuationBlock) const { 522 // We start with a loop with a single latch: 523 // 524 // +--------------------+ 525 // | | 526 // | preheader | 527 // | | 528 // +--------+-----------+ 529 // | ----------------\ 530 // | / | 531 // +--------v----v------+ | 532 // | | | 533 // | header | | 534 // | | | 535 // +--------------------+ | 536 // | 537 // ..... | 538 // | 539 // +--------------------+ | 540 // | | | 541 // | latch >----------/ 542 // | | 543 // +-------v------------+ 544 // | 545 // | 546 // | +--------------------+ 547 // | | | 548 // +---> original exit | 549 // | | 550 // +--------------------+ 551 // 552 // We change the control flow to look like 553 // 554 // 555 // +--------------------+ 556 // | | 557 // | preheader >-------------------------+ 558 // | | | 559 // +--------v-----------+ | 560 // | /-------------+ | 561 // | / | | 562 // +--------v--v--------+ | | 563 // | | | | 564 // | header | | +--------+ | 565 // | | | | | | 566 // +--------------------+ | | +-----v-----v-----------+ 567 // | | | | 568 // | | | .pseudo.exit | 569 // | | | | 570 // | | +-----------v-----------+ 571 // | | | 572 // ..... | | | 573 // | | +--------v-------------+ 574 // +--------------------+ | | | | 575 // | | | | | ContinuationBlock | 576 // | latch >------+ | | | 577 // | | | +----------------------+ 578 // +---------v----------+ | 579 // | | 580 // | | 581 // | +---------------^-----+ 582 // | | | 583 // +-----> .exit.selector | 584 // | | 585 // +----------v----------+ 586 // | 587 // +--------------------+ | 588 // | | | 589 // | original exit <----+ 590 // | | 591 // +--------------------+ 592 593 RewrittenRangeInfo RRI; 594 595 BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); 596 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", 597 &F, BBInsertLocation); 598 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, 599 BBInsertLocation); 600 601 BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); 602 bool Increasing = LS.IndVarIncreasing; 603 bool IsSignedPredicate = LS.IsSignedPredicate; 604 605 IRBuilder<> B(PreheaderJump); 606 auto NoopOrExt = [&](Value *V) { 607 if (V->getType() == RangeTy) 608 return V; 609 return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) 610 : B.CreateZExt(V, RangeTy, "wide." + V->getName()); 611 }; 612 613 // EnterLoopCond - is it okay to start executing this `LS'? 614 Value *EnterLoopCond = nullptr; 615 auto Pred = 616 Increasing 617 ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) 618 : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); 619 Value *IndVarStart = NoopOrExt(LS.IndVarStart); 620 EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); 621 622 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); 623 PreheaderJump->eraseFromParent(); 624 625 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); 626 B.SetInsertPoint(LS.LatchBr); 627 Value *IndVarBase = NoopOrExt(LS.IndVarBase); 628 Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); 629 630 Value *CondForBranch = LS.LatchBrExitIdx == 1 631 ? TakeBackedgeLoopCond 632 : B.CreateNot(TakeBackedgeLoopCond); 633 634 LS.LatchBr->setCondition(CondForBranch); 635 636 B.SetInsertPoint(RRI.ExitSelector); 637 638 // IterationsLeft - are there any more iterations left, given the original 639 // upper bound on the induction variable? If not, we branch to the "real" 640 // exit. 641 Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); 642 Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); 643 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); 644 645 BranchInst *BranchToContinuation = 646 BranchInst::Create(ContinuationBlock, RRI.PseudoExit); 647 648 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of 649 // each of the PHI nodes in the loop header. This feeds into the initial 650 // value of the same PHI nodes if/when we continue execution. 651 for (PHINode &PN : LS.Header->phis()) { 652 PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", 653 BranchToContinuation->getIterator()); 654 655 NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); 656 NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), 657 RRI.ExitSelector); 658 RRI.PHIValuesAtPseudoExit.push_back(NewPHI); 659 } 660 661 RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", 662 BranchToContinuation->getIterator()); 663 RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); 664 RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); 665 666 // The latch exit now has a branch from `RRI.ExitSelector' instead of 667 // `LS.Latch'. The PHI nodes need to be updated to reflect that. 668 LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); 669 670 return RRI; 671 } 672 673 void LoopConstrainer::rewriteIncomingValuesForPHIs( 674 LoopStructure &LS, BasicBlock *ContinuationBlock, 675 const LoopConstrainer::RewrittenRangeInfo &RRI) const { 676 unsigned PHIIndex = 0; 677 for (PHINode &PN : LS.Header->phis()) 678 PN.setIncomingValueForBlock(ContinuationBlock, 679 RRI.PHIValuesAtPseudoExit[PHIIndex++]); 680 681 LS.IndVarStart = RRI.IndVarEnd; 682 } 683 684 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, 685 BasicBlock *OldPreheader, 686 const char *Tag) const { 687 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); 688 BranchInst::Create(LS.Header, Preheader); 689 690 LS.Header->replacePhiUsesWith(OldPreheader, Preheader); 691 692 return Preheader; 693 } 694 695 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { 696 Loop *ParentLoop = OriginalLoop.getParentLoop(); 697 if (!ParentLoop) 698 return; 699 700 for (BasicBlock *BB : BBs) 701 ParentLoop->addBasicBlockToLoop(BB, LI); 702 } 703 704 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, 705 ValueToValueMapTy &VM, 706 bool IsSubloop) { 707 Loop &New = *LI.AllocateLoop(); 708 if (Parent) 709 Parent->addChildLoop(&New); 710 else 711 LI.addTopLevelLoop(&New); 712 LPMAddNewLoop(&New, IsSubloop); 713 714 // Add all of the blocks in Original to the new loop. 715 for (auto *BB : Original->blocks()) 716 if (LI.getLoopFor(BB) == Original) 717 New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); 718 719 // Add all of the subloops to the new loop. 720 for (Loop *SubLoop : *Original) 721 createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); 722 723 return &New; 724 } 725 726 bool LoopConstrainer::run() { 727 BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); 728 assert(Preheader != nullptr && "precondition!"); 729 730 OriginalPreheader = Preheader; 731 MainLoopPreheader = Preheader; 732 bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; 733 bool Increasing = MainLoopStructure.IndVarIncreasing; 734 IntegerType *IVTy = cast<IntegerType>(RangeTy); 735 736 SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer"); 737 Instruction *InsertPt = OriginalPreheader->getTerminator(); 738 739 // It would have been better to make `PreLoop' and `PostLoop' 740 // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy 741 // constructor. 742 ClonedLoop PreLoop, PostLoop; 743 bool NeedsPreLoop = 744 Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); 745 bool NeedsPostLoop = 746 Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); 747 748 Value *ExitPreLoopAt = nullptr; 749 Value *ExitMainLoopAt = nullptr; 750 const SCEVConstant *MinusOneS = 751 cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); 752 753 if (NeedsPreLoop) { 754 const SCEV *ExitPreLoopAtSCEV = nullptr; 755 756 if (Increasing) 757 ExitPreLoopAtSCEV = *SR.LowLimit; 758 else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, 759 IsSignedPredicate)) 760 ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); 761 else { 762 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " 763 << "preloop exit limit. HighLimit = " 764 << *(*SR.HighLimit) << "\n"); 765 return false; 766 } 767 768 if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { 769 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" 770 << " preloop exit limit " << *ExitPreLoopAtSCEV 771 << " at block " << InsertPt->getParent()->getName() 772 << "\n"); 773 return false; 774 } 775 776 ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); 777 ExitPreLoopAt->setName("exit.preloop.at"); 778 } 779 780 if (NeedsPostLoop) { 781 const SCEV *ExitMainLoopAtSCEV = nullptr; 782 783 if (Increasing) 784 ExitMainLoopAtSCEV = *SR.HighLimit; 785 else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, 786 IsSignedPredicate)) 787 ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); 788 else { 789 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " 790 << "mainloop exit limit. LowLimit = " 791 << *(*SR.LowLimit) << "\n"); 792 return false; 793 } 794 795 if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { 796 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" 797 << " main loop exit limit " << *ExitMainLoopAtSCEV 798 << " at block " << InsertPt->getParent()->getName() 799 << "\n"); 800 return false; 801 } 802 803 ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); 804 ExitMainLoopAt->setName("exit.mainloop.at"); 805 } 806 807 // We clone these ahead of time so that we don't have to deal with changing 808 // and temporarily invalid IR as we transform the loops. 809 if (NeedsPreLoop) 810 cloneLoop(PreLoop, "preloop"); 811 if (NeedsPostLoop) 812 cloneLoop(PostLoop, "postloop"); 813 814 RewrittenRangeInfo PreLoopRRI; 815 816 if (NeedsPreLoop) { 817 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, 818 PreLoop.Structure.Header); 819 820 MainLoopPreheader = 821 createPreheader(MainLoopStructure, Preheader, "mainloop"); 822 PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, 823 ExitPreLoopAt, MainLoopPreheader); 824 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, 825 PreLoopRRI); 826 } 827 828 BasicBlock *PostLoopPreheader = nullptr; 829 RewrittenRangeInfo PostLoopRRI; 830 831 if (NeedsPostLoop) { 832 PostLoopPreheader = 833 createPreheader(PostLoop.Structure, Preheader, "postloop"); 834 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, 835 ExitMainLoopAt, PostLoopPreheader); 836 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, 837 PostLoopRRI); 838 } 839 840 BasicBlock *NewMainLoopPreheader = 841 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; 842 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, 843 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, 844 PostLoopRRI.ExitSelector, NewMainLoopPreheader}; 845 846 // Some of the above may be nullptr, filter them out before passing to 847 // addToParentLoopIfNeeded. 848 auto NewBlocksEnd = 849 std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); 850 851 addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); 852 853 DT.recalculate(F); 854 855 // We need to first add all the pre and post loop blocks into the loop 856 // structures (as part of createClonedLoopStructure), and then update the 857 // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating 858 // LI when LoopSimplifyForm is generated. 859 Loop *PreL = nullptr, *PostL = nullptr; 860 if (!PreLoop.Blocks.empty()) { 861 PreL = createClonedLoopStructure(&OriginalLoop, 862 OriginalLoop.getParentLoop(), PreLoop.Map, 863 /* IsSubLoop */ false); 864 } 865 866 if (!PostLoop.Blocks.empty()) { 867 PostL = 868 createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), 869 PostLoop.Map, /* IsSubLoop */ false); 870 } 871 872 // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. 873 auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) { 874 formLCSSARecursively(*L, DT, &LI, &SE); 875 simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); 876 // Pre/post loops are slow paths, we do not need to perform any loop 877 // optimizations on them. 878 if (!IsOriginalLoop) 879 DisableAllLoopOptsOnLoop(*L); 880 }; 881 if (PreL) 882 CanonicalizeLoop(PreL, false); 883 if (PostL) 884 CanonicalizeLoop(PostL, false); 885 CanonicalizeLoop(&OriginalLoop, true); 886 887 /// At this point: 888 /// - We've broken a "main loop" out of the loop in a way that the "main loop" 889 /// runs with the induction variable in a subset of [Begin, End). 890 /// - There is no overflow when computing "main loop" exit limit. 891 /// - Max latch taken count of the loop is limited. 892 /// It guarantees that induction variable will not overflow iterating in the 893 /// "main loop". 894 if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase)) 895 if (IsSignedPredicate) 896 cast<BinaryOperator>(MainLoopStructure.IndVarBase) 897 ->setHasNoSignedWrap(true); 898 /// TODO: support unsigned predicate. 899 /// To add NUW flag we need to prove that both operands of BO are 900 /// non-negative. E.g: 901 /// ... 902 /// %iv.next = add nsw i32 %iv, -1 903 /// %cmp = icmp ult i32 %iv.next, %n 904 /// br i1 %cmp, label %loopexit, label %loop 905 /// 906 /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will 907 /// overflow, therefore NUW flag is not legal here. 908 909 return true; 910 } 911