1 //===- LoopFlatten.cpp - Loop flattening pass------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass flattens pairs nested loops into a single loop. 10 // 11 // The intention is to optimise loop nests like this, which together access an 12 // array linearly: 13 // 14 // for (int i = 0; i < N; ++i) 15 // for (int j = 0; j < M; ++j) 16 // f(A[i*M+j]); 17 // 18 // into one loop: 19 // 20 // for (int i = 0; i < (N*M); ++i) 21 // f(A[i]); 22 // 23 // It can also flatten loops where the induction variables are not used in the 24 // loop. This is only worth doing if the induction variables are only used in an 25 // expression like i*M+j. If they had any other uses, we would have to insert a 26 // div/mod to reconstruct the original values, so this wouldn't be profitable. 27 // 28 // We also need to prove that N*M will not overflow. The preferred solution is 29 // to widen the IV, which avoids overflow checks, so that is tried first. If 30 // the IV cannot be widened, then we try to determine that this new tripcount 31 // expression won't overflow. 32 // 33 // Q: Does LoopFlatten use SCEV? 34 // Short answer: Yes and no. 35 // 36 // Long answer: 37 // For this transformation to be valid, we require all uses of the induction 38 // variables to be linear expressions of the form i*M+j. The different Loop 39 // APIs are used to get some loop components like the induction variable, 40 // compare statement, etc. In addition, we do some pattern matching to find the 41 // linear expressions and other loop components like the loop increment. The 42 // latter are examples of expressions that do use the induction variable, but 43 // are safe to ignore when we check all uses to be of the form i*M+j. We keep 44 // track of all of this in bookkeeping struct FlattenInfo. 45 // We assume the loops to be canonical, i.e. starting at 0 and increment with 46 // 1. This makes RHS of the compare the loop tripcount (with the right 47 // predicate). We use SCEV to then sanity check that this tripcount matches 48 // with the tripcount as computed by SCEV. 49 // 50 //===----------------------------------------------------------------------===// 51 52 #include "llvm/Transforms/Scalar/LoopFlatten.h" 53 54 #include "llvm/ADT/Statistic.h" 55 #include "llvm/Analysis/AssumptionCache.h" 56 #include "llvm/Analysis/LoopInfo.h" 57 #include "llvm/Analysis/LoopNestAnalysis.h" 58 #include "llvm/Analysis/MemorySSAUpdater.h" 59 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 60 #include "llvm/Analysis/ScalarEvolution.h" 61 #include "llvm/Analysis/TargetTransformInfo.h" 62 #include "llvm/Analysis/ValueTracking.h" 63 #include "llvm/IR/Dominators.h" 64 #include "llvm/IR/Function.h" 65 #include "llvm/IR/IRBuilder.h" 66 #include "llvm/IR/Module.h" 67 #include "llvm/IR/PatternMatch.h" 68 #include "llvm/Support/Debug.h" 69 #include "llvm/Support/raw_ostream.h" 70 #include "llvm/Transforms/Scalar/LoopPassManager.h" 71 #include "llvm/Transforms/Utils/Local.h" 72 #include "llvm/Transforms/Utils/LoopUtils.h" 73 #include "llvm/Transforms/Utils/LoopVersioning.h" 74 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 75 #include "llvm/Transforms/Utils/SimplifyIndVar.h" 76 #include <optional> 77 78 using namespace llvm; 79 using namespace llvm::PatternMatch; 80 81 #define DEBUG_TYPE "loop-flatten" 82 83 STATISTIC(NumFlattened, "Number of loops flattened"); 84 85 static cl::opt<unsigned> RepeatedInstructionThreshold( 86 "loop-flatten-cost-threshold", cl::Hidden, cl::init(2), 87 cl::desc("Limit on the cost of instructions that can be repeated due to " 88 "loop flattening")); 89 90 static cl::opt<bool> 91 AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, 92 cl::init(false), 93 cl::desc("Assume that the product of the two iteration " 94 "trip counts will never overflow")); 95 96 static cl::opt<bool> 97 WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true), 98 cl::desc("Widen the loop induction variables, if possible, so " 99 "overflow checks won't reject flattening")); 100 101 static cl::opt<bool> 102 VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true), 103 cl::desc("Version loops if flattened loop could overflow")); 104 105 namespace { 106 // We require all uses of both induction variables to match this pattern: 107 // 108 // (OuterPHI * InnerTripCount) + InnerPHI 109 // 110 // I.e., it needs to be a linear expression of the induction variables and the 111 // inner loop trip count. We keep track of all different expressions on which 112 // checks will be performed in this bookkeeping struct. 113 // 114 struct FlattenInfo { 115 Loop *OuterLoop = nullptr; // The loop pair to be flattened. 116 Loop *InnerLoop = nullptr; 117 118 PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop 119 PHINode *OuterInductionPHI = nullptr; // induction variables, which are 120 // expected to start at zero and 121 // increment by one on each loop. 122 123 Value *InnerTripCount = nullptr; // The product of these two tripcounts 124 Value *OuterTripCount = nullptr; // will be the new flattened loop 125 // tripcount. Also used to recognise a 126 // linear expression that will be replaced. 127 128 SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions 129 // of the form i*M+j that will be 130 // replaced. 131 132 BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in 133 BinaryOperator *OuterIncrement = nullptr; // loop control statements that 134 BranchInst *InnerBranch = nullptr; // are safe to ignore. 135 136 BranchInst *OuterBranch = nullptr; // The instruction that needs to be 137 // updated with new tripcount. 138 139 SmallPtrSet<PHINode *, 4> InnerPHIsToTransform; 140 141 bool Widened = false; // Whether this holds the flatten info before or after 142 // widening. 143 144 PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction 145 PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV 146 // has been applied. Used to skip 147 // checks on phi nodes. 148 149 Value *NewTripCount = nullptr; // The tripcount of the flattened loop. 150 151 FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; 152 153 bool isNarrowInductionPhi(PHINode *Phi) { 154 // This can't be the narrow phi if we haven't widened the IV first. 155 if (!Widened) 156 return false; 157 return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi; 158 } 159 bool isInnerLoopIncrement(User *U) { 160 return InnerIncrement == U; 161 } 162 bool isOuterLoopIncrement(User *U) { 163 return OuterIncrement == U; 164 } 165 bool isInnerLoopTest(User *U) { 166 return InnerBranch->getCondition() == U; 167 } 168 169 bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { 170 for (User *U : OuterInductionPHI->users()) { 171 if (isOuterLoopIncrement(U)) 172 continue; 173 174 auto IsValidOuterPHIUses = [&] (User *U) -> bool { 175 LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); 176 if (!ValidOuterPHIUses.count(U)) { 177 LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); 178 return false; 179 } 180 LLVM_DEBUG(dbgs() << "Use is optimisable\n"); 181 return true; 182 }; 183 184 if (auto *V = dyn_cast<TruncInst>(U)) { 185 for (auto *K : V->users()) { 186 if (!IsValidOuterPHIUses(K)) 187 return false; 188 } 189 continue; 190 } 191 192 if (!IsValidOuterPHIUses(U)) 193 return false; 194 } 195 return true; 196 } 197 198 bool matchLinearIVUser(User *U, Value *InnerTripCount, 199 SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { 200 LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump()); 201 Value *MatchedMul = nullptr; 202 Value *MatchedItCount = nullptr; 203 204 bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI), 205 m_Value(MatchedMul))) && 206 match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI), 207 m_Value(MatchedItCount))); 208 209 // Matches the same pattern as above, except it also looks for truncs 210 // on the phi, which can be the result of widening the induction variables. 211 bool IsAddTrunc = 212 match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)), 213 m_Value(MatchedMul))) && 214 match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)), 215 m_Value(MatchedItCount))); 216 217 // Matches the pattern ptr+i*M+j, with the two additions being done via GEP. 218 bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)), 219 m_Specific(InnerInductionPHI))) && 220 match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI), 221 m_Value(MatchedItCount))); 222 223 if (!MatchedItCount) 224 return false; 225 226 LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump()); 227 LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump()); 228 229 // The mul should not have any other uses. Widening may leave trivially dead 230 // uses, which can be ignored. 231 if (count_if(MatchedMul->users(), [](User *U) { 232 return !isInstructionTriviallyDead(cast<Instruction>(U)); 233 }) > 1) { 234 LLVM_DEBUG(dbgs() << "Multiply has more than one use\n"); 235 return false; 236 } 237 238 // Look through extends if the IV has been widened. Don't look through 239 // extends if we already looked through a trunc. 240 if (Widened && (IsAdd || IsGEP) && 241 (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) { 242 assert(MatchedItCount->getType() == InnerInductionPHI->getType() && 243 "Unexpected type mismatch in types after widening"); 244 MatchedItCount = isa<SExtInst>(MatchedItCount) 245 ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0) 246 : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0); 247 } 248 249 LLVM_DEBUG(dbgs() << "Looking for inner trip count: "; 250 InnerTripCount->dump()); 251 252 if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) { 253 LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n"); 254 ValidOuterPHIUses.insert(MatchedMul); 255 LinearIVUses.insert(U); 256 return true; 257 } 258 259 LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); 260 return false; 261 } 262 263 bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { 264 Value *SExtInnerTripCount = InnerTripCount; 265 if (Widened && 266 (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount))) 267 SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0); 268 269 for (User *U : InnerInductionPHI->users()) { 270 LLVM_DEBUG(dbgs() << "Checking User: "; U->dump()); 271 if (isInnerLoopIncrement(U)) { 272 LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n"); 273 continue; 274 } 275 276 // After widening the IVs, a trunc instruction might have been introduced, 277 // so look through truncs. 278 if (isa<TruncInst>(U)) { 279 if (!U->hasOneUse()) 280 return false; 281 U = *U->user_begin(); 282 } 283 284 // If the use is in the compare (which is also the condition of the inner 285 // branch) then the compare has been altered by another transformation e.g 286 // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is 287 // a constant. Ignore this use as the compare gets removed later anyway. 288 if (isInnerLoopTest(U)) { 289 LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n"); 290 continue; 291 } 292 293 if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) { 294 LLVM_DEBUG(dbgs() << "Not a linear IV user\n"); 295 return false; 296 } 297 LLVM_DEBUG(dbgs() << "Linear IV users found!\n"); 298 } 299 return true; 300 } 301 }; 302 } // namespace 303 304 static bool 305 setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, 306 SmallPtrSetImpl<Instruction *> &IterationInstructions) { 307 TripCount = TC; 308 IterationInstructions.insert(Increment); 309 LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump()); 310 LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); 311 LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); 312 return true; 313 } 314 315 // Given the RHS of the loop latch compare instruction, verify with SCEV 316 // that this is indeed the loop tripcount. 317 // TODO: This used to be a straightforward check but has grown to be quite 318 // complicated now. It is therefore worth revisiting what the additional 319 // benefits are of this (compared to relying on canonical loops and pattern 320 // matching). 321 static bool verifyTripCount(Value *RHS, Loop *L, 322 SmallPtrSetImpl<Instruction *> &IterationInstructions, 323 PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, 324 BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { 325 const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); 326 if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { 327 LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); 328 return false; 329 } 330 331 // Evaluating in the trip count's type can not overflow here as the overflow 332 // checks are performed in checkOverflow, but are first tried to avoid by 333 // widening the IV. 334 const SCEV *SCEVTripCount = 335 SE->getTripCountFromExitCount(BackedgeTakenCount, 336 BackedgeTakenCount->getType(), L); 337 338 const SCEV *SCEVRHS = SE->getSCEV(RHS); 339 if (SCEVRHS == SCEVTripCount) 340 return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); 341 ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS); 342 if (ConstantRHS) { 343 const SCEV *BackedgeTCExt = nullptr; 344 if (IsWidened) { 345 const SCEV *SCEVTripCountExt; 346 // Find the extended backedge taken count and extended trip count using 347 // SCEV. One of these should now match the RHS of the compare. 348 BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); 349 SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, 350 RHS->getType(), L); 351 if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { 352 LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); 353 return false; 354 } 355 } 356 // If the RHS of the compare is equal to the backedge taken count we need 357 // to add one to get the trip count. 358 if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { 359 Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(), 360 ConstantRHS->getValue() + 1); 361 return setLoopComponents(NewRHS, TripCount, Increment, 362 IterationInstructions); 363 } 364 return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); 365 } 366 // If the RHS isn't a constant then check that the reason it doesn't match 367 // the SCEV trip count is because the RHS is a ZExt or SExt instruction 368 // (and take the trip count to be the RHS). 369 if (!IsWidened) { 370 LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); 371 return false; 372 } 373 auto *TripCountInst = dyn_cast<Instruction>(RHS); 374 if (!TripCountInst) { 375 LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); 376 return false; 377 } 378 if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) || 379 SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { 380 LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); 381 return false; 382 } 383 return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); 384 } 385 386 // Finds the induction variable, increment and trip count for a simple loop that 387 // we can flatten. 388 static bool findLoopComponents( 389 Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions, 390 PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, 391 BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { 392 LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); 393 394 if (!L->isLoopSimplifyForm()) { 395 LLVM_DEBUG(dbgs() << "Loop is not in normal form\n"); 396 return false; 397 } 398 399 // Currently, to simplify the implementation, the Loop induction variable must 400 // start at zero and increment with a step size of one. 401 if (!L->isCanonical(*SE)) { 402 LLVM_DEBUG(dbgs() << "Loop is not canonical\n"); 403 return false; 404 } 405 406 // There must be exactly one exiting block, and it must be the same at the 407 // latch. 408 BasicBlock *Latch = L->getLoopLatch(); 409 if (L->getExitingBlock() != Latch) { 410 LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n"); 411 return false; 412 } 413 414 // Find the induction PHI. If there is no induction PHI, we can't do the 415 // transformation. TODO: could other variables trigger this? Do we have to 416 // search for the best one? 417 InductionPHI = L->getInductionVariable(*SE); 418 if (!InductionPHI) { 419 LLVM_DEBUG(dbgs() << "Could not find induction PHI\n"); 420 return false; 421 } 422 LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); 423 424 bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0)); 425 auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { 426 if (ContinueOnTrue) 427 return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; 428 else 429 return Pred == CmpInst::ICMP_EQ; 430 }; 431 432 // Find Compare and make sure it is valid. getLatchCmpInst checks that the 433 // back branch of the latch is conditional. 434 ICmpInst *Compare = L->getLatchCmpInst(); 435 if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || 436 Compare->hasNUsesOrMore(2)) { 437 LLVM_DEBUG(dbgs() << "Could not find valid comparison\n"); 438 return false; 439 } 440 BackBranch = cast<BranchInst>(Latch->getTerminator()); 441 IterationInstructions.insert(BackBranch); 442 LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump()); 443 IterationInstructions.insert(Compare); 444 LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); 445 446 // Find increment and trip count. 447 // There are exactly 2 incoming values to the induction phi; one from the 448 // pre-header and one from the latch. The incoming latch value is the 449 // increment variable. 450 Increment = 451 cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); 452 if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) && 453 !Increment->hasNUses(1)) { 454 LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); 455 return false; 456 } 457 // The trip count is the RHS of the compare. If this doesn't match the trip 458 // count computed by SCEV then this is because the trip count variable 459 // has been widened so the types don't match, or because it is a constant and 460 // another transformation has changed the compare (e.g. icmp ult %inc, 461 // tripcount -> icmp ult %j, tripcount-1), or both. 462 Value *RHS = Compare->getOperand(1); 463 464 return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount, 465 Increment, BackBranch, SE, IsWidened); 466 } 467 468 static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { 469 // All PHIs in the inner and outer headers must either be: 470 // - The induction PHI, which we are going to rewrite as one induction in 471 // the new loop. This is already checked by findLoopComponents. 472 // - An outer header PHI with all incoming values from outside the loop. 473 // LoopSimplify guarantees we have a pre-header, so we don't need to 474 // worry about that here. 475 // - Pairs of PHIs in the inner and outer headers, which implement a 476 // loop-carried dependency that will still be valid in the new loop. To 477 // be valid, this variable must be modified only in the inner loop. 478 479 // The set of PHI nodes in the outer loop header that we know will still be 480 // valid after the transformation. These will not need to be modified (with 481 // the exception of the induction variable), but we do need to check that 482 // there are no unsafe PHI nodes. 483 SmallPtrSet<PHINode *, 4> SafeOuterPHIs; 484 SafeOuterPHIs.insert(FI.OuterInductionPHI); 485 486 // Check that all PHI nodes in the inner loop header match one of the valid 487 // patterns. 488 for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) { 489 // The induction PHIs break these rules, and that's OK because we treat 490 // them specially when doing the transformation. 491 if (&InnerPHI == FI.InnerInductionPHI) 492 continue; 493 if (FI.isNarrowInductionPhi(&InnerPHI)) 494 continue; 495 496 // Each inner loop PHI node must have two incoming values/blocks - one 497 // from the pre-header, and one from the latch. 498 assert(InnerPHI.getNumIncomingValues() == 2); 499 Value *PreHeaderValue = 500 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader()); 501 Value *LatchValue = 502 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch()); 503 504 // The incoming value from the outer loop must be the PHI node in the 505 // outer loop header, with no modifications made in the top of the outer 506 // loop. 507 PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue); 508 if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) { 509 LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n"); 510 return false; 511 } 512 513 // The other incoming value must come from the inner loop, without any 514 // modifications in the tail end of the outer loop. We are in LCSSA form, 515 // so this will actually be a PHI in the inner loop's exit block, which 516 // only uses values from inside the inner loop. 517 PHINode *LCSSAPHI = dyn_cast<PHINode>( 518 OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch())); 519 if (!LCSSAPHI) { 520 LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n"); 521 return false; 522 } 523 524 // The value used by the LCSSA PHI must be the same one that the inner 525 // loop's PHI uses. 526 if (LCSSAPHI->hasConstantValue() != LatchValue) { 527 LLVM_DEBUG( 528 dbgs() << "LCSSA PHI incoming value does not match latch value\n"); 529 return false; 530 } 531 532 LLVM_DEBUG(dbgs() << "PHI pair is safe:\n"); 533 LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump()); 534 LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump()); 535 SafeOuterPHIs.insert(OuterPHI); 536 FI.InnerPHIsToTransform.insert(&InnerPHI); 537 } 538 539 for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { 540 if (FI.isNarrowInductionPhi(&OuterPHI)) 541 continue; 542 if (!SafeOuterPHIs.count(&OuterPHI)) { 543 LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); 544 return false; 545 } 546 } 547 548 LLVM_DEBUG(dbgs() << "checkPHIs: OK\n"); 549 return true; 550 } 551 552 static bool 553 checkOuterLoopInsts(FlattenInfo &FI, 554 SmallPtrSetImpl<Instruction *> &IterationInstructions, 555 const TargetTransformInfo *TTI) { 556 // Check for instructions in the outer but not inner loop. If any of these 557 // have side-effects then this transformation is not legal, and if there is 558 // a significant amount of code here which can't be optimised out that it's 559 // not profitable (as these instructions would get executed for each 560 // iteration of the inner loop). 561 InstructionCost RepeatedInstrCost = 0; 562 for (auto *B : FI.OuterLoop->getBlocks()) { 563 if (FI.InnerLoop->contains(B)) 564 continue; 565 566 for (auto &I : *B) { 567 if (!isa<PHINode>(&I) && !I.isTerminator() && 568 !isSafeToSpeculativelyExecute(&I)) { 569 LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have " 570 "side effects: "; 571 I.dump()); 572 return false; 573 } 574 // The execution count of the outer loop's iteration instructions 575 // (increment, compare and branch) will be increased, but the 576 // equivalent instructions will be removed from the inner loop, so 577 // they make a net difference of zero. 578 if (IterationInstructions.count(&I)) 579 continue; 580 // The unconditional branch to the inner loop's header will turn into 581 // a fall-through, so adds no cost. 582 BranchInst *Br = dyn_cast<BranchInst>(&I); 583 if (Br && Br->isUnconditional() && 584 Br->getSuccessor(0) == FI.InnerLoop->getHeader()) 585 continue; 586 // Multiplies of the outer iteration variable and inner iteration 587 // count will be optimised out. 588 if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), 589 m_Specific(FI.InnerTripCount)))) 590 continue; 591 InstructionCost Cost = 592 TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); 593 LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); 594 RepeatedInstrCost += Cost; 595 } 596 } 597 598 LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: " 599 << RepeatedInstrCost << "\n"); 600 // Bail out if flattening the loops would cause instructions in the outer 601 // loop but not in the inner loop to be executed extra times. 602 if (RepeatedInstrCost > RepeatedInstructionThreshold) { 603 LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n"); 604 return false; 605 } 606 607 LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n"); 608 return true; 609 } 610 611 612 613 // We require all uses of both induction variables to match this pattern: 614 // 615 // (OuterPHI * InnerTripCount) + InnerPHI 616 // 617 // Any uses of the induction variables not matching that pattern would 618 // require a div/mod to reconstruct in the flattened loop, so the 619 // transformation wouldn't be profitable. 620 static bool checkIVUsers(FlattenInfo &FI) { 621 // Check that all uses of the inner loop's induction variable match the 622 // expected pattern, recording the uses of the outer IV. 623 SmallPtrSet<Value *, 4> ValidOuterPHIUses; 624 if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses)) 625 return false; 626 627 // Check that there are no uses of the outer IV other than the ones found 628 // as part of the pattern above. 629 if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses)) 630 return false; 631 632 LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n"; 633 dbgs() << "Found " << FI.LinearIVUses.size() 634 << " value(s) that can be replaced:\n"; 635 for (Value *V : FI.LinearIVUses) { 636 dbgs() << " "; 637 V->dump(); 638 }); 639 return true; 640 } 641 642 // Return an OverflowResult dependant on if overflow of the multiplication of 643 // InnerTripCount and OuterTripCount can be assumed not to happen. 644 static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, 645 AssumptionCache *AC) { 646 Function *F = FI.OuterLoop->getHeader()->getParent(); 647 const DataLayout &DL = F->getDataLayout(); 648 649 // For debugging/testing. 650 if (AssumeNoOverflow) 651 return OverflowResult::NeverOverflows; 652 653 // Check if the multiply could not overflow due to known ranges of the 654 // input values. 655 OverflowResult OR = computeOverflowForUnsignedMul( 656 FI.InnerTripCount, FI.OuterTripCount, 657 SimplifyQuery(DL, DT, AC, 658 FI.OuterLoop->getLoopPreheader()->getTerminator())); 659 if (OR != OverflowResult::MayOverflow) 660 return OR; 661 662 auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) { 663 for (Value *GEPUser : GEP->users()) { 664 auto *GEPUserInst = cast<Instruction>(GEPUser); 665 if (!isa<LoadInst>(GEPUserInst) && 666 !(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1))) 667 continue; 668 if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop)) 669 continue; 670 // The IV is used as the operand of a GEP which dominates the loop 671 // latch, and the IV is at least as wide as the address space of the 672 // GEP. In this case, the GEP would wrap around the address space 673 // before the IV increment wraps, which would be UB. 674 if (GEP->isInBounds() && 675 GEPOperand->getType()->getIntegerBitWidth() >= 676 DL.getPointerTypeSizeInBits(GEP->getType())) { 677 LLVM_DEBUG( 678 dbgs() << "use of linear IV would be UB if overflow occurred: "; 679 GEP->dump()); 680 return true; 681 } 682 } 683 return false; 684 }; 685 686 // Check if any IV user is, or is used by, a GEP that would cause UB if the 687 // multiply overflows. 688 for (Value *V : FI.LinearIVUses) { 689 if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) 690 if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1))) 691 return OverflowResult::NeverOverflows; 692 for (Value *U : V->users()) 693 if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) 694 if (CheckGEP(GEP, V)) 695 return OverflowResult::NeverOverflows; 696 } 697 698 return OverflowResult::MayOverflow; 699 } 700 701 static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, 702 ScalarEvolution *SE, AssumptionCache *AC, 703 const TargetTransformInfo *TTI) { 704 SmallPtrSet<Instruction *, 8> IterationInstructions; 705 if (!findLoopComponents(FI.InnerLoop, IterationInstructions, 706 FI.InnerInductionPHI, FI.InnerTripCount, 707 FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened)) 708 return false; 709 if (!findLoopComponents(FI.OuterLoop, IterationInstructions, 710 FI.OuterInductionPHI, FI.OuterTripCount, 711 FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened)) 712 return false; 713 714 // Both of the loop trip count values must be invariant in the outer loop 715 // (non-instructions are all inherently invariant). 716 if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) { 717 LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n"); 718 return false; 719 } 720 if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) { 721 LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n"); 722 return false; 723 } 724 725 if (!checkPHIs(FI, TTI)) 726 return false; 727 728 // FIXME: it should be possible to handle different types correctly. 729 if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType()) 730 return false; 731 732 if (!checkOuterLoopInsts(FI, IterationInstructions, TTI)) 733 return false; 734 735 // Find the values in the loop that can be replaced with the linearized 736 // induction variable, and check that there are no other uses of the inner 737 // or outer induction variable. If there were, we could still do this 738 // transformation, but we'd have to insert a div/mod to calculate the 739 // original IVs, so it wouldn't be profitable. 740 if (!checkIVUsers(FI)) 741 return false; 742 743 LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n"); 744 return true; 745 } 746 747 static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, 748 ScalarEvolution *SE, AssumptionCache *AC, 749 const TargetTransformInfo *TTI, LPMUpdater *U, 750 MemorySSAUpdater *MSSAU) { 751 Function *F = FI.OuterLoop->getHeader()->getParent(); 752 LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); 753 { 754 using namespace ore; 755 OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(), 756 FI.InnerLoop->getHeader()); 757 OptimizationRemarkEmitter ORE(F); 758 Remark << "Flattened into outer loop"; 759 ORE.emit(Remark); 760 } 761 762 if (!FI.NewTripCount) { 763 FI.NewTripCount = BinaryOperator::CreateMul( 764 FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", 765 FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator()); 766 LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; 767 FI.NewTripCount->dump()); 768 } 769 770 // Fix up PHI nodes that take values from the inner loop back-edge, which 771 // we are about to remove. 772 FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); 773 774 // The old Phi will be optimised away later, but for now we can't leave 775 // leave it in an invalid state, so are updating them too. 776 for (PHINode *PHI : FI.InnerPHIsToTransform) 777 PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); 778 779 // Modify the trip count of the outer loop to be the product of the two 780 // trip counts. 781 cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount); 782 783 // Replace the inner loop backedge with an unconditional branch to the exit. 784 BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); 785 BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); 786 Instruction *Term = InnerExitingBlock->getTerminator(); 787 Instruction *BI = BranchInst::Create(InnerExitBlock, InnerExitingBlock); 788 BI->setDebugLoc(Term->getDebugLoc()); 789 Term->eraseFromParent(); 790 791 // Update the DomTree and MemorySSA. 792 DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); 793 if (MSSAU) 794 MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); 795 796 // Replace all uses of the polynomial calculated from the two induction 797 // variables with the one new one. 798 IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator()); 799 for (Value *V : FI.LinearIVUses) { 800 Value *OuterValue = FI.OuterInductionPHI; 801 if (FI.Widened) 802 OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(), 803 "flatten.trunciv"); 804 805 if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) { 806 // Replace the GEP with one that uses OuterValue as the offset. 807 auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0)); 808 Value *Base = InnerGEP->getOperand(0); 809 // When the base of the GEP doesn't dominate the outer induction phi then 810 // we need to insert the new GEP where the old GEP was. 811 if (!DT->dominates(Base, &*Builder.GetInsertPoint())) 812 Builder.SetInsertPoint(cast<Instruction>(V)); 813 OuterValue = 814 Builder.CreateGEP(GEP->getSourceElementType(), Base, OuterValue, 815 "flatten." + V->getName(), 816 GEP->isInBounds() && InnerGEP->isInBounds()); 817 } 818 819 LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: "; 820 OuterValue->dump()); 821 V->replaceAllUsesWith(OuterValue); 822 } 823 824 // Tell LoopInfo, SCEV and the pass manager that the inner loop has been 825 // deleted, and invalidate any outer loop information. 826 SE->forgetLoop(FI.OuterLoop); 827 SE->forgetBlockAndLoopDispositions(); 828 if (U) 829 U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName()); 830 LI->erase(FI.InnerLoop); 831 832 // Increment statistic value. 833 NumFlattened++; 834 835 return true; 836 } 837 838 static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, 839 ScalarEvolution *SE, AssumptionCache *AC, 840 const TargetTransformInfo *TTI) { 841 if (!WidenIV) { 842 LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n"); 843 return false; 844 } 845 846 LLVM_DEBUG(dbgs() << "Try widening the IVs\n"); 847 Module *M = FI.InnerLoop->getHeader()->getParent()->getParent(); 848 auto &DL = M->getDataLayout(); 849 auto *InnerType = FI.InnerInductionPHI->getType(); 850 auto *OuterType = FI.OuterInductionPHI->getType(); 851 unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); 852 auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext()); 853 854 // If both induction types are less than the maximum legal integer width, 855 // promote both to the widest type available so we know calculating 856 // (OuterTripCount * InnerTripCount) as the new trip count is safe. 857 if (InnerType != OuterType || 858 InnerType->getScalarSizeInBits() >= MaxLegalSize || 859 MaxLegalType->getScalarSizeInBits() < 860 InnerType->getScalarSizeInBits() * 2) { 861 LLVM_DEBUG(dbgs() << "Can't widen the IV\n"); 862 return false; 863 } 864 865 SCEVExpander Rewriter(*SE, DL, "loopflatten"); 866 SmallVector<WeakTrackingVH, 4> DeadInsts; 867 unsigned ElimExt = 0; 868 unsigned Widened = 0; 869 870 auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool { 871 PHINode *WidePhi = 872 createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened, 873 true /* HasGuards */, true /* UsePostIncrementRanges */); 874 if (!WidePhi) 875 return false; 876 LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); 877 LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump()); 878 Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); 879 return true; 880 }; 881 882 bool Deleted; 883 if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted)) 884 return false; 885 // Add the narrow phi to list, so that it will be adjusted later when the 886 // the transformation is performed. 887 if (!Deleted) 888 FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI); 889 890 if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted)) 891 return false; 892 893 assert(Widened && "Widened IV expected"); 894 FI.Widened = true; 895 896 // Save the old/narrow induction phis, which we need to ignore in CheckPHIs. 897 FI.NarrowInnerInductionPHI = FI.InnerInductionPHI; 898 FI.NarrowOuterInductionPHI = FI.OuterInductionPHI; 899 900 // After widening, rediscover all the loop components. 901 return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); 902 } 903 904 static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, 905 ScalarEvolution *SE, AssumptionCache *AC, 906 const TargetTransformInfo *TTI, LPMUpdater *U, 907 MemorySSAUpdater *MSSAU, 908 const LoopAccessInfo &LAI) { 909 LLVM_DEBUG( 910 dbgs() << "Loop flattening running on outer loop " 911 << FI.OuterLoop->getHeader()->getName() << " and inner loop " 912 << FI.InnerLoop->getHeader()->getName() << " in " 913 << FI.OuterLoop->getHeader()->getParent()->getName() << "\n"); 914 915 if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI)) 916 return false; 917 918 // Check if we can widen the induction variables to avoid overflow checks. 919 bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI); 920 921 // It can happen that after widening of the IV, flattening may not be 922 // possible/happening, e.g. when it is deemed unprofitable. So bail here if 923 // that is the case. 924 // TODO: IV widening without performing the actual flattening transformation 925 // is not ideal. While this codegen change should not matter much, it is an 926 // unnecessary change which is better to avoid. It's unlikely this happens 927 // often, because if it's unprofitibale after widening, it should be 928 // unprofitabe before widening as checked in the first round of checks. But 929 // 'RepeatedInstructionThreshold' is set to only 2, which can probably be 930 // relaxed. Because this is making a code change (the IV widening, but not 931 // the flattening), we return true here. 932 if (FI.Widened && !CanFlatten) 933 return true; 934 935 // If we have widened and can perform the transformation, do that here. 936 if (CanFlatten) 937 return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); 938 939 // Otherwise, if we haven't widened the IV, check if the new iteration 940 // variable might overflow. In this case, we need to version the loop, and 941 // select the original version at runtime if the iteration space is too 942 // large. 943 OverflowResult OR = checkOverflow(FI, DT, AC); 944 if (OR == OverflowResult::AlwaysOverflowsHigh || 945 OR == OverflowResult::AlwaysOverflowsLow) { 946 LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); 947 return false; 948 } else if (OR == OverflowResult::MayOverflow) { 949 Module *M = FI.OuterLoop->getHeader()->getParent()->getParent(); 950 const DataLayout &DL = M->getDataLayout(); 951 if (!VersionLoops) { 952 LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); 953 return false; 954 } else if (!DL.isLegalInteger( 955 FI.OuterTripCount->getType()->getScalarSizeInBits())) { 956 // If the trip count type isn't legal then it won't be possible to check 957 // for overflow using only a single multiply instruction, so don't 958 // flatten. 959 LLVM_DEBUG( 960 dbgs() << "Can't check overflow efficiently, not flattening\n"); 961 return false; 962 } 963 LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n"); 964 965 // Version the loop. The overflow check isn't a runtime pointer check, so we 966 // pass an empty list of runtime pointer checks, causing LoopVersioning to 967 // emit 'false' as the branch condition, and add our own check afterwards. 968 BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader(); 969 ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr); 970 LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE); 971 LVer.versionLoop(); 972 973 // Check for overflow by calculating the new tripcount using 974 // umul_with_overflow and then checking if it overflowed. 975 BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator()); 976 assert(Br->isConditional() && 977 "Expected LoopVersioning to generate a conditional branch"); 978 assert(match(Br->getCondition(), m_Zero()) && 979 "Expected branch condition to be false"); 980 IRBuilder<> Builder(Br); 981 Function *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, 982 FI.OuterTripCount->getType()); 983 Value *Call = Builder.CreateCall(F, {FI.OuterTripCount, FI.InnerTripCount}, 984 "flatten.mul"); 985 FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount"); 986 Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow"); 987 Br->setCondition(Overflow); 988 } else { 989 LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); 990 } 991 992 return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); 993 } 994 995 PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, 996 LoopStandardAnalysisResults &AR, 997 LPMUpdater &U) { 998 999 bool Changed = false; 1000 1001 std::optional<MemorySSAUpdater> MSSAU; 1002 if (AR.MSSA) { 1003 MSSAU = MemorySSAUpdater(AR.MSSA); 1004 if (VerifyMemorySSA) 1005 AR.MSSA->verifyMemorySSA(); 1006 } 1007 1008 // The loop flattening pass requires loops to be 1009 // in simplified form, and also needs LCSSA. Running 1010 // this pass will simplify all loops that contain inner loops, 1011 // regardless of whether anything ends up being flattened. 1012 LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); 1013 for (Loop *InnerLoop : LN.getLoops()) { 1014 auto *OuterLoop = InnerLoop->getParentLoop(); 1015 if (!OuterLoop) 1016 continue; 1017 FlattenInfo FI(OuterLoop, InnerLoop); 1018 Changed |= 1019 FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, 1020 MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop)); 1021 } 1022 1023 if (!Changed) 1024 return PreservedAnalyses::all(); 1025 1026 if (AR.MSSA && VerifyMemorySSA) 1027 AR.MSSA->verifyMemorySSA(); 1028 1029 auto PA = getLoopPassPreservedAnalyses(); 1030 if (AR.MSSA) 1031 PA.preserve<MemorySSAAnalysis>(); 1032 return PA; 1033 } 1034