#include "llvm/Transforms/Utils/LoopConstrainer.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" using namespace llvm; static const char *ClonedLoopTag = "loop_constrainer.loop.clone"; #define DEBUG_TYPE "loop-constrainer" /// Given a loop with an deccreasing induction variable, is it possible to /// safely calculate the bounds of a new loop using the given Predicate. static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, const SCEV *Step, ICmpInst::Predicate Pred, unsigned LatchBrExitIdx, Loop *L, ScalarEvolution &SE) { if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) return false; if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) return false; assert(SE.isKnownNegative(Step) && "expecting negative step"); LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n"); LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); bool IsSigned = ICmpInst::isSigned(Pred); // The predicate that we need to check that the induction variable lies // within bounds. ICmpInst::Predicate BoundPred = IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; if (LatchBrExitIdx == 1) return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); unsigned BitWidth = cast(BoundSCEV->getType())->getBitWidth(); APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) : APInt::getMinValue(BitWidth); const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); const SCEV *MinusOne = SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); } /// Given a loop with an increasing induction variable, is it possible to /// safely calculate the bounds of a new loop using the given Predicate. static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, const SCEV *Step, ICmpInst::Predicate Pred, unsigned LatchBrExitIdx, Loop *L, ScalarEvolution &SE) { if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) return false; if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) return false; LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n"); LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); bool IsSigned = ICmpInst::isSigned(Pred); // The predicate that we need to check that the induction variable lies // within bounds. ICmpInst::Predicate BoundPred = IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; if (LatchBrExitIdx == 1) return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType())); unsigned BitWidth = cast(BoundSCEV->getType())->getBitWidth(); APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) : APInt::getMaxValue(BitWidth); const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, SE.getAddExpr(BoundSCEV, Step)) && SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } /// Returns estimate for max latch taken count of the loop of the narrowest /// available type. If the latch block has such estimate, it is returned. /// Otherwise, we use max exit count of whole loop (that is potentially of wider /// type than latch check itself), which is still better than no estimate. static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, const Loop &L) { const SCEV *FromBlock = SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); if (isa(FromBlock)) return SE.getSymbolicMaxBackedgeTakenCount(&L); return FromBlock; } std::optional LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, bool AllowUnsignedLatchCond, const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; return std::nullopt; } BasicBlock *Latch = L.getLoopLatch(); assert(Latch && "Simplified loops only have one latch!"); if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { FailureReason = "loop has already been cloned"; return std::nullopt; } if (!L.isLoopExiting(Latch)) { FailureReason = "no loop latch"; return std::nullopt; } BasicBlock *Header = L.getHeader(); BasicBlock *Preheader = L.getLoopPreheader(); if (!Preheader) { FailureReason = "no preheader"; return std::nullopt; } BranchInst *LatchBr = dyn_cast(Latch->getTerminator()); if (!LatchBr || LatchBr->isUnconditional()) { FailureReason = "latch terminator not conditional branch"; return std::nullopt; } unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; ICmpInst *ICI = dyn_cast(LatchBr->getCondition()); if (!ICI || !isa(ICI->getOperand(0)->getType())) { FailureReason = "latch terminator branch not conditional on integral icmp"; return std::nullopt; } const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); if (isa(MaxBETakenCount)) { FailureReason = "could not compute latch count"; return std::nullopt; } assert(SE.getLoopDisposition(MaxBETakenCount, &L) == ScalarEvolution::LoopInvariant && "loop variant exit count doesn't make sense!"); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *LeftValue = ICI->getOperand(0); const SCEV *LeftSCEV = SE.getSCEV(LeftValue); IntegerType *IndVarTy = cast(LeftValue->getType()); Value *RightValue = ICI->getOperand(1); const SCEV *RightSCEV = SE.getSCEV(RightValue); // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. if (!isa(LeftSCEV)) { if (isa(RightSCEV)) { std::swap(LeftSCEV, RightSCEV); std::swap(LeftValue, RightValue); Pred = ICmpInst::getSwappedPredicate(Pred); } else { FailureReason = "no add recurrences in the icmp"; return std::nullopt; } } auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { if (AR->getNoWrapFlags(SCEV::FlagNSW)) return true; IntegerType *Ty = cast(AR->getType()); IntegerType *WideTy = IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); const SCEVAddRecExpr *ExtendAfterOp = dyn_cast(SE.getSignExtendExpr(AR, WideTy)); if (ExtendAfterOp) { const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); const SCEV *ExtendedStep = SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; if (NoSignedWrap) return true; } // We may have proved this when computing the sign extension above. return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; }; // `ICI` is interpreted as taking the backedge if the *next* value of the // induction variable satisfies some constraint. const SCEVAddRecExpr *IndVarBase = cast(LeftSCEV); if (IndVarBase->getLoop() != &L) { FailureReason = "LHS in cmp is not an AddRec for this loop"; return std::nullopt; } if (!IndVarBase->isAffine()) { FailureReason = "LHS in icmp not induction variable"; return std::nullopt; } const SCEV *StepRec = IndVarBase->getStepRecurrence(SE); if (!isa(StepRec)) { FailureReason = "LHS in icmp not induction variable"; return std::nullopt; } ConstantInt *StepCI = cast(StepRec)->getValue(); if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { FailureReason = "LHS in icmp needs nsw for equality predicates"; return std::nullopt; } assert(!StepCI->isZero() && "Zero step?"); bool IsIncreasing = !StepCI->isNegative(); bool IsSignedPredicate; const SCEV *StartNext = IndVarBase->getStart(); const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); const SCEV *Step = SE.getSCEV(StepCI); const SCEV *FixedRightSCEV = nullptr; // If RightValue resides within loop (but still being loop invariant), // regenerate it as preheader. if (auto *I = dyn_cast(RightValue)) if (L.contains(I->getParent())) FixedRightSCEV = RightSCEV; if (IsIncreasing) { bool DecreasedRightValueByOne = false; if (StepCI->isOne()) { // Try to turn eq/ne predicates to those we can work with. if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) // while (++i != len) { while (++i < len) { // ... ---> ... // } } // If both parts are known non-negative, it is profitable to use // unsigned comparison in increasing loop. This allows us to make the // comparison check against "RightSCEV + 1" more optimistic. if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && isKnownNonNegativeInLoop(RightSCEV, &L, SE)) Pred = ICmpInst::ICMP_ULT; else Pred = ICmpInst::ICMP_SLT; else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (++i == len) ---> if (++i > len - 1) // break; break; // ... ... // } } if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) { Pred = ICmpInst::ICMP_UGT; RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); DecreasedRightValueByOne = true; } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) { Pred = ICmpInst::ICMP_SGT; RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); DecreasedRightValueByOne = true; } } } bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); bool FoundExpectedPred = (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp slt semantically, found something else"; return std::nullopt; } IsSignedPredicate = ICmpInst::isSigned(Pred); if (!IsSignedPredicate && !AllowUnsignedLatchCond) { FailureReason = "unsigned latch conditions are explicitly prohibited"; return std::nullopt; } if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, LatchBrExitIdx, &L, SE)) { FailureReason = "Unsafe loop bounds"; return std::nullopt; } if (LatchBrExitIdx == 0) { // We need to increase the right value unless we have already decreased // it virtually when we replaced EQ with SGT. if (!DecreasedRightValueByOne) FixedRightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); } else { assert(!DecreasedRightValueByOne && "Right value can be decreased only for LatchBrExitIdx == 0!"); } } else { bool IncreasedRightValueByOne = false; if (StepCI->isMinusOne()) { // Try to turn eq/ne predicates to those we can work with. if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) // while (--i != len) { while (--i > len) { // ... ---> ... // } } // We intentionally don't turn the predicate into UGT even if we know // that both operands are non-negative, because it will only pessimize // our check against "RightSCEV - 1". Pred = ICmpInst::ICMP_SGT; else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (--i == len) ---> if (--i < len + 1) // break; break; // ... ... // } } if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { Pred = ICmpInst::ICMP_ULT; RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); IncreasedRightValueByOne = true; } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { Pred = ICmpInst::ICMP_SLT; RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); IncreasedRightValueByOne = true; } } } bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); bool FoundExpectedPred = (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp sgt semantically, found something else"; return std::nullopt; } IsSignedPredicate = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; if (!IsSignedPredicate && !AllowUnsignedLatchCond) { FailureReason = "unsigned latch conditions are explicitly prohibited"; return std::nullopt; } if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, LatchBrExitIdx, &L, SE)) { FailureReason = "Unsafe bounds"; return std::nullopt; } if (LatchBrExitIdx == 0) { // We need to decrease the right value unless we have already increased // it virtually when we replaced EQ with SLT. if (!IncreasedRightValueByOne) FixedRightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); } else { assert(!IncreasedRightValueByOne && "Right value can be increased only for LatchBrExitIdx == 0!"); } } BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); assert(!L.contains(LatchExit) && "expected an exit block!"); const DataLayout &DL = Preheader->getModule()->getDataLayout(); SCEVExpander Expander(SE, DL, "loop-constrainer"); Instruction *Ins = Preheader->getTerminator(); if (FixedRightSCEV) RightValue = Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); IndVarStartV->setName("indvar.start"); LoopStructure Result; Result.Tag = "main"; Result.Header = Header; Result.Latch = Latch; Result.LatchBr = LatchBr; Result.LatchExit = LatchExit; Result.LatchBrExitIdx = LatchBrExitIdx; Result.IndVarStart = IndVarStartV; Result.IndVarStep = StepCI; Result.IndVarBase = LeftValue; Result.IndVarIncreasing = IsIncreasing; Result.LoopExitAt = RightValue; Result.IsSignedPredicate = IsSignedPredicate; Result.ExitCountTy = cast(MaxBETakenCount->getType()); FailureReason = nullptr; return Result; } // Add metadata to the loop L to disable loop optimizations. Callers need to // confirm that optimizing loop L is not beneficial. static void DisableAllLoopOptsOnLoop(Loop &L) { // We do not care about any existing loopID related metadata for L, since we // are setting all loop metadata to false. LLVMContext &Context = L.getHeader()->getContext(); // Reserve first location for self reference to the LoopID metadata node. MDNode *Dummy = MDNode::get(Context, {}); MDNode *DisableUnroll = MDNode::get( Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); Metadata *FalseVal = ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); MDNode *DisableVectorize = MDNode::get( Context, {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); MDNode *DisableLICMVersioning = MDNode::get( Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); MDNode *DisableDistribution = MDNode::get( Context, {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); MDNode *NewLoopID = MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, DisableLICMVersioning, DisableDistribution}); // Set operand 0 to refer to the loop id itself. NewLoopID->replaceOperandWith(0, NewLoopID); L.setLoopID(NewLoopID); } LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI, function_ref LPMAddNewLoop, const LoopStructure &LS, ScalarEvolution &SE, DominatorTree &DT, Type *T, SubRanges SR) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T), MainLoopStructure(LS), SR(SR) {} void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, const char *Tag) const { for (BasicBlock *BB : OriginalLoop.getBlocks()) { BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); Result.Blocks.push_back(Clone); Result.Map[BB] = Clone; } auto GetClonedValue = [&Result](Value *V) { assert(V && "null values not in domain!"); auto It = Result.Map.find(V); if (It == Result.Map.end()) return V; return static_cast(It->second); }; auto *ClonedLatch = cast(GetClonedValue(OriginalLoop.getLoopLatch())); ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, MDNode::get(Ctx, {})); Result.Structure = MainLoopStructure.map(GetClonedValue); Result.Structure.Tag = Tag; for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { BasicBlock *ClonedBB = Result.Blocks[i]; BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); for (Instruction &I : *ClonedBB) RemapInstruction(&I, Result.Map, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Exit blocks will now have one more predecessor and their PHI nodes need // to be edited to reflect that. No phi nodes need to be introduced because // the loop is in LCSSA. for (auto *SBB : successors(OriginalBB)) { if (OriginalLoop.contains(SBB)) continue; // not an exit block for (PHINode &PN : SBB->phis()) { Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); SE.forgetValue(&PN); } } } } LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, BasicBlock *ContinuationBlock) const { // We start with a loop with a single latch: // // +--------------------+ // | | // | preheader | // | | // +--------+-----------+ // | ----------------\ // | / | // +--------v----v------+ | // | | | // | header | | // | | | // +--------------------+ | // | // ..... | // | // +--------------------+ | // | | | // | latch >----------/ // | | // +-------v------------+ // | // | // | +--------------------+ // | | | // +---> original exit | // | | // +--------------------+ // // We change the control flow to look like // // // +--------------------+ // | | // | preheader >-------------------------+ // | | | // +--------v-----------+ | // | /-------------+ | // | / | | // +--------v--v--------+ | | // | | | | // | header | | +--------+ | // | | | | | | // +--------------------+ | | +-----v-----v-----------+ // | | | | // | | | .pseudo.exit | // | | | | // | | +-----------v-----------+ // | | | // ..... | | | // | | +--------v-------------+ // +--------------------+ | | | | // | | | | | ContinuationBlock | // | latch >------+ | | | // | | | +----------------------+ // +---------v----------+ | // | | // | | // | +---------------^-----+ // | | | // +-----> .exit.selector | // | | // +----------v----------+ // | // +--------------------+ | // | | | // | original exit <----+ // | | // +--------------------+ RewrittenRangeInfo RRI; BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", &F, BBInsertLocation); RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, BBInsertLocation); BranchInst *PreheaderJump = cast(Preheader->getTerminator()); bool Increasing = LS.IndVarIncreasing; bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); auto NoopOrExt = [&](Value *V) { if (V->getType() == RangeTy) return V; return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) : B.CreateZExt(V, RangeTy, "wide." + V->getName()); }; // EnterLoopCond - is it okay to start executing this `LS'? Value *EnterLoopCond = nullptr; auto Pred = Increasing ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); Value *IndVarStart = NoopOrExt(LS.IndVarStart); EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); Value *IndVarBase = NoopOrExt(LS.IndVarBase); Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond : B.CreateNot(TakeBackedgeLoopCond); LS.LatchBr->setCondition(CondForBranch); B.SetInsertPoint(RRI.ExitSelector); // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = BranchInst::Create(ContinuationBlock, RRI.PseudoExit); // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of // each of the PHI nodes in the loop header. This feeds into the initial // value of the same PHI nodes if/when we continue execution. for (PHINode &PN : LS.Header->phis()) { PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", BranchToContinuation); NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), RRI.ExitSelector); RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", BranchToContinuation); RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); return RRI; } void LoopConstrainer::rewriteIncomingValuesForPHIs( LoopStructure &LS, BasicBlock *ContinuationBlock, const LoopConstrainer::RewrittenRangeInfo &RRI) const { unsigned PHIIndex = 0; for (PHINode &PN : LS.Header->phis()) PN.setIncomingValueForBlock(ContinuationBlock, RRI.PHIValuesAtPseudoExit[PHIIndex++]); LS.IndVarStart = RRI.IndVarEnd; } BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, const char *Tag) const { BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); LS.Header->replacePhiUsesWith(OldPreheader, Preheader); return Preheader; } void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef BBs) { Loop *ParentLoop = OriginalLoop.getParentLoop(); if (!ParentLoop) return; for (BasicBlock *BB : BBs) ParentLoop->addBasicBlockToLoop(BB, LI); } Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, ValueToValueMapTy &VM, bool IsSubloop) { Loop &New = *LI.AllocateLoop(); if (Parent) Parent->addChildLoop(&New); else LI.addTopLevelLoop(&New); LPMAddNewLoop(&New, IsSubloop); // Add all of the blocks in Original to the new loop. for (auto *BB : Original->blocks()) if (LI.getLoopFor(BB) == Original) New.addBasicBlockToLoop(cast(VM[BB]), LI); // Add all of the subloops to the new loop. for (Loop *SubLoop : *Original) createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); return &New; } bool LoopConstrainer::run() { BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); assert(Preheader != nullptr && "precondition!"); OriginalPreheader = Preheader; MainLoopPreheader = Preheader; bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = cast(RangeTy); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer"); Instruction *InsertPt = OriginalPreheader->getTerminator(); // It would have been better to make `PreLoop' and `PostLoop' // `std::optional's, but `ValueToValueMapTy' does not have a copy // constructor. ClonedLoop PreLoop, PostLoop; bool NeedsPreLoop = Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); bool NeedsPostLoop = Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); Value *ExitPreLoopAt = nullptr; Value *ExitMainLoopAt = nullptr; const SCEVConstant *MinusOneS = cast(SE.getConstant(IVTy, -1, true /* isSigned */)); if (NeedsPreLoop) { const SCEV *ExitPreLoopAtSCEV = nullptr; if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, IsSignedPredicate)) ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); else { LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) << "\n"); return false; } if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" << " preloop exit limit " << *ExitPreLoopAtSCEV << " at block " << InsertPt->getParent()->getName() << "\n"); return false; } ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); ExitPreLoopAt->setName("exit.preloop.at"); } if (NeedsPostLoop) { const SCEV *ExitMainLoopAtSCEV = nullptr; if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, IsSignedPredicate)) ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); else { LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) << "\n"); return false; } if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" << " main loop exit limit " << *ExitMainLoopAtSCEV << " at block " << InsertPt->getParent()->getName() << "\n"); return false; } ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); ExitMainLoopAt->setName("exit.mainloop.at"); } // We clone these ahead of time so that we don't have to deal with changing // and temporarily invalid IR as we transform the loops. if (NeedsPreLoop) cloneLoop(PreLoop, "preloop"); if (NeedsPostLoop) cloneLoop(PostLoop, "postloop"); RewrittenRangeInfo PreLoopRRI; if (NeedsPreLoop) { Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, PreLoop.Structure.Header); MainLoopPreheader = createPreheader(MainLoopStructure, Preheader, "mainloop"); PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, ExitPreLoopAt, MainLoopPreheader); rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, PreLoopRRI); } BasicBlock *PostLoopPreheader = nullptr; RewrittenRangeInfo PostLoopRRI; if (NeedsPostLoop) { PostLoopPreheader = createPreheader(PostLoop.Structure, Preheader, "postloop"); PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, ExitMainLoopAt, PostLoopPreheader); rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, PostLoopRRI); } BasicBlock *NewMainLoopPreheader = MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, PostLoopRRI.ExitSelector, NewMainLoopPreheader}; // Some of the above may be nullptr, filter them out before passing to // addToParentLoopIfNeeded. auto NewBlocksEnd = std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); DT.recalculate(F); // We need to first add all the pre and post loop blocks into the loop // structures (as part of createClonedLoopStructure), and then update the // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating // LI when LoopSimplifyForm is generated. Loop *PreL = nullptr, *PostL = nullptr; if (!PreLoop.Blocks.empty()) { PreL = createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map, /* IsSubLoop */ false); } if (!PostLoop.Blocks.empty()) { PostL = createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map, /* IsSubLoop */ false); } // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) { formLCSSARecursively(*L, DT, &LI, &SE); simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); // Pre/post loops are slow paths, we do not need to perform any loop // optimizations on them. if (!IsOriginalLoop) DisableAllLoopOptsOnLoop(*L); }; if (PreL) CanonicalizeLoop(PreL, false); if (PostL) CanonicalizeLoop(PostL, false); CanonicalizeLoop(&OriginalLoop, true); /// At this point: /// - We've broken a "main loop" out of the loop in a way that the "main loop" /// runs with the induction variable in a subset of [Begin, End). /// - There is no overflow when computing "main loop" exit limit. /// - Max latch taken count of the loop is limited. /// It guarantees that induction variable will not overflow iterating in the /// "main loop". if (isa(MainLoopStructure.IndVarBase)) if (IsSignedPredicate) cast(MainLoopStructure.IndVarBase) ->setHasNoSignedWrap(true); /// TODO: support unsigned predicate. /// To add NUW flag we need to prove that both operands of BO are /// non-negative. E.g: /// ... /// %iv.next = add nsw i32 %iv, -1 /// %cmp = icmp ult i32 %iv.next, %n /// br i1 %cmp, label %loopexit, label %loop /// /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will /// overflow, therefore NUW flag is not legal here. return true; }