1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the scalar evolution analysis 10 // engine, which is used primarily to analyze expressions involving induction 11 // variables in loops. 12 // 13 // There are several aspects to this library. First is the representation of 14 // scalar expressions, which are represented as subclasses of the SCEV class. 15 // These classes are used to represent certain types of subexpressions that we 16 // can handle. We only create one SCEV of a particular shape, so 17 // pointer-comparisons for equality are legal. 18 // 19 // One important aspect of the SCEV objects is that they are never cyclic, even 20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If 21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial 22 // recurrence) then we represent it directly as a recurrence node, otherwise we 23 // represent it as a SCEVUnknown node. 24 // 25 // In addition to being able to represent expressions of various types, we also 26 // have folders that are used to build the *canonical* representation for a 27 // particular expression. These folders are capable of using a variety of 28 // rewrite rules to simplify the expressions. 29 // 30 // Once the folders are defined, we can implement the more interesting 31 // higher-level code, such as the code that recognizes PHI nodes of various 32 // types, computes the execution count of a loop, etc. 33 // 34 // TODO: We should use these routines and value representations to implement 35 // dependence analysis! 36 // 37 //===----------------------------------------------------------------------===// 38 // 39 // There are several good references for the techniques used in this analysis. 40 // 41 // Chains of recurrences -- a method to expedite the evaluation 42 // of closed-form functions 43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima 44 // 45 // On computational properties of chains of recurrences 46 // Eugene V. Zima 47 // 48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization 49 // Robert A. van Engelen 50 // 51 // Efficient Symbolic Analysis for Optimizing Compilers 52 // Robert A. van Engelen 53 // 54 // Using the chains of recurrences algebra for data dependence testing and 55 // induction variable substitution 56 // MS Thesis, Johnie Birch 57 // 58 //===----------------------------------------------------------------------===// 59 60 #include "llvm/Analysis/ScalarEvolution.h" 61 #include "llvm/ADT/APInt.h" 62 #include "llvm/ADT/ArrayRef.h" 63 #include "llvm/ADT/DenseMap.h" 64 #include "llvm/ADT/DepthFirstIterator.h" 65 #include "llvm/ADT/FoldingSet.h" 66 #include "llvm/ADT/STLExtras.h" 67 #include "llvm/ADT/ScopeExit.h" 68 #include "llvm/ADT/Sequence.h" 69 #include "llvm/ADT/SmallPtrSet.h" 70 #include "llvm/ADT/SmallSet.h" 71 #include "llvm/ADT/SmallVector.h" 72 #include "llvm/ADT/Statistic.h" 73 #include "llvm/ADT/StringExtras.h" 74 #include "llvm/ADT/StringRef.h" 75 #include "llvm/Analysis/AssumptionCache.h" 76 #include "llvm/Analysis/ConstantFolding.h" 77 #include "llvm/Analysis/InstructionSimplify.h" 78 #include "llvm/Analysis/LoopInfo.h" 79 #include "llvm/Analysis/MemoryBuiltins.h" 80 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 81 #include "llvm/Analysis/ScalarEvolutionPatternMatch.h" 82 #include "llvm/Analysis/TargetLibraryInfo.h" 83 #include "llvm/Analysis/ValueTracking.h" 84 #include "llvm/Config/llvm-config.h" 85 #include "llvm/IR/Argument.h" 86 #include "llvm/IR/BasicBlock.h" 87 #include "llvm/IR/CFG.h" 88 #include "llvm/IR/Constant.h" 89 #include "llvm/IR/ConstantRange.h" 90 #include "llvm/IR/Constants.h" 91 #include "llvm/IR/DataLayout.h" 92 #include "llvm/IR/DerivedTypes.h" 93 #include "llvm/IR/Dominators.h" 94 #include "llvm/IR/Function.h" 95 #include "llvm/IR/GlobalAlias.h" 96 #include "llvm/IR/GlobalValue.h" 97 #include "llvm/IR/InstIterator.h" 98 #include "llvm/IR/InstrTypes.h" 99 #include "llvm/IR/Instruction.h" 100 #include "llvm/IR/Instructions.h" 101 #include "llvm/IR/IntrinsicInst.h" 102 #include "llvm/IR/Intrinsics.h" 103 #include "llvm/IR/LLVMContext.h" 104 #include "llvm/IR/Operator.h" 105 #include "llvm/IR/PatternMatch.h" 106 #include "llvm/IR/Type.h" 107 #include "llvm/IR/Use.h" 108 #include "llvm/IR/User.h" 109 #include "llvm/IR/Value.h" 110 #include "llvm/IR/Verifier.h" 111 #include "llvm/InitializePasses.h" 112 #include "llvm/Pass.h" 113 #include "llvm/Support/Casting.h" 114 #include "llvm/Support/CommandLine.h" 115 #include "llvm/Support/Compiler.h" 116 #include "llvm/Support/Debug.h" 117 #include "llvm/Support/ErrorHandling.h" 118 #include "llvm/Support/InterleavedRange.h" 119 #include "llvm/Support/KnownBits.h" 120 #include "llvm/Support/SaveAndRestore.h" 121 #include "llvm/Support/raw_ostream.h" 122 #include <algorithm> 123 #include <cassert> 124 #include <climits> 125 #include <cstdint> 126 #include <cstdlib> 127 #include <map> 128 #include <memory> 129 #include <numeric> 130 #include <optional> 131 #include <tuple> 132 #include <utility> 133 #include <vector> 134 135 using namespace llvm; 136 using namespace PatternMatch; 137 using namespace SCEVPatternMatch; 138 139 #define DEBUG_TYPE "scalar-evolution" 140 141 STATISTIC(NumExitCountsComputed, 142 "Number of loop exits with predictable exit counts"); 143 STATISTIC(NumExitCountsNotComputed, 144 "Number of loop exits without predictable exit counts"); 145 STATISTIC(NumBruteForceTripCountsComputed, 146 "Number of loops with trip counts computed by force"); 147 148 #ifdef EXPENSIVE_CHECKS 149 bool llvm::VerifySCEV = true; 150 #else 151 bool llvm::VerifySCEV = false; 152 #endif 153 154 static cl::opt<unsigned> 155 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, 156 cl::desc("Maximum number of iterations SCEV will " 157 "symbolically execute a constant " 158 "derived loop"), 159 cl::init(100)); 160 161 static cl::opt<bool, true> VerifySCEVOpt( 162 "verify-scev", cl::Hidden, cl::location(VerifySCEV), 163 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); 164 static cl::opt<bool> VerifySCEVStrict( 165 "verify-scev-strict", cl::Hidden, 166 cl::desc("Enable stricter verification with -verify-scev is passed")); 167 168 static cl::opt<bool> VerifyIR( 169 "scev-verify-ir", cl::Hidden, 170 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), 171 cl::init(false)); 172 173 static cl::opt<unsigned> MulOpsInlineThreshold( 174 "scev-mulops-inline-threshold", cl::Hidden, 175 cl::desc("Threshold for inlining multiplication operands into a SCEV"), 176 cl::init(32)); 177 178 static cl::opt<unsigned> AddOpsInlineThreshold( 179 "scev-addops-inline-threshold", cl::Hidden, 180 cl::desc("Threshold for inlining addition operands into a SCEV"), 181 cl::init(500)); 182 183 static cl::opt<unsigned> MaxSCEVCompareDepth( 184 "scalar-evolution-max-scev-compare-depth", cl::Hidden, 185 cl::desc("Maximum depth of recursive SCEV complexity comparisons"), 186 cl::init(32)); 187 188 static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( 189 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, 190 cl::desc("Maximum depth of recursive SCEV operations implication analysis"), 191 cl::init(2)); 192 193 static cl::opt<unsigned> MaxValueCompareDepth( 194 "scalar-evolution-max-value-compare-depth", cl::Hidden, 195 cl::desc("Maximum depth of recursive value complexity comparisons"), 196 cl::init(2)); 197 198 static cl::opt<unsigned> 199 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, 200 cl::desc("Maximum depth of recursive arithmetics"), 201 cl::init(32)); 202 203 static cl::opt<unsigned> MaxConstantEvolvingDepth( 204 "scalar-evolution-max-constant-evolving-depth", cl::Hidden, 205 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); 206 207 static cl::opt<unsigned> 208 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, 209 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), 210 cl::init(8)); 211 212 static cl::opt<unsigned> 213 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, 214 cl::desc("Max coefficients in AddRec during evolving"), 215 cl::init(8)); 216 217 static cl::opt<unsigned> 218 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, 219 cl::desc("Size of the expression which is considered huge"), 220 cl::init(4096)); 221 222 static cl::opt<unsigned> RangeIterThreshold( 223 "scev-range-iter-threshold", cl::Hidden, 224 cl::desc("Threshold for switching to iteratively computing SCEV ranges"), 225 cl::init(32)); 226 227 static cl::opt<unsigned> MaxLoopGuardCollectionDepth( 228 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, 229 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1)); 230 231 static cl::opt<bool> 232 ClassifyExpressions("scalar-evolution-classify-expressions", 233 cl::Hidden, cl::init(true), 234 cl::desc("When printing analysis, include information on every instruction")); 235 236 static cl::opt<bool> UseExpensiveRangeSharpening( 237 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden, 238 cl::init(false), 239 cl::desc("Use more powerful methods of sharpening expression ranges. May " 240 "be costly in terms of compile time")); 241 242 static cl::opt<unsigned> MaxPhiSCCAnalysisSize( 243 "scalar-evolution-max-scc-analysis-depth", cl::Hidden, 244 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " 245 "Phi strongly connected components"), 246 cl::init(8)); 247 248 static cl::opt<bool> 249 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, 250 cl::desc("Handle <= and >= in finite loops"), 251 cl::init(true)); 252 253 static cl::opt<bool> UseContextForNoWrapFlagInference( 254 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, 255 cl::desc("Infer nuw/nsw flags using context where suitable"), 256 cl::init(true)); 257 258 //===----------------------------------------------------------------------===// 259 // SCEV class definitions 260 //===----------------------------------------------------------------------===// 261 262 //===----------------------------------------------------------------------===// 263 // Implementation of the SCEV class. 264 // 265 266 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 267 LLVM_DUMP_METHOD void SCEV::dump() const { 268 print(dbgs()); 269 dbgs() << '\n'; 270 } 271 #endif 272 273 void SCEV::print(raw_ostream &OS) const { 274 switch (getSCEVType()) { 275 case scConstant: 276 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); 277 return; 278 case scVScale: 279 OS << "vscale"; 280 return; 281 case scPtrToInt: { 282 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this); 283 const SCEV *Op = PtrToInt->getOperand(); 284 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to " 285 << *PtrToInt->getType() << ")"; 286 return; 287 } 288 case scTruncate: { 289 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); 290 const SCEV *Op = Trunc->getOperand(); 291 OS << "(trunc " << *Op->getType() << " " << *Op << " to " 292 << *Trunc->getType() << ")"; 293 return; 294 } 295 case scZeroExtend: { 296 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this); 297 const SCEV *Op = ZExt->getOperand(); 298 OS << "(zext " << *Op->getType() << " " << *Op << " to " 299 << *ZExt->getType() << ")"; 300 return; 301 } 302 case scSignExtend: { 303 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this); 304 const SCEV *Op = SExt->getOperand(); 305 OS << "(sext " << *Op->getType() << " " << *Op << " to " 306 << *SExt->getType() << ")"; 307 return; 308 } 309 case scAddRecExpr: { 310 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this); 311 OS << "{" << *AR->getOperand(0); 312 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) 313 OS << ",+," << *AR->getOperand(i); 314 OS << "}<"; 315 if (AR->hasNoUnsignedWrap()) 316 OS << "nuw><"; 317 if (AR->hasNoSignedWrap()) 318 OS << "nsw><"; 319 if (AR->hasNoSelfWrap() && 320 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) 321 OS << "nw><"; 322 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); 323 OS << ">"; 324 return; 325 } 326 case scAddExpr: 327 case scMulExpr: 328 case scUMaxExpr: 329 case scSMaxExpr: 330 case scUMinExpr: 331 case scSMinExpr: 332 case scSequentialUMinExpr: { 333 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this); 334 const char *OpStr = nullptr; 335 switch (NAry->getSCEVType()) { 336 case scAddExpr: OpStr = " + "; break; 337 case scMulExpr: OpStr = " * "; break; 338 case scUMaxExpr: OpStr = " umax "; break; 339 case scSMaxExpr: OpStr = " smax "; break; 340 case scUMinExpr: 341 OpStr = " umin "; 342 break; 343 case scSMinExpr: 344 OpStr = " smin "; 345 break; 346 case scSequentialUMinExpr: 347 OpStr = " umin_seq "; 348 break; 349 default: 350 llvm_unreachable("There are no other nary expression types."); 351 } 352 OS << "(" 353 << llvm::interleaved(llvm::make_pointee_range(NAry->operands()), OpStr) 354 << ")"; 355 switch (NAry->getSCEVType()) { 356 case scAddExpr: 357 case scMulExpr: 358 if (NAry->hasNoUnsignedWrap()) 359 OS << "<nuw>"; 360 if (NAry->hasNoSignedWrap()) 361 OS << "<nsw>"; 362 break; 363 default: 364 // Nothing to print for other nary expressions. 365 break; 366 } 367 return; 368 } 369 case scUDivExpr: { 370 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this); 371 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; 372 return; 373 } 374 case scUnknown: 375 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false); 376 return; 377 case scCouldNotCompute: 378 OS << "***COULDNOTCOMPUTE***"; 379 return; 380 } 381 llvm_unreachable("Unknown SCEV kind!"); 382 } 383 384 Type *SCEV::getType() const { 385 switch (getSCEVType()) { 386 case scConstant: 387 return cast<SCEVConstant>(this)->getType(); 388 case scVScale: 389 return cast<SCEVVScale>(this)->getType(); 390 case scPtrToInt: 391 case scTruncate: 392 case scZeroExtend: 393 case scSignExtend: 394 return cast<SCEVCastExpr>(this)->getType(); 395 case scAddRecExpr: 396 return cast<SCEVAddRecExpr>(this)->getType(); 397 case scMulExpr: 398 return cast<SCEVMulExpr>(this)->getType(); 399 case scUMaxExpr: 400 case scSMaxExpr: 401 case scUMinExpr: 402 case scSMinExpr: 403 return cast<SCEVMinMaxExpr>(this)->getType(); 404 case scSequentialUMinExpr: 405 return cast<SCEVSequentialMinMaxExpr>(this)->getType(); 406 case scAddExpr: 407 return cast<SCEVAddExpr>(this)->getType(); 408 case scUDivExpr: 409 return cast<SCEVUDivExpr>(this)->getType(); 410 case scUnknown: 411 return cast<SCEVUnknown>(this)->getType(); 412 case scCouldNotCompute: 413 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 414 } 415 llvm_unreachable("Unknown SCEV kind!"); 416 } 417 418 ArrayRef<const SCEV *> SCEV::operands() const { 419 switch (getSCEVType()) { 420 case scConstant: 421 case scVScale: 422 case scUnknown: 423 return {}; 424 case scPtrToInt: 425 case scTruncate: 426 case scZeroExtend: 427 case scSignExtend: 428 return cast<SCEVCastExpr>(this)->operands(); 429 case scAddRecExpr: 430 case scAddExpr: 431 case scMulExpr: 432 case scUMaxExpr: 433 case scSMaxExpr: 434 case scUMinExpr: 435 case scSMinExpr: 436 case scSequentialUMinExpr: 437 return cast<SCEVNAryExpr>(this)->operands(); 438 case scUDivExpr: 439 return cast<SCEVUDivExpr>(this)->operands(); 440 case scCouldNotCompute: 441 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 442 } 443 llvm_unreachable("Unknown SCEV kind!"); 444 } 445 446 bool SCEV::isZero() const { return match(this, m_scev_Zero()); } 447 448 bool SCEV::isOne() const { return match(this, m_scev_One()); } 449 450 bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); } 451 452 bool SCEV::isNonConstantNegative() const { 453 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); 454 if (!Mul) return false; 455 456 // If there is a constant factor, it will be first. 457 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); 458 if (!SC) return false; 459 460 // Return true if the value is negative, this matches things like (-42 * V). 461 return SC->getAPInt().isNegative(); 462 } 463 464 SCEVCouldNotCompute::SCEVCouldNotCompute() : 465 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {} 466 467 bool SCEVCouldNotCompute::classof(const SCEV *S) { 468 return S->getSCEVType() == scCouldNotCompute; 469 } 470 471 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { 472 FoldingSetNodeID ID; 473 ID.AddInteger(scConstant); 474 ID.AddPointer(V); 475 void *IP = nullptr; 476 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 477 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); 478 UniqueSCEVs.InsertNode(S, IP); 479 return S; 480 } 481 482 const SCEV *ScalarEvolution::getConstant(const APInt &Val) { 483 return getConstant(ConstantInt::get(getContext(), Val)); 484 } 485 486 const SCEV * 487 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { 488 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty)); 489 return getConstant(ConstantInt::get(ITy, V, isSigned)); 490 } 491 492 const SCEV *ScalarEvolution::getVScale(Type *Ty) { 493 FoldingSetNodeID ID; 494 ID.AddInteger(scVScale); 495 ID.AddPointer(Ty); 496 void *IP = nullptr; 497 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) 498 return S; 499 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); 500 UniqueSCEVs.InsertNode(S, IP); 501 return S; 502 } 503 504 const SCEV *ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { 505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue()); 506 if (EC.isScalable()) 507 Res = getMulExpr(Res, getVScale(Ty)); 508 return Res; 509 } 510 511 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 512 const SCEV *op, Type *ty) 513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} 514 515 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, 516 Type *ITy) 517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) { 518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && 519 "Must be a non-bit-width-changing pointer-to-integer cast!"); 520 } 521 522 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, 523 SCEVTypes SCEVTy, const SCEV *op, 524 Type *ty) 525 : SCEVCastExpr(ID, SCEVTy, op, ty) {} 526 527 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, 528 Type *ty) 529 : SCEVIntegralCastExpr(ID, scTruncate, op, ty) { 530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 531 "Cannot truncate non-integer value!"); 532 } 533 534 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 535 const SCEV *op, Type *ty) 536 : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) { 537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 538 "Cannot zero extend non-integer value!"); 539 } 540 541 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 542 const SCEV *op, Type *ty) 543 : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) { 544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 545 "Cannot sign extend non-integer value!"); 546 } 547 548 void SCEVUnknown::deleted() { 549 // Clear this SCEVUnknown from various maps. 550 SE->forgetMemoizedResults(this); 551 552 // Remove this SCEVUnknown from the uniquing map. 553 SE->UniqueSCEVs.RemoveNode(this); 554 555 // Release the value. 556 setValPtr(nullptr); 557 } 558 559 void SCEVUnknown::allUsesReplacedWith(Value *New) { 560 // Clear this SCEVUnknown from various maps. 561 SE->forgetMemoizedResults(this); 562 563 // Remove this SCEVUnknown from the uniquing map. 564 SE->UniqueSCEVs.RemoveNode(this); 565 566 // Replace the value pointer in case someone is still using this SCEVUnknown. 567 setValPtr(New); 568 } 569 570 //===----------------------------------------------------------------------===// 571 // SCEV Utilities 572 //===----------------------------------------------------------------------===// 573 574 /// Compare the two values \p LV and \p RV in terms of their "complexity" where 575 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order 576 /// operands in SCEV expressions. 577 static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, 578 Value *RV, unsigned Depth) { 579 if (Depth > MaxValueCompareDepth) 580 return 0; 581 582 // Order pointer values after integer values. This helps SCEVExpander form 583 // GEPs. 584 bool LIsPointer = LV->getType()->isPointerTy(), 585 RIsPointer = RV->getType()->isPointerTy(); 586 if (LIsPointer != RIsPointer) 587 return (int)LIsPointer - (int)RIsPointer; 588 589 // Compare getValueID values. 590 unsigned LID = LV->getValueID(), RID = RV->getValueID(); 591 if (LID != RID) 592 return (int)LID - (int)RID; 593 594 // Sort arguments by their position. 595 if (const auto *LA = dyn_cast<Argument>(LV)) { 596 const auto *RA = cast<Argument>(RV); 597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); 598 return (int)LArgNo - (int)RArgNo; 599 } 600 601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { 602 const auto *RGV = cast<GlobalValue>(RV); 603 604 if (auto L = LGV->getLinkage() - RGV->getLinkage()) 605 return L; 606 607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) { 608 auto LT = GV->getLinkage(); 609 return !(GlobalValue::isPrivateLinkage(LT) || 610 GlobalValue::isInternalLinkage(LT)); 611 }; 612 613 // Use the names to distinguish the two values, but only if the 614 // names are semantically important. 615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) 616 return LGV->getName().compare(RGV->getName()); 617 } 618 619 // For instructions, compare their loop depth, and their operand count. This 620 // is pretty loose. 621 if (const auto *LInst = dyn_cast<Instruction>(LV)) { 622 const auto *RInst = cast<Instruction>(RV); 623 624 // Compare loop depths. 625 const BasicBlock *LParent = LInst->getParent(), 626 *RParent = RInst->getParent(); 627 if (LParent != RParent) { 628 unsigned LDepth = LI->getLoopDepth(LParent), 629 RDepth = LI->getLoopDepth(RParent); 630 if (LDepth != RDepth) 631 return (int)LDepth - (int)RDepth; 632 } 633 634 // Compare the number of operands. 635 unsigned LNumOps = LInst->getNumOperands(), 636 RNumOps = RInst->getNumOperands(); 637 if (LNumOps != RNumOps) 638 return (int)LNumOps - (int)RNumOps; 639 640 for (unsigned Idx : seq(LNumOps)) { 641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx), 642 RInst->getOperand(Idx), Depth + 1); 643 if (Result != 0) 644 return Result; 645 } 646 } 647 648 return 0; 649 } 650 651 // Return negative, zero, or positive, if LHS is less than, equal to, or greater 652 // than RHS, respectively. A three-way result allows recursive comparisons to be 653 // more efficient. 654 // If the max analysis depth was reached, return std::nullopt, assuming we do 655 // not know if they are equivalent for sure. 656 static std::optional<int> 657 CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, 658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) { 659 // Fast-path: SCEVs are uniqued so we can do a quick equality check. 660 if (LHS == RHS) 661 return 0; 662 663 // Primarily, sort the SCEVs by their getSCEVType(). 664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); 665 if (LType != RType) 666 return (int)LType - (int)RType; 667 668 if (Depth > MaxSCEVCompareDepth) 669 return std::nullopt; 670 671 // Aside from the getSCEVType() ordering, the particular ordering 672 // isn't very important except that it's beneficial to be consistent, 673 // so that (a + b) and (b + a) don't end up as different expressions. 674 switch (LType) { 675 case scUnknown: { 676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); 677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); 678 679 int X = 680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1); 681 return X; 682 } 683 684 case scConstant: { 685 const SCEVConstant *LC = cast<SCEVConstant>(LHS); 686 const SCEVConstant *RC = cast<SCEVConstant>(RHS); 687 688 // Compare constant values. 689 const APInt &LA = LC->getAPInt(); 690 const APInt &RA = RC->getAPInt(); 691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); 692 if (LBitWidth != RBitWidth) 693 return (int)LBitWidth - (int)RBitWidth; 694 return LA.ult(RA) ? -1 : 1; 695 } 696 697 case scVScale: { 698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType()); 699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType()); 700 return LTy->getBitWidth() - RTy->getBitWidth(); 701 } 702 703 case scAddRecExpr: { 704 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); 705 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); 706 707 // There is always a dominance between two recs that are used by one SCEV, 708 // so we can safely sort recs by loop header dominance. We require such 709 // order in getAddExpr. 710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); 711 if (LLoop != RLoop) { 712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader(); 713 assert(LHead != RHead && "Two loops share the same header?"); 714 if (DT.dominates(LHead, RHead)) 715 return 1; 716 assert(DT.dominates(RHead, LHead) && 717 "No dominance between recurrences used by one SCEV?"); 718 return -1; 719 } 720 721 [[fallthrough]]; 722 } 723 724 case scTruncate: 725 case scZeroExtend: 726 case scSignExtend: 727 case scPtrToInt: 728 case scAddExpr: 729 case scMulExpr: 730 case scUDivExpr: 731 case scSMaxExpr: 732 case scUMaxExpr: 733 case scSMinExpr: 734 case scUMinExpr: 735 case scSequentialUMinExpr: { 736 ArrayRef<const SCEV *> LOps = LHS->operands(); 737 ArrayRef<const SCEV *> ROps = RHS->operands(); 738 739 // Lexicographically compare n-ary-like expressions. 740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size(); 741 if (LNumOps != RNumOps) 742 return (int)LNumOps - (int)RNumOps; 743 744 for (unsigned i = 0; i != LNumOps; ++i) { 745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1); 746 if (X != 0) 747 return X; 748 } 749 return 0; 750 } 751 752 case scCouldNotCompute: 753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 754 } 755 llvm_unreachable("Unknown SCEV kind!"); 756 } 757 758 /// Given a list of SCEV objects, order them by their complexity, and group 759 /// objects of the same complexity together by value. When this routine is 760 /// finished, we know that any duplicates in the vector are consecutive and that 761 /// complexity is monotonically increasing. 762 /// 763 /// Note that we go take special precautions to ensure that we get deterministic 764 /// results from this routine. In other words, we don't want the results of 765 /// this to depend on where the addresses of various SCEV objects happened to 766 /// land in memory. 767 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, 768 LoopInfo *LI, DominatorTree &DT) { 769 if (Ops.size() < 2) return; // Noop 770 771 // Whether LHS has provably less complexity than RHS. 772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) { 773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT); 774 return Complexity && *Complexity < 0; 775 }; 776 if (Ops.size() == 2) { 777 // This is the common case, which also happens to be trivially simple. 778 // Special case it. 779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; 780 if (IsLessComplex(RHS, LHS)) 781 std::swap(LHS, RHS); 782 return; 783 } 784 785 // Do the rough sort by complexity. 786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) { 787 return IsLessComplex(LHS, RHS); 788 }); 789 790 // Now that we are sorted by complexity, group elements of the same 791 // complexity. Note that this is, at worst, N^2, but the vector is likely to 792 // be extremely short in practice. Note that we take this approach because we 793 // do not want to depend on the addresses of the objects we are grouping. 794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { 795 const SCEV *S = Ops[i]; 796 unsigned Complexity = S->getSCEVType(); 797 798 // If there are any objects of the same complexity and same value as this 799 // one, group them. 800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { 801 if (Ops[j] == S) { // Found a duplicate. 802 // Move it to immediately after i'th element. 803 std::swap(Ops[i+1], Ops[j]); 804 ++i; // no need to rescan it. 805 if (i == e-2) return; // Done! 806 } 807 } 808 } 809 } 810 811 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at 812 /// least HugeExprThreshold nodes). 813 static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) { 814 return any_of(Ops, [](const SCEV *S) { 815 return S->getExpressionSize() >= HugeExprThreshold; 816 }); 817 } 818 819 /// Performs a number of common optimizations on the passed \p Ops. If the 820 /// whole expression reduces down to a single operand, it will be returned. 821 /// 822 /// The following optimizations are performed: 823 /// * Fold constants using the \p Fold function. 824 /// * Remove identity constants satisfying \p IsIdentity. 825 /// * If a constant satisfies \p IsAbsorber, return it. 826 /// * Sort operands by complexity. 827 template <typename FoldT, typename IsIdentityT, typename IsAbsorberT> 828 static const SCEV * 829 constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, 830 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold, 831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) { 832 const SCEVConstant *Folded = nullptr; 833 for (unsigned Idx = 0; Idx < Ops.size();) { 834 const SCEV *Op = Ops[Idx]; 835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) { 836 if (!Folded) 837 Folded = C; 838 else 839 Folded = cast<SCEVConstant>( 840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt()))); 841 Ops.erase(Ops.begin() + Idx); 842 continue; 843 } 844 ++Idx; 845 } 846 847 if (Ops.empty()) { 848 assert(Folded && "Must have folded value"); 849 return Folded; 850 } 851 852 if (Folded && IsAbsorber(Folded->getAPInt())) 853 return Folded; 854 855 GroupByComplexity(Ops, &LI, DT); 856 if (Folded && !IsIdentity(Folded->getAPInt())) 857 Ops.insert(Ops.begin(), Folded); 858 859 return Ops.size() == 1 ? Ops[0] : nullptr; 860 } 861 862 //===----------------------------------------------------------------------===// 863 // Simple SCEV method implementations 864 //===----------------------------------------------------------------------===// 865 866 /// Compute BC(It, K). The result has width W. Assume, K > 0. 867 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, 868 ScalarEvolution &SE, 869 Type *ResultTy) { 870 // Handle the simplest case efficiently. 871 if (K == 1) 872 return SE.getTruncateOrZeroExtend(It, ResultTy); 873 874 // We are using the following formula for BC(It, K): 875 // 876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! 877 // 878 // Suppose, W is the bitwidth of the return value. We must be prepared for 879 // overflow. Hence, we must assure that the result of our computation is 880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't 881 // safe in modular arithmetic. 882 // 883 // However, this code doesn't use exactly that formula; the formula it uses 884 // is something like the following, where T is the number of factors of 2 in 885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is 886 // exponentiation: 887 // 888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) 889 // 890 // This formula is trivially equivalent to the previous formula. However, 891 // this formula can be implemented much more efficiently. The trick is that 892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular 893 // arithmetic. To do exact division in modular arithmetic, all we have 894 // to do is multiply by the inverse. Therefore, this step can be done at 895 // width W. 896 // 897 // The next issue is how to safely do the division by 2^T. The way this 898 // is done is by doing the multiplication step at a width of at least W + T 899 // bits. This way, the bottom W+T bits of the product are accurate. Then, 900 // when we perform the division by 2^T (which is equivalent to a right shift 901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get 902 // truncated out after the division by 2^T. 903 // 904 // In comparison to just directly using the first formula, this technique 905 // is much more efficient; using the first formula requires W * K bits, 906 // but this formula less than W + K bits. Also, the first formula requires 907 // a division step, whereas this formula only requires multiplies and shifts. 908 // 909 // It doesn't matter whether the subtraction step is done in the calculation 910 // width or the input iteration count's width; if the subtraction overflows, 911 // the result must be zero anyway. We prefer here to do it in the width of 912 // the induction variable because it helps a lot for certain cases; CodeGen 913 // isn't smart enough to ignore the overflow, which leads to much less 914 // efficient code if the width of the subtraction is wider than the native 915 // register width. 916 // 917 // (It's possible to not widen at all by pulling out factors of 2 before 918 // the multiplication; for example, K=2 can be calculated as 919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires 920 // extra arithmetic, so it's not an obvious win, and it gets 921 // much more complicated for K > 3.) 922 923 // Protection from insane SCEVs; this bound is conservative, 924 // but it probably doesn't matter. 925 if (K > 1000) 926 return SE.getCouldNotCompute(); 927 928 unsigned W = SE.getTypeSizeInBits(ResultTy); 929 930 // Calculate K! / 2^T and T; we divide out the factors of two before 931 // multiplying for calculating K! / 2^T to avoid overflow. 932 // Other overflow doesn't matter because we only care about the bottom 933 // W bits of the result. 934 APInt OddFactorial(W, 1); 935 unsigned T = 1; 936 for (unsigned i = 3; i <= K; ++i) { 937 unsigned TwoFactors = countr_zero(i); 938 T += TwoFactors; 939 OddFactorial *= (i >> TwoFactors); 940 } 941 942 // We need at least W + T bits for the multiplication step 943 unsigned CalculationBits = W + T; 944 945 // Calculate 2^T, at width T+W. 946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); 947 948 // Calculate the multiplicative inverse of K! / 2^T; 949 // this multiplication factor will perform the exact division by 950 // K! / 2^T. 951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse(); 952 953 // Calculate the product, at width T+W 954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(), 955 CalculationBits); 956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); 957 for (unsigned i = 1; i != K; ++i) { 958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); 959 Dividend = SE.getMulExpr(Dividend, 960 SE.getTruncateOrZeroExtend(S, CalculationTy)); 961 } 962 963 // Divide by 2^T 964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); 965 966 // Truncate the result, and divide by K! / 2^T. 967 968 return SE.getMulExpr(SE.getConstant(MultiplyFactor), 969 SE.getTruncateOrZeroExtend(DivResult, ResultTy)); 970 } 971 972 /// Return the value of this chain of recurrences at the specified iteration 973 /// number. We can evaluate this recurrence by multiplying each element in the 974 /// chain by the binomial coefficient corresponding to it. In other words, we 975 /// can evaluate {A,+,B,+,C,+,D} as: 976 /// 977 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) 978 /// 979 /// where BC(It, k) stands for binomial coefficient. 980 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, 981 ScalarEvolution &SE) const { 982 return evaluateAtIteration(operands(), It, SE); 983 } 984 985 const SCEV * 986 SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands, 987 const SCEV *It, ScalarEvolution &SE) { 988 assert(Operands.size() > 0); 989 const SCEV *Result = Operands[0]; 990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) { 991 // The computation is correct in the face of overflow provided that the 992 // multiplication is performed _after_ the evaluation of the binomial 993 // coefficient. 994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType()); 995 if (isa<SCEVCouldNotCompute>(Coeff)) 996 return Coeff; 997 998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff)); 999 } 1000 return Result; 1001 } 1002 1003 //===----------------------------------------------------------------------===// 1004 // SCEV Expression folder implementations 1005 //===----------------------------------------------------------------------===// 1006 1007 const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, 1008 unsigned Depth) { 1009 assert(Depth <= 1 && 1010 "getLosslessPtrToIntExpr() should self-recurse at most once."); 1011 1012 // We could be called with an integer-typed operands during SCEV rewrites. 1013 // Since the operand is an integer already, just perform zext/trunc/self cast. 1014 if (!Op->getType()->isPointerTy()) 1015 return Op; 1016 1017 // What would be an ID for such a SCEV cast expression? 1018 FoldingSetNodeID ID; 1019 ID.AddInteger(scPtrToInt); 1020 ID.AddPointer(Op); 1021 1022 void *IP = nullptr; 1023 1024 // Is there already an expression for such a cast? 1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) 1026 return S; 1027 1028 // It isn't legal for optimizations to construct new ptrtoint expressions 1029 // for non-integral pointers. 1030 if (getDataLayout().isNonIntegralPointerType(Op->getType())) 1031 return getCouldNotCompute(); 1032 1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType()); 1034 1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type 1036 // is sufficiently wide to represent all possible pointer values. 1037 // We could theoretically teach SCEV to truncate wider pointers, but 1038 // that isn't implemented for now. 1039 if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) != 1040 getDataLayout().getTypeSizeInBits(IntPtrTy)) 1041 return getCouldNotCompute(); 1042 1043 // If not, is this expression something we can't reduce any further? 1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) { 1045 // Perform some basic constant folding. If the operand of the ptr2int cast 1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be 1047 // left as-is), but produce a zero constant. 1048 // NOTE: We could handle a more general case, but lack motivational cases. 1049 if (isa<ConstantPointerNull>(U->getValue())) 1050 return getZero(IntPtrTy); 1051 1052 // Create an explicit cast node. 1053 // We can reuse the existing insert position since if we get here, 1054 // we won't have made any changes which would invalidate it. 1055 SCEV *S = new (SCEVAllocator) 1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy); 1057 UniqueSCEVs.InsertNode(S, IP); 1058 registerUser(S, Op); 1059 return S; 1060 } 1061 1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for " 1063 "non-SCEVUnknown's."); 1064 1065 // Otherwise, we've got some expression that is more complex than just a 1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an 1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown 1068 // only, and the expressions must otherwise be integer-typed. 1069 // So sink the cast down to the SCEVUnknown's. 1070 1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression, 1072 /// which computes a pointer-typed value, and rewrites the whole expression 1073 /// tree so that *all* the computations are done on integers, and the only 1074 /// pointer-typed operands in the expression are SCEVUnknown. 1075 class SCEVPtrToIntSinkingRewriter 1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> { 1077 using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>; 1078 1079 public: 1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} 1081 1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) { 1083 SCEVPtrToIntSinkingRewriter Rewriter(SE); 1084 return Rewriter.visit(Scev); 1085 } 1086 1087 const SCEV *visit(const SCEV *S) { 1088 Type *STy = S->getType(); 1089 // If the expression is not pointer-typed, just keep it as-is. 1090 if (!STy->isPointerTy()) 1091 return S; 1092 // Else, recursively sink the cast down into it. 1093 return Base::visit(S); 1094 } 1095 1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 1097 SmallVector<const SCEV *, 2> Operands; 1098 bool Changed = false; 1099 for (const auto *Op : Expr->operands()) { 1100 Operands.push_back(visit(Op)); 1101 Changed |= Op != Operands.back(); 1102 } 1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); 1104 } 1105 1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 1107 SmallVector<const SCEV *, 2> Operands; 1108 bool Changed = false; 1109 for (const auto *Op : Expr->operands()) { 1110 Operands.push_back(visit(Op)); 1111 Changed |= Op != Operands.back(); 1112 } 1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); 1114 } 1115 1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 1117 assert(Expr->getType()->isPointerTy() && 1118 "Should only reach pointer-typed SCEVUnknown's."); 1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1); 1120 } 1121 }; 1122 1123 // And actually perform the cast sinking. 1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); 1125 assert(IntOp->getType()->isIntegerTy() && 1126 "We must have succeeded in sinking the cast, " 1127 "and ending up with an integer-typed expression!"); 1128 return IntOp; 1129 } 1130 1131 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) { 1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!"); 1133 1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op); 1135 if (isa<SCEVCouldNotCompute>(IntOp)) 1136 return IntOp; 1137 1138 return getTruncateOrZeroExtend(IntOp, Ty); 1139 } 1140 1141 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, 1142 unsigned Depth) { 1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && 1144 "This is not a truncating conversion!"); 1145 assert(isSCEVable(Ty) && 1146 "This is not a conversion to a SCEVable type!"); 1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!"); 1148 Ty = getEffectiveSCEVType(Ty); 1149 1150 FoldingSetNodeID ID; 1151 ID.AddInteger(scTruncate); 1152 ID.AddPointer(Op); 1153 ID.AddPointer(Ty); 1154 void *IP = nullptr; 1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1156 1157 // Fold if the operand is constant. 1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1159 return getConstant( 1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty))); 1161 1162 // trunc(trunc(x)) --> trunc(x) 1163 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) 1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1); 1165 1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing 1167 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1); 1169 1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing 1171 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1); 1173 1174 if (Depth > MaxCastDepth) { 1175 SCEV *S = 1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); 1177 UniqueSCEVs.InsertNode(S, IP); 1178 registerUser(S, Op); 1179 return S; 1180 } 1181 1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and 1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN), 1184 // if after transforming we have at most one truncate, not counting truncates 1185 // that replace other casts. 1186 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) { 1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op); 1188 SmallVector<const SCEV *, 4> Operands; 1189 unsigned numTruncs = 0; 1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; 1191 ++i) { 1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); 1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) && 1194 isa<SCEVTruncateExpr>(S)) 1195 numTruncs++; 1196 Operands.push_back(S); 1197 } 1198 if (numTruncs < 2) { 1199 if (isa<SCEVAddExpr>(Op)) 1200 return getAddExpr(Operands); 1201 if (isa<SCEVMulExpr>(Op)) 1202 return getMulExpr(Operands); 1203 llvm_unreachable("Unexpected SCEV type for Op."); 1204 } 1205 // Although we checked in the beginning that ID is not in the cache, it is 1206 // possible that during recursion and different modification ID was inserted 1207 // into the cache. So if we find it, just return it. 1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) 1209 return S; 1210 } 1211 1212 // If the input value is a chrec scev, truncate the chrec's operands. 1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { 1214 SmallVector<const SCEV *, 4> Operands; 1215 for (const SCEV *Op : AddRec->operands()) 1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1)); 1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); 1218 } 1219 1220 // Return zero if truncating to known zeros. 1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op); 1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty)) 1223 return getZero(Ty); 1224 1225 // The cast wasn't folded; create an explicit cast node. We can reuse 1226 // the existing insert position since if we get here, we won't have 1227 // made any changes which would invalidate it. 1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), 1229 Op, Ty); 1230 UniqueSCEVs.InsertNode(S, IP); 1231 registerUser(S, Op); 1232 return S; 1233 } 1234 1235 // Get the limit of a recurrence such that incrementing by Step cannot cause 1236 // signed overflow as long as the value of the recurrence within the 1237 // loop does not exceed this limit before incrementing. 1238 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, 1239 ICmpInst::Predicate *Pred, 1240 ScalarEvolution *SE) { 1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1242 if (SE->isKnownPositive(Step)) { 1243 *Pred = ICmpInst::ICMP_SLT; 1244 return SE->getConstant(APInt::getSignedMinValue(BitWidth) - 1245 SE->getSignedRangeMax(Step)); 1246 } 1247 if (SE->isKnownNegative(Step)) { 1248 *Pred = ICmpInst::ICMP_SGT; 1249 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - 1250 SE->getSignedRangeMin(Step)); 1251 } 1252 return nullptr; 1253 } 1254 1255 // Get the limit of a recurrence such that incrementing by Step cannot cause 1256 // unsigned overflow as long as the value of the recurrence within the loop does 1257 // not exceed this limit before incrementing. 1258 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, 1259 ICmpInst::Predicate *Pred, 1260 ScalarEvolution *SE) { 1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1262 *Pred = ICmpInst::ICMP_ULT; 1263 1264 return SE->getConstant(APInt::getMinValue(BitWidth) - 1265 SE->getUnsignedRangeMax(Step)); 1266 } 1267 1268 namespace { 1269 1270 struct ExtendOpTraitsBase { 1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, 1272 unsigned); 1273 }; 1274 1275 // Used to make code generic over signed and unsigned overflow. 1276 template <typename ExtendOp> struct ExtendOpTraits { 1277 // Members present: 1278 // 1279 // static const SCEV::NoWrapFlags WrapType; 1280 // 1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; 1282 // 1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1284 // ICmpInst::Predicate *Pred, 1285 // ScalarEvolution *SE); 1286 }; 1287 1288 template <> 1289 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { 1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; 1291 1292 static const GetExtendExprTy GetExtendExpr; 1293 1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1295 ICmpInst::Predicate *Pred, 1296 ScalarEvolution *SE) { 1297 return getSignedOverflowLimitForStep(Step, Pred, SE); 1298 } 1299 }; 1300 1301 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1302 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; 1303 1304 template <> 1305 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { 1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; 1307 1308 static const GetExtendExprTy GetExtendExpr; 1309 1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1311 ICmpInst::Predicate *Pred, 1312 ScalarEvolution *SE) { 1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE); 1314 } 1315 }; 1316 1317 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1318 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; 1319 1320 } // end anonymous namespace 1321 1322 // The recurrence AR has been shown to have no signed/unsigned wrap or something 1323 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as 1324 // easily prove NSW/NUW for its preincrement or postincrement sibling. This 1325 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + 1326 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the 1327 // expression "Step + sext/zext(PreIncAR)" is congruent with 1328 // "sext/zext(PostIncAR)" 1329 template <typename ExtendOpTy> 1330 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, 1331 ScalarEvolution *SE, unsigned Depth) { 1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; 1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1334 1335 const Loop *L = AR->getLoop(); 1336 const SCEV *Start = AR->getStart(); 1337 const SCEV *Step = AR->getStepRecurrence(*SE); 1338 1339 // Check for a simple looking step prior to loop entry. 1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); 1341 if (!SA) 1342 return nullptr; 1343 1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV 1345 // subtraction is expensive. For this purpose, perform a quick and dirty 1346 // difference, by checking for Step in the operand list. Note, that 1347 // SA might have repeated ops, like %a + %a + ..., so only remove one. 1348 SmallVector<const SCEV *, 4> DiffOps(SA->operands()); 1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It) 1350 if (*It == Step) { 1351 DiffOps.erase(It); 1352 break; 1353 } 1354 1355 if (DiffOps.size() == SA->getNumOperands()) 1356 return nullptr; 1357 1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + 1359 // `Step`: 1360 1361 // 1. NSW/NUW flags on the step increment. 1362 auto PreStartFlags = 1363 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); 1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); 1365 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( 1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); 1367 1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies 1369 // "S+X does not sign/unsign-overflow". 1370 // 1371 1372 const SCEV *BECount = SE->getBackedgeTakenCount(L); 1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) && 1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) 1375 return PreStart; 1376 1377 // 2. Direct overflow check on the step operation's expression. 1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); 1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); 1380 const SCEV *OperandExtendedStart = 1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), 1382 (SE->*GetExtendExpr)(Step, WideTy, Depth)); 1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { 1384 if (PreAR && AR->getNoWrapFlags(WrapType)) { 1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW 1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then 1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. 1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType); 1389 } 1390 return PreStart; 1391 } 1392 1393 // 3. Loop precondition. 1394 ICmpInst::Predicate Pred; 1395 const SCEV *OverflowLimit = 1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); 1397 1398 if (OverflowLimit && 1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) 1400 return PreStart; 1401 1402 return nullptr; 1403 } 1404 1405 // Get the normalized zero or sign extended expression for this AddRec's Start. 1406 template <typename ExtendOpTy> 1407 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, 1408 ScalarEvolution *SE, 1409 unsigned Depth) { 1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1411 1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth); 1413 if (!PreStart) 1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); 1415 1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, 1417 Depth), 1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth)); 1419 } 1420 1421 // Try to prove away overflow by looking at "nearby" add recurrences. A 1422 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it 1423 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. 1424 // 1425 // Formally: 1426 // 1427 // {S,+,X} == {S-T,+,X} + T 1428 // => Ext({S,+,X}) == Ext({S-T,+,X} + T) 1429 // 1430 // If ({S-T,+,X} + T) does not overflow ... (1) 1431 // 1432 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) 1433 // 1434 // If {S-T,+,X} does not overflow ... (2) 1435 // 1436 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) 1437 // == {Ext(S-T)+Ext(T),+,Ext(X)} 1438 // 1439 // If (S-T)+T does not overflow ... (3) 1440 // 1441 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} 1442 // == {Ext(S),+,Ext(X)} == LHS 1443 // 1444 // Thus, if (1), (2) and (3) are true for some T, then 1445 // Ext({S,+,X}) == {Ext(S),+,Ext(X)} 1446 // 1447 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) 1448 // does not overflow" restricted to the 0th iteration. Therefore we only need 1449 // to check for (1) and (2). 1450 // 1451 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T 1452 // is `Delta` (defined below). 1453 template <typename ExtendOpTy> 1454 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, 1455 const SCEV *Step, 1456 const Loop *L) { 1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; 1458 1459 // We restrict `Start` to a constant to prevent SCEV from spending too much 1460 // time here. It is correct (but more expensive) to continue with a 1461 // non-constant `Start` and do a general SCEV subtraction to compute 1462 // `PreStart` below. 1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); 1464 if (!StartC) 1465 return false; 1466 1467 APInt StartAI = StartC->getAPInt(); 1468 1469 for (unsigned Delta : {-2, -1, 1, 2}) { 1470 const SCEV *PreStart = getConstant(StartAI - Delta); 1471 1472 FoldingSetNodeID ID; 1473 ID.AddInteger(scAddRecExpr); 1474 ID.AddPointer(PreStart); 1475 ID.AddPointer(Step); 1476 ID.AddPointer(L); 1477 void *IP = nullptr; 1478 const auto *PreAR = 1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 1480 1481 // Give up if we don't already have the add recurrence we need because 1482 // actually constructing an add recurrence is relatively expensive. 1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) 1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta); 1485 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep( 1487 DeltaS, &Pred, this); 1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) 1489 return true; 1490 } 1491 } 1492 1493 return false; 1494 } 1495 1496 // Finds an integer D for an expression (C + x + y + ...) such that the top 1497 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or 1498 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is 1499 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and 1500 // the (C + x + y + ...) expression is \p WholeAddExpr. 1501 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, 1502 const SCEVConstant *ConstantTerm, 1503 const SCEVAddExpr *WholeAddExpr) { 1504 const APInt &C = ConstantTerm->getAPInt(); 1505 const unsigned BitWidth = C.getBitWidth(); 1506 // Find number of trailing zeros of (x + y + ...) w/o the C first: 1507 uint32_t TZ = BitWidth; 1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) 1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I))); 1510 if (TZ) { 1511 // Set D to be as many least significant bits of C as possible while still 1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: 1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C; 1514 } 1515 return APInt(BitWidth, 0); 1516 } 1517 1518 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top 1519 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the 1520 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p 1521 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. 1522 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, 1523 const APInt &ConstantStart, 1524 const SCEV *Step) { 1525 const unsigned BitWidth = ConstantStart.getBitWidth(); 1526 const uint32_t TZ = SE.getMinTrailingZeros(Step); 1527 if (TZ) 1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) 1529 : ConstantStart; 1530 return APInt(BitWidth, 0); 1531 } 1532 1533 static void insertFoldCacheEntry( 1534 const ScalarEvolution::FoldID &ID, const SCEV *S, 1535 DenseMap<ScalarEvolution::FoldID, const SCEV *> &FoldCache, 1536 DenseMap<const SCEV *, SmallVector<ScalarEvolution::FoldID, 2>> 1537 &FoldCacheUser) { 1538 auto I = FoldCache.insert({ID, S}); 1539 if (!I.second) { 1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache 1541 // entry. 1542 auto &UserIDs = FoldCacheUser[I.first->second]; 1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs"); 1544 for (unsigned I = 0; I != UserIDs.size(); ++I) 1545 if (UserIDs[I] == ID) { 1546 std::swap(UserIDs[I], UserIDs.back()); 1547 break; 1548 } 1549 UserIDs.pop_back(); 1550 I.first->second = S; 1551 } 1552 FoldCacheUser[S].push_back(ID); 1553 } 1554 1555 const SCEV * 1556 ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { 1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1558 "This is not an extending conversion!"); 1559 assert(isSCEVable(Ty) && 1560 "This is not a conversion to a SCEVable type!"); 1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); 1562 Ty = getEffectiveSCEVType(Ty); 1563 1564 FoldID ID(scZeroExtend, Op, Ty); 1565 if (const SCEV *S = FoldCache.lookup(ID)) 1566 return S; 1567 1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth); 1569 if (!isa<SCEVZeroExtendExpr>(S)) 1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); 1571 return S; 1572 } 1573 1574 const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, 1575 unsigned Depth) { 1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1577 "This is not an extending conversion!"); 1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); 1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); 1580 1581 // Fold if the operand is constant. 1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty))); 1584 1585 // zext(zext(x)) --> zext(x) 1586 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); 1588 1589 // Before doing any expensive analysis, check to see if we've already 1590 // computed a SCEV for this Op and Ty. 1591 FoldingSetNodeID ID; 1592 ID.AddInteger(scZeroExtend); 1593 ID.AddPointer(Op); 1594 ID.AddPointer(Ty); 1595 void *IP = nullptr; 1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1597 if (Depth > MaxCastDepth) { 1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), 1599 Op, Ty); 1600 UniqueSCEVs.InsertNode(S, IP); 1601 registerUser(S, Op); 1602 return S; 1603 } 1604 1605 // zext(trunc(x)) --> zext(x) or x or trunc(x) 1606 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1607 // It's possible the bits taken off by the truncate were all zero bits. If 1608 // so, we should be able to simplify this further. 1609 const SCEV *X = ST->getOperand(); 1610 ConstantRange CR = getUnsignedRange(X); 1611 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1612 unsigned NewBits = getTypeSizeInBits(Ty); 1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( 1614 CR.zextOrTrunc(NewBits))) 1615 return getTruncateOrZeroExtend(X, Ty, Depth); 1616 } 1617 1618 // If the input value is a chrec scev, and we can prove that the value 1619 // did not overflow the old, smaller, value, we can zero extend all of the 1620 // operands (often constants). This allows analysis of something like 1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } 1622 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1623 if (AR->isAffine()) { 1624 const SCEV *Start = AR->getStart(); 1625 const SCEV *Step = AR->getStepRecurrence(*this); 1626 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 1627 const Loop *L = AR->getLoop(); 1628 1629 // If we have special knowledge that this addrec won't overflow, 1630 // we don't need to do any further analysis. 1631 if (AR->hasNoUnsignedWrap()) { 1632 Start = 1633 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); 1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1); 1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1636 } 1637 1638 // Check whether the backedge-taken count is SCEVCouldNotCompute. 1639 // Note that this serves two purposes: It filters out loops that are 1640 // simply not analyzable, and it covers the case where this code is 1641 // being called from within backedge-taken count analysis, such that 1642 // attempting to ask for the backedge-taken count would likely result 1643 // in infinite recursion. In the later case, the analysis code will 1644 // cope with a conservative value, and it will take care to purge 1645 // that value once it has finished. 1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); 1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 1648 // Manually compute the final value for AR, checking for overflow. 1649 1650 // Check whether the backedge-taken count can be losslessly casted to 1651 // the addrec's type. The count is always unsigned. 1652 const SCEV *CastedMaxBECount = 1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); 1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( 1655 CastedMaxBECount, MaxBECount->getType(), Depth); 1656 if (MaxBECount == RecastedMaxBECount) { 1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 1658 // Check whether Start+Step*MaxBECount has no unsigned overflow. 1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, 1660 SCEV::FlagAnyWrap, Depth + 1); 1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, 1662 SCEV::FlagAnyWrap, 1663 Depth + 1), 1664 WideTy, Depth + 1); 1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); 1666 const SCEV *WideMaxBECount = 1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); 1668 const SCEV *OperandExtendedAdd = 1669 getAddExpr(WideStart, 1670 getMulExpr(WideMaxBECount, 1671 getZeroExtendExpr(Step, WideTy, Depth + 1), 1672 SCEV::FlagAnyWrap, Depth + 1), 1673 SCEV::FlagAnyWrap, Depth + 1); 1674 if (ZAdd == OperandExtendedAdd) { 1675 // Cache knowledge of AR NUW, which is propagated to this AddRec. 1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); 1677 // Return the expression with the addrec on the outside. 1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 1679 Depth + 1); 1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1); 1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1682 } 1683 // Similar to above, only this time treat the step value as signed. 1684 // This covers loops that count down. 1685 OperandExtendedAdd = 1686 getAddExpr(WideStart, 1687 getMulExpr(WideMaxBECount, 1688 getSignExtendExpr(Step, WideTy, Depth + 1), 1689 SCEV::FlagAnyWrap, Depth + 1), 1690 SCEV::FlagAnyWrap, Depth + 1); 1691 if (ZAdd == OperandExtendedAdd) { 1692 // Cache knowledge of AR NW, which is propagated to this AddRec. 1693 // Negative step causes unsigned wrap, but it still can't self-wrap. 1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); 1695 // Return the expression with the addrec on the outside. 1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 1697 Depth + 1); 1698 Step = getSignExtendExpr(Step, Ty, Depth + 1); 1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1700 } 1701 } 1702 } 1703 1704 // Normally, in the cases we can prove no-overflow via a 1705 // backedge guarding condition, we can also compute a backedge 1706 // taken count for the loop. The exceptions are assumptions and 1707 // guards present in the loop -- SCEV is not great at exploiting 1708 // these to compute max backedge taken counts, but can still use 1709 // these to prove lack of overflow. Use this fact to avoid 1710 // doing extra work that may not pay off. 1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || 1712 !AC.assumptions().empty()) { 1713 1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR); 1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); 1716 if (AR->hasNoUnsignedWrap()) { 1717 // Same as nuw case above - duplicated here to avoid a compile time 1718 // issue. It's not clear that the order of checks does matter, but 1719 // it's one of two issue possible causes for a change which was 1720 // reverted. Be conservative for the moment. 1721 Start = 1722 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); 1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1); 1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1725 } 1726 1727 // For a negative step, we can extend the operands iff doing so only 1728 // traverses values in the range zext([0,UINT_MAX]). 1729 if (isKnownNegative(Step)) { 1730 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - 1731 getSignedRangeMin(Step)); 1732 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || 1733 isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { 1734 // Cache knowledge of AR NW, which is propagated to this 1735 // AddRec. Negative step causes unsigned wrap, but it 1736 // still can't self-wrap. 1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); 1738 // Return the expression with the addrec on the outside. 1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 1740 Depth + 1); 1741 Step = getSignExtendExpr(Step, Ty, Depth + 1); 1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1743 } 1744 } 1745 } 1746 1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw> 1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap 1749 // where D maximizes the number of trailing zeros of (C - D + Step * n) 1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { 1751 const APInt &C = SC->getAPInt(); 1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step); 1753 if (D != 0) { 1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); 1755 const SCEV *SResidual = 1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); 1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); 1758 return getAddExpr(SZExtD, SZExtR, 1759 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), 1760 Depth + 1); 1761 } 1762 } 1763 1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { 1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); 1766 Start = 1767 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); 1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1); 1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 1770 } 1771 } 1772 1773 // zext(A % B) --> zext(A) % zext(B) 1774 { 1775 const SCEV *LHS; 1776 const SCEV *RHS; 1777 if (matchURem(Op, LHS, RHS)) 1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), 1779 getZeroExtendExpr(RHS, Ty, Depth + 1)); 1780 } 1781 1782 // zext(A / B) --> zext(A) / zext(B). 1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op)) 1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1), 1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1)); 1786 1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { 1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> 1789 if (SA->hasNoUnsignedWrap()) { 1790 // If the addition does not unsign overflow then we can, by definition, 1791 // commute the zero extension with the addition operation. 1792 SmallVector<const SCEV *, 4> Ops; 1793 for (const auto *Op : SA->operands()) 1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); 1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); 1796 } 1797 1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...)) 1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap 1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...) 1801 // 1802 // Often address arithmetics contain expressions like 1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))). 1804 // This transformation is useful while proving that such expressions are 1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass. 1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { 1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); 1808 if (D != 0) { 1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); 1810 const SCEV *SResidual = 1811 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); 1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); 1813 return getAddExpr(SZExtD, SZExtR, 1814 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), 1815 Depth + 1); 1816 } 1817 } 1818 } 1819 1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) { 1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw> 1822 if (SM->hasNoUnsignedWrap()) { 1823 // If the multiply does not unsign overflow then we can, by definition, 1824 // commute the zero extension with the multiply operation. 1825 SmallVector<const SCEV *, 4> Ops; 1826 for (const auto *Op : SM->operands()) 1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); 1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); 1829 } 1830 1831 // zext(2^K * (trunc X to iN)) to iM -> 1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw> 1833 // 1834 // Proof: 1835 // 1836 // zext(2^K * (trunc X to iN)) to iM 1837 // = zext((trunc X to iN) << K) to iM 1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM 1839 // (because shl removes the top K bits) 1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM 1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. 1842 // 1843 if (SM->getNumOperands() == 2) 1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) 1845 if (MulLHS->getAPInt().isPowerOf2()) 1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { 1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - 1848 MulLHS->getAPInt().logBase2(); 1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); 1850 return getMulExpr( 1851 getZeroExtendExpr(MulLHS, Ty), 1852 getZeroExtendExpr( 1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), 1854 SCEV::FlagNUW, Depth + 1); 1855 } 1856 } 1857 1858 // zext(umin(x, y)) -> umin(zext(x), zext(y)) 1859 // zext(umax(x, y)) -> umax(zext(x), zext(y)) 1860 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) { 1861 auto *MinMax = cast<SCEVMinMaxExpr>(Op); 1862 SmallVector<const SCEV *, 4> Operands; 1863 for (auto *Operand : MinMax->operands()) 1864 Operands.push_back(getZeroExtendExpr(Operand, Ty)); 1865 if (isa<SCEVUMinExpr>(MinMax)) 1866 return getUMinExpr(Operands); 1867 return getUMaxExpr(Operands); 1868 } 1869 1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y)) 1871 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) { 1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!"); 1873 SmallVector<const SCEV *, 4> Operands; 1874 for (auto *Operand : MinMax->operands()) 1875 Operands.push_back(getZeroExtendExpr(Operand, Ty)); 1876 return getUMinExpr(Operands, /*Sequential*/ true); 1877 } 1878 1879 // The cast wasn't folded; create an explicit cast node. 1880 // Recompute the insert position, as it may have been invalidated. 1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), 1883 Op, Ty); 1884 UniqueSCEVs.InsertNode(S, IP); 1885 registerUser(S, Op); 1886 return S; 1887 } 1888 1889 const SCEV * 1890 ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { 1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1892 "This is not an extending conversion!"); 1893 assert(isSCEVable(Ty) && 1894 "This is not a conversion to a SCEVable type!"); 1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); 1896 Ty = getEffectiveSCEVType(Ty); 1897 1898 FoldID ID(scSignExtend, Op, Ty); 1899 if (const SCEV *S = FoldCache.lookup(ID)) 1900 return S; 1901 1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth); 1903 if (!isa<SCEVSignExtendExpr>(S)) 1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); 1905 return S; 1906 } 1907 1908 const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, 1909 unsigned Depth) { 1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1911 "This is not an extending conversion!"); 1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); 1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); 1914 Ty = getEffectiveSCEVType(Ty); 1915 1916 // Fold if the operand is constant. 1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty))); 1919 1920 // sext(sext(x)) --> sext(x) 1921 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1); 1923 1924 // sext(zext(x)) --> zext(x) 1925 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); 1927 1928 // Before doing any expensive analysis, check to see if we've already 1929 // computed a SCEV for this Op and Ty. 1930 FoldingSetNodeID ID; 1931 ID.AddInteger(scSignExtend); 1932 ID.AddPointer(Op); 1933 ID.AddPointer(Ty); 1934 void *IP = nullptr; 1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1936 // Limit recursion depth. 1937 if (Depth > MaxCastDepth) { 1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), 1939 Op, Ty); 1940 UniqueSCEVs.InsertNode(S, IP); 1941 registerUser(S, Op); 1942 return S; 1943 } 1944 1945 // sext(trunc(x)) --> sext(x) or x or trunc(x) 1946 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1947 // It's possible the bits taken off by the truncate were all sign bits. If 1948 // so, we should be able to simplify this further. 1949 const SCEV *X = ST->getOperand(); 1950 ConstantRange CR = getSignedRange(X); 1951 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1952 unsigned NewBits = getTypeSizeInBits(Ty); 1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains( 1954 CR.sextOrTrunc(NewBits))) 1955 return getTruncateOrSignExtend(X, Ty, Depth); 1956 } 1957 1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { 1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> 1960 if (SA->hasNoSignedWrap()) { 1961 // If the addition does not sign overflow then we can, by definition, 1962 // commute the sign extension with the addition operation. 1963 SmallVector<const SCEV *, 4> Ops; 1964 for (const auto *Op : SA->operands()) 1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); 1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); 1967 } 1968 1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...)) 1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap 1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...) 1972 // 1973 // For instance, this will bring two seemingly different expressions: 1974 // 1 + sext(5 + 20 * %x + 24 * %y) and 1975 // sext(6 + 20 * %x + 24 * %y) 1976 // to the same form: 1977 // 2 + sext(4 + 20 * %x + 24 * %y) 1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { 1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); 1980 if (D != 0) { 1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); 1982 const SCEV *SResidual = 1983 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); 1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); 1985 return getAddExpr(SSExtD, SSExtR, 1986 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), 1987 Depth + 1); 1988 } 1989 } 1990 } 1991 // If the input value is a chrec scev, and we can prove that the value 1992 // did not overflow the old, smaller, value, we can sign extend all of the 1993 // operands (often constants). This allows analysis of something like 1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } 1995 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1996 if (AR->isAffine()) { 1997 const SCEV *Start = AR->getStart(); 1998 const SCEV *Step = AR->getStepRecurrence(*this); 1999 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 2000 const Loop *L = AR->getLoop(); 2001 2002 // If we have special knowledge that this addrec won't overflow, 2003 // we don't need to do any further analysis. 2004 if (AR->hasNoSignedWrap()) { 2005 Start = 2006 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); 2007 Step = getSignExtendExpr(Step, Ty, Depth + 1); 2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW); 2009 } 2010 2011 // Check whether the backedge-taken count is SCEVCouldNotCompute. 2012 // Note that this serves two purposes: It filters out loops that are 2013 // simply not analyzable, and it covers the case where this code is 2014 // being called from within backedge-taken count analysis, such that 2015 // attempting to ask for the backedge-taken count would likely result 2016 // in infinite recursion. In the later case, the analysis code will 2017 // cope with a conservative value, and it will take care to purge 2018 // that value once it has finished. 2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); 2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 2021 // Manually compute the final value for AR, checking for 2022 // overflow. 2023 2024 // Check whether the backedge-taken count can be losslessly casted to 2025 // the addrec's type. The count is always unsigned. 2026 const SCEV *CastedMaxBECount = 2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); 2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( 2029 CastedMaxBECount, MaxBECount->getType(), Depth); 2030 if (MaxBECount == RecastedMaxBECount) { 2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 2032 // Check whether Start+Step*MaxBECount has no signed overflow. 2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, 2034 SCEV::FlagAnyWrap, Depth + 1); 2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, 2036 SCEV::FlagAnyWrap, 2037 Depth + 1), 2038 WideTy, Depth + 1); 2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); 2040 const SCEV *WideMaxBECount = 2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); 2042 const SCEV *OperandExtendedAdd = 2043 getAddExpr(WideStart, 2044 getMulExpr(WideMaxBECount, 2045 getSignExtendExpr(Step, WideTy, Depth + 1), 2046 SCEV::FlagAnyWrap, Depth + 1), 2047 SCEV::FlagAnyWrap, Depth + 1); 2048 if (SAdd == OperandExtendedAdd) { 2049 // Cache knowledge of AR NSW, which is propagated to this AddRec. 2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); 2051 // Return the expression with the addrec on the outside. 2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, 2053 Depth + 1); 2054 Step = getSignExtendExpr(Step, Ty, Depth + 1); 2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 2056 } 2057 // Similar to above, only this time treat the step value as unsigned. 2058 // This covers loops that count up with an unsigned step. 2059 OperandExtendedAdd = 2060 getAddExpr(WideStart, 2061 getMulExpr(WideMaxBECount, 2062 getZeroExtendExpr(Step, WideTy, Depth + 1), 2063 SCEV::FlagAnyWrap, Depth + 1), 2064 SCEV::FlagAnyWrap, Depth + 1); 2065 if (SAdd == OperandExtendedAdd) { 2066 // If AR wraps around then 2067 // 2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType()) 2069 // => SAdd != OperandExtendedAdd 2070 // 2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> 2072 // (SAdd == OperandExtendedAdd => AR is NW) 2073 2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); 2075 2076 // Return the expression with the addrec on the outside. 2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, 2078 Depth + 1); 2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1); 2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 2081 } 2082 } 2083 } 2084 2085 auto NewFlags = proveNoSignedWrapViaInduction(AR); 2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); 2087 if (AR->hasNoSignedWrap()) { 2088 // Same as nsw case above - duplicated here to avoid a compile time 2089 // issue. It's not clear that the order of checks does matter, but 2090 // it's one of two issue possible causes for a change which was 2091 // reverted. Be conservative for the moment. 2092 Start = 2093 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); 2094 Step = getSignExtendExpr(Step, Ty, Depth + 1); 2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 2096 } 2097 2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> 2099 // if D + (C - D + Step * n) could be proven to not signed wrap 2100 // where D maximizes the number of trailing zeros of (C - D + Step * n) 2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { 2102 const APInt &C = SC->getAPInt(); 2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step); 2104 if (D != 0) { 2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); 2106 const SCEV *SResidual = 2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); 2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); 2109 return getAddExpr(SSExtD, SSExtR, 2110 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), 2111 Depth + 1); 2112 } 2113 } 2114 2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { 2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); 2117 Start = 2118 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); 2119 Step = getSignExtendExpr(Step, Ty, Depth + 1); 2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); 2121 } 2122 } 2123 2124 // If the input value is provably positive and we could not simplify 2125 // away the sext build a zext instead. 2126 if (isKnownNonNegative(Op)) 2127 return getZeroExtendExpr(Op, Ty, Depth + 1); 2128 2129 // sext(smin(x, y)) -> smin(sext(x), sext(y)) 2130 // sext(smax(x, y)) -> smax(sext(x), sext(y)) 2131 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) { 2132 auto *MinMax = cast<SCEVMinMaxExpr>(Op); 2133 SmallVector<const SCEV *, 4> Operands; 2134 for (auto *Operand : MinMax->operands()) 2135 Operands.push_back(getSignExtendExpr(Operand, Ty)); 2136 if (isa<SCEVSMinExpr>(MinMax)) 2137 return getSMinExpr(Operands); 2138 return getSMaxExpr(Operands); 2139 } 2140 2141 // The cast wasn't folded; create an explicit cast node. 2142 // Recompute the insert position, as it may have been invalidated. 2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), 2145 Op, Ty); 2146 UniqueSCEVs.InsertNode(S, IP); 2147 registerUser(S, { Op }); 2148 return S; 2149 } 2150 2151 const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, 2152 Type *Ty) { 2153 switch (Kind) { 2154 case scTruncate: 2155 return getTruncateExpr(Op, Ty); 2156 case scZeroExtend: 2157 return getZeroExtendExpr(Op, Ty); 2158 case scSignExtend: 2159 return getSignExtendExpr(Op, Ty); 2160 case scPtrToInt: 2161 return getPtrToIntExpr(Op, Ty); 2162 default: 2163 llvm_unreachable("Not a SCEV cast expression!"); 2164 } 2165 } 2166 2167 /// getAnyExtendExpr - Return a SCEV for the given operand extended with 2168 /// unspecified bits out to the given type. 2169 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, 2170 Type *Ty) { 2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 2172 "This is not an extending conversion!"); 2173 assert(isSCEVable(Ty) && 2174 "This is not a conversion to a SCEVable type!"); 2175 Ty = getEffectiveSCEVType(Ty); 2176 2177 // Sign-extend negative constants. 2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 2179 if (SC->getAPInt().isNegative()) 2180 return getSignExtendExpr(Op, Ty); 2181 2182 // Peel off a truncate cast. 2183 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { 2184 const SCEV *NewOp = T->getOperand(); 2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) 2186 return getAnyExtendExpr(NewOp, Ty); 2187 return getTruncateOrNoop(NewOp, Ty); 2188 } 2189 2190 // Next try a zext cast. If the cast is folded, use it. 2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty); 2192 if (!isa<SCEVZeroExtendExpr>(ZExt)) 2193 return ZExt; 2194 2195 // Next try a sext cast. If the cast is folded, use it. 2196 const SCEV *SExt = getSignExtendExpr(Op, Ty); 2197 if (!isa<SCEVSignExtendExpr>(SExt)) 2198 return SExt; 2199 2200 // Force the cast to be folded into the operands of an addrec. 2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) { 2202 SmallVector<const SCEV *, 4> Ops; 2203 for (const SCEV *Op : AR->operands()) 2204 Ops.push_back(getAnyExtendExpr(Op, Ty)); 2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); 2206 } 2207 2208 // If the expression is obviously signed, use the sext cast value. 2209 if (isa<SCEVSMaxExpr>(Op)) 2210 return SExt; 2211 2212 // Absent any other information, use the zext cast value. 2213 return ZExt; 2214 } 2215 2216 /// Process the given Ops list, which is a list of operands to be added under 2217 /// the given scale, update the given map. This is a helper function for 2218 /// getAddRecExpr. As an example of what it does, given a sequence of operands 2219 /// that would form an add expression like this: 2220 /// 2221 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) 2222 /// 2223 /// where A and B are constants, update the map with these values: 2224 /// 2225 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) 2226 /// 2227 /// and add 13 + A*B*29 to AccumulatedConstant. 2228 /// This will allow getAddRecExpr to produce this: 2229 /// 2230 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) 2231 /// 2232 /// This form often exposes folding opportunities that are hidden in 2233 /// the original operand list. 2234 /// 2235 /// Return true iff it appears that any interesting folding opportunities 2236 /// may be exposed. This helps getAddRecExpr short-circuit extra work in 2237 /// the common case where no interesting opportunities are present, and 2238 /// is also used as a check to avoid infinite recursion. 2239 static bool 2240 CollectAddOperandsWithScales(SmallDenseMap<const SCEV *, APInt, 16> &M, 2241 SmallVectorImpl<const SCEV *> &NewOps, 2242 APInt &AccumulatedConstant, 2243 ArrayRef<const SCEV *> Ops, const APInt &Scale, 2244 ScalarEvolution &SE) { 2245 bool Interesting = false; 2246 2247 // Iterate over the add operands. They are sorted, with constants first. 2248 unsigned i = 0; 2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { 2250 ++i; 2251 // Pull a buried constant out to the outside. 2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) 2253 Interesting = true; 2254 AccumulatedConstant += Scale * C->getAPInt(); 2255 } 2256 2257 // Next comes everything else. We're especially interested in multiplies 2258 // here, but they're in the middle, so just visit the rest with one loop. 2259 for (; i != Ops.size(); ++i) { 2260 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); 2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { 2262 APInt NewScale = 2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt(); 2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { 2265 // A multiplication of a constant with another add; recurse. 2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1)); 2267 Interesting |= 2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 2269 Add->operands(), NewScale, SE); 2270 } else { 2271 // A multiplication of a constant with some other value. Update 2272 // the map. 2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands())); 2274 const SCEV *Key = SE.getMulExpr(MulOps); 2275 auto Pair = M.insert({Key, NewScale}); 2276 if (Pair.second) { 2277 NewOps.push_back(Pair.first->first); 2278 } else { 2279 Pair.first->second += NewScale; 2280 // The map already had an entry for this value, which may indicate 2281 // a folding opportunity. 2282 Interesting = true; 2283 } 2284 } 2285 } else { 2286 // An ordinary operand. Update the map. 2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = 2288 M.insert({Ops[i], Scale}); 2289 if (Pair.second) { 2290 NewOps.push_back(Pair.first->first); 2291 } else { 2292 Pair.first->second += Scale; 2293 // The map already had an entry for this value, which may indicate 2294 // a folding opportunity. 2295 Interesting = true; 2296 } 2297 } 2298 } 2299 2300 return Interesting; 2301 } 2302 2303 bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, 2304 const SCEV *LHS, const SCEV *RHS, 2305 const Instruction *CtxI) { 2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, 2307 SCEV::NoWrapFlags, unsigned); 2308 switch (BinOp) { 2309 default: 2310 llvm_unreachable("Unsupported binary op"); 2311 case Instruction::Add: 2312 Operation = &ScalarEvolution::getAddExpr; 2313 break; 2314 case Instruction::Sub: 2315 Operation = &ScalarEvolution::getMinusSCEV; 2316 break; 2317 case Instruction::Mul: 2318 Operation = &ScalarEvolution::getMulExpr; 2319 break; 2320 } 2321 2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) = 2323 Signed ? &ScalarEvolution::getSignExtendExpr 2324 : &ScalarEvolution::getZeroExtendExpr; 2325 2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS) 2327 auto *NarrowTy = cast<IntegerType>(LHS->getType()); 2328 auto *WideTy = 2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); 2330 2331 const SCEV *A = (this->*Extension)( 2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0); 2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0); 2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0); 2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); 2336 if (A == B) 2337 return true; 2338 // Can we use context to prove the fact we need? 2339 if (!CtxI) 2340 return false; 2341 // TODO: Support mul. 2342 if (BinOp == Instruction::Mul) 2343 return false; 2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS); 2345 // TODO: Lift this limitation. 2346 if (!RHSC) 2347 return false; 2348 APInt C = RHSC->getAPInt(); 2349 unsigned NumBits = C.getBitWidth(); 2350 bool IsSub = (BinOp == Instruction::Sub); 2351 bool IsNegativeConst = (Signed && C.isNegative()); 2352 // Compute the direction and magnitude by which we need to check overflow. 2353 bool OverflowDown = IsSub ^ IsNegativeConst; 2354 APInt Magnitude = C; 2355 if (IsNegativeConst) { 2356 if (C == APInt::getSignedMinValue(NumBits)) 2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't 2358 // want to deal with that. 2359 return false; 2360 Magnitude = -C; 2361 } 2362 2363 ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; 2364 if (OverflowDown) { 2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS. 2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits) 2367 : APInt::getMinValue(NumBits); 2368 APInt Limit = Min + Magnitude; 2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI); 2370 } else { 2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude. 2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits) 2373 : APInt::getMaxValue(NumBits); 2374 APInt Limit = Max - Magnitude; 2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI); 2376 } 2377 } 2378 2379 std::optional<SCEV::NoWrapFlags> 2380 ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp( 2381 const OverflowingBinaryOperator *OBO) { 2382 // It cannot be done any better. 2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap()) 2384 return std::nullopt; 2385 2386 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap; 2387 2388 if (OBO->hasNoUnsignedWrap()) 2389 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2390 if (OBO->hasNoSignedWrap()) 2391 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 2392 2393 bool Deduced = false; 2394 2395 if (OBO->getOpcode() != Instruction::Add && 2396 OBO->getOpcode() != Instruction::Sub && 2397 OBO->getOpcode() != Instruction::Mul) 2398 return std::nullopt; 2399 2400 const SCEV *LHS = getSCEV(OBO->getOperand(0)); 2401 const SCEV *RHS = getSCEV(OBO->getOperand(1)); 2402 2403 const Instruction *CtxI = 2404 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr; 2405 if (!OBO->hasNoUnsignedWrap() && 2406 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(), 2407 /* Signed */ false, LHS, RHS, CtxI)) { 2408 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2409 Deduced = true; 2410 } 2411 2412 if (!OBO->hasNoSignedWrap() && 2413 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(), 2414 /* Signed */ true, LHS, RHS, CtxI)) { 2415 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 2416 Deduced = true; 2417 } 2418 2419 if (Deduced) 2420 return Flags; 2421 return std::nullopt; 2422 } 2423 2424 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and 2425 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of 2426 // can't-overflow flags for the operation if possible. 2427 static SCEV::NoWrapFlags 2428 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, 2429 const ArrayRef<const SCEV *> Ops, 2430 SCEV::NoWrapFlags Flags) { 2431 using namespace std::placeholders; 2432 2433 using OBO = OverflowingBinaryOperator; 2434 2435 bool CanAnalyze = 2436 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; 2437 (void)CanAnalyze; 2438 assert(CanAnalyze && "don't call from other places!"); 2439 2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; 2441 SCEV::NoWrapFlags SignOrUnsignWrap = 2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); 2443 2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. 2445 auto IsKnownNonNegative = [&](const SCEV *S) { 2446 return SE->isKnownNonNegative(S); 2447 }; 2448 2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) 2450 Flags = 2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); 2452 2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); 2454 2455 if (SignOrUnsignWrap != SignOrUnsignMask && 2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && 2457 isa<SCEVConstant>(Ops[0])) { 2458 2459 auto Opcode = [&] { 2460 switch (Type) { 2461 case scAddExpr: 2462 return Instruction::Add; 2463 case scMulExpr: 2464 return Instruction::Mul; 2465 default: 2466 llvm_unreachable("Unexpected SCEV op."); 2467 } 2468 }(); 2469 2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); 2471 2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow. 2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { 2474 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 2475 Opcode, C, OBO::NoSignedWrap); 2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) 2477 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 2478 } 2479 2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow. 2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { 2482 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 2483 Opcode, C, OBO::NoUnsignedWrap); 2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) 2485 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2486 } 2487 } 2488 2489 // <0,+,nonnegative><nw> is also nuw 2490 // TODO: Add corresponding nsw case 2491 if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) && 2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 && 2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1])) 2494 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2495 2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW 2497 if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && 2498 Ops.size() == 2) { 2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0])) 2500 if (UDiv->getOperand(1) == Ops[1]) 2501 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1])) 2503 if (UDiv->getOperand(1) == Ops[0]) 2504 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2505 } 2506 2507 return Flags; 2508 } 2509 2510 bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { 2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); 2512 } 2513 2514 /// Get a canonical add expression, or something simpler if possible. 2515 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, 2516 SCEV::NoWrapFlags OrigFlags, 2517 unsigned Depth) { 2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && 2519 "only nuw or nsw allowed"); 2520 assert(!Ops.empty() && "Cannot get empty add!"); 2521 if (Ops.size() == 1) return Ops[0]; 2522 #ifndef NDEBUG 2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2526 "SCEVAddExpr operand types don't match!"); 2527 unsigned NumPtrs = count_if( 2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); }); 2529 assert(NumPtrs <= 1 && "add has at most one pointer operand"); 2530 #endif 2531 2532 const SCEV *Folded = constantFoldAndGroupOps( 2533 *this, LI, DT, Ops, 2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; }, 2535 [](const APInt &C) { return C.isZero(); }, // identity 2536 [](const APInt &C) { return false; }); // absorber 2537 if (Folded) 2538 return Folded; 2539 2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0; 2541 2542 // Delay expensive flag strengthening until necessary. 2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { 2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags); 2545 }; 2546 2547 // Limit recursion calls depth. 2548 if (Depth > MaxArithDepth || hasHugeExpression(Ops)) 2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); 2550 2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) { 2552 // Don't strengthen flags if we have no new information. 2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S); 2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags) 2555 Add->setNoWrapFlags(ComputeFlags(Ops)); 2556 return S; 2557 } 2558 2559 // Okay, check to see if the same value occurs in the operand list more than 2560 // once. If so, merge them together into an multiply expression. Since we 2561 // sorted the list, these values are required to be adjacent. 2562 Type *Ty = Ops[0]->getType(); 2563 bool FoundMatch = false; 2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) 2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 2566 // Scan ahead to count how many equal operands there are. 2567 unsigned Count = 2; 2568 while (i+Count != e && Ops[i+Count] == Ops[i]) 2569 ++Count; 2570 // Merge the values into a multiply. 2571 const SCEV *Scale = getConstant(Ty, Count); 2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); 2573 if (Ops.size() == Count) 2574 return Mul; 2575 Ops[i] = Mul; 2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); 2577 --i; e -= Count - 1; 2578 FoundMatch = true; 2579 } 2580 if (FoundMatch) 2581 return getAddExpr(Ops, OrigFlags, Depth + 1); 2582 2583 // Check for truncates. If all the operands are truncated from the same 2584 // type, see if factoring out the truncate would permit the result to be 2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y) 2586 // if the contents of the resulting outer trunc fold to something simple. 2587 auto FindTruncSrcType = [&]() -> Type * { 2588 // We're ultimately looking to fold an addrec of truncs and muls of only 2589 // constants and truncs, so if we find any other types of SCEV 2590 // as operands of the addrec then we bail and return nullptr here. 2591 // Otherwise, we return the type of the operand of a trunc that we find. 2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx])) 2593 return T->getOperand()->getType(); 2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { 2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); 2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp)) 2597 return T->getOperand()->getType(); 2598 } 2599 return nullptr; 2600 }; 2601 if (auto *SrcType = FindTruncSrcType()) { 2602 SmallVector<const SCEV *, 8> LargeOps; 2603 bool Ok = true; 2604 // Check all the operands to see if they can be represented in the 2605 // source type of the truncate. 2606 for (const SCEV *Op : Ops) { 2607 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { 2608 if (T->getOperand()->getType() != SrcType) { 2609 Ok = false; 2610 break; 2611 } 2612 LargeOps.push_back(T->getOperand()); 2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) { 2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType)); 2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) { 2616 SmallVector<const SCEV *, 8> LargeMulOps; 2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { 2618 if (const SCEVTruncateExpr *T = 2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { 2620 if (T->getOperand()->getType() != SrcType) { 2621 Ok = false; 2622 break; 2623 } 2624 LargeMulOps.push_back(T->getOperand()); 2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) { 2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); 2627 } else { 2628 Ok = false; 2629 break; 2630 } 2631 } 2632 if (Ok) 2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1)); 2634 } else { 2635 Ok = false; 2636 break; 2637 } 2638 } 2639 if (Ok) { 2640 // Evaluate the expression in the larger type. 2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); 2642 // If it folds to something simple, use it. Otherwise, don't. 2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) 2644 return getTruncateExpr(Fold, Ty); 2645 } 2646 } 2647 2648 if (Ops.size() == 2) { 2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and 2650 // C2 can be folded in a way that allows retaining wrapping flags of (X + 2651 // C1). 2652 const SCEV *A = Ops[0]; 2653 const SCEV *B = Ops[1]; 2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B); 2655 auto *C = dyn_cast<SCEVConstant>(A); 2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) { 2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt(); 2658 auto C2 = C->getAPInt(); 2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap; 2660 2661 APInt ConstAdd = C1 + C2; 2662 auto AddFlags = AddExpr->getNoWrapFlags(); 2663 // Adding a smaller constant is NUW if the original AddExpr was NUW. 2664 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) && 2665 ConstAdd.ule(C1)) { 2666 PreservedFlags = 2667 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW); 2668 } 2669 2670 // Adding a constant with the same sign and small magnitude is NSW, if the 2671 // original AddExpr was NSW. 2672 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) && 2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() && 2674 ConstAdd.abs().ule(C1.abs())) { 2675 PreservedFlags = 2676 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW); 2677 } 2678 2679 if (PreservedFlags != SCEV::FlagAnyWrap) { 2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands()); 2681 NewOps[0] = getConstant(ConstAdd); 2682 return getAddExpr(NewOps, PreservedFlags); 2683 } 2684 } 2685 } 2686 2687 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y) 2688 if (Ops.size() == 2) { 2689 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]); 2690 if (Mul && Mul->getNumOperands() == 2 && 2691 Mul->getOperand(0)->isAllOnesValue()) { 2692 const SCEV *X; 2693 const SCEV *Y; 2694 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) { 2695 return getMulExpr(Y, getUDivExpr(X, Y)); 2696 } 2697 } 2698 } 2699 2700 // Skip past any other cast SCEVs. 2701 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) 2702 ++Idx; 2703 2704 // If there are add operands they would be next. 2705 if (Idx < Ops.size()) { 2706 bool DeletedAdd = false; 2707 // If the original flags and all inlined SCEVAddExprs are NUW, use the 2708 // common NUW flag for expression after inlining. Other flags cannot be 2709 // preserved, because they may depend on the original order of operations. 2710 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW); 2711 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { 2712 if (Ops.size() > AddOpsInlineThreshold || 2713 Add->getNumOperands() > AddOpsInlineThreshold) 2714 break; 2715 // If we have an add, expand the add operands onto the end of the operands 2716 // list. 2717 Ops.erase(Ops.begin()+Idx); 2718 append_range(Ops, Add->operands()); 2719 DeletedAdd = true; 2720 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags()); 2721 } 2722 2723 // If we deleted at least one add, we added operands to the end of the list, 2724 // and they are not necessarily sorted. Recurse to resort and resimplify 2725 // any operands we just acquired. 2726 if (DeletedAdd) 2727 return getAddExpr(Ops, CommonFlags, Depth + 1); 2728 } 2729 2730 // Skip over the add expression until we get to a multiply. 2731 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 2732 ++Idx; 2733 2734 // Check to see if there are any folding opportunities present with 2735 // operands multiplied by constant values. 2736 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { 2737 uint64_t BitWidth = getTypeSizeInBits(Ty); 2738 SmallDenseMap<const SCEV *, APInt, 16> M; 2739 SmallVector<const SCEV *, 8> NewOps; 2740 APInt AccumulatedConstant(BitWidth, 0); 2741 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 2742 Ops, APInt(BitWidth, 1), *this)) { 2743 struct APIntCompare { 2744 bool operator()(const APInt &LHS, const APInt &RHS) const { 2745 return LHS.ult(RHS); 2746 } 2747 }; 2748 2749 // Some interesting folding opportunity is present, so its worthwhile to 2750 // re-generate the operands list. Group the operands by constant scale, 2751 // to avoid multiplying by the same constant scale multiple times. 2752 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists; 2753 for (const SCEV *NewOp : NewOps) 2754 MulOpLists[M.find(NewOp)->second].push_back(NewOp); 2755 // Re-generate the operands list. 2756 Ops.clear(); 2757 if (AccumulatedConstant != 0) 2758 Ops.push_back(getConstant(AccumulatedConstant)); 2759 for (auto &MulOp : MulOpLists) { 2760 if (MulOp.first == 1) { 2761 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1)); 2762 } else if (MulOp.first != 0) { 2763 Ops.push_back(getMulExpr( 2764 getConstant(MulOp.first), 2765 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1), 2766 SCEV::FlagAnyWrap, Depth + 1)); 2767 } 2768 } 2769 if (Ops.empty()) 2770 return getZero(Ty); 2771 if (Ops.size() == 1) 2772 return Ops[0]; 2773 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 2774 } 2775 } 2776 2777 // If we are adding something to a multiply expression, make sure the 2778 // something is not already an operand of the multiply. If so, merge it into 2779 // the multiply. 2780 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { 2781 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); 2782 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { 2783 const SCEV *MulOpSCEV = Mul->getOperand(MulOp); 2784 if (isa<SCEVConstant>(MulOpSCEV)) 2785 continue; 2786 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) 2787 if (MulOpSCEV == Ops[AddOp]) { 2788 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) 2789 const SCEV *InnerMul = Mul->getOperand(MulOp == 0); 2790 if (Mul->getNumOperands() != 2) { 2791 // If the multiply has more than two operands, we must get the 2792 // Y*Z term. 2793 SmallVector<const SCEV *, 4> MulOps( 2794 Mul->operands().take_front(MulOp)); 2795 append_range(MulOps, Mul->operands().drop_front(MulOp + 1)); 2796 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); 2797 } 2798 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; 2799 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); 2800 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, 2801 SCEV::FlagAnyWrap, Depth + 1); 2802 if (Ops.size() == 2) return OuterMul; 2803 if (AddOp < Idx) { 2804 Ops.erase(Ops.begin()+AddOp); 2805 Ops.erase(Ops.begin()+Idx-1); 2806 } else { 2807 Ops.erase(Ops.begin()+Idx); 2808 Ops.erase(Ops.begin()+AddOp-1); 2809 } 2810 Ops.push_back(OuterMul); 2811 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 2812 } 2813 2814 // Check this multiply against other multiplies being added together. 2815 for (unsigned OtherMulIdx = Idx+1; 2816 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); 2817 ++OtherMulIdx) { 2818 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); 2819 // If MulOp occurs in OtherMul, we can fold the two multiplies 2820 // together. 2821 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); 2822 OMulOp != e; ++OMulOp) 2823 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { 2824 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) 2825 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); 2826 if (Mul->getNumOperands() != 2) { 2827 SmallVector<const SCEV *, 4> MulOps( 2828 Mul->operands().take_front(MulOp)); 2829 append_range(MulOps, Mul->operands().drop_front(MulOp+1)); 2830 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); 2831 } 2832 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); 2833 if (OtherMul->getNumOperands() != 2) { 2834 SmallVector<const SCEV *, 4> MulOps( 2835 OtherMul->operands().take_front(OMulOp)); 2836 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1)); 2837 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); 2838 } 2839 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; 2840 const SCEV *InnerMulSum = 2841 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); 2842 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, 2843 SCEV::FlagAnyWrap, Depth + 1); 2844 if (Ops.size() == 2) return OuterMul; 2845 Ops.erase(Ops.begin()+Idx); 2846 Ops.erase(Ops.begin()+OtherMulIdx-1); 2847 Ops.push_back(OuterMul); 2848 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 2849 } 2850 } 2851 } 2852 } 2853 2854 // If there are any add recurrences in the operands list, see if any other 2855 // added values are loop invariant. If so, we can fold them into the 2856 // recurrence. 2857 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 2858 ++Idx; 2859 2860 // Scan over all recurrences, trying to fold loop invariants into them. 2861 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 2862 // Scan all of the other operands to this add and add them to the vector if 2863 // they are loop invariant w.r.t. the recurrence. 2864 SmallVector<const SCEV *, 8> LIOps; 2865 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 2866 const Loop *AddRecLoop = AddRec->getLoop(); 2867 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2868 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { 2869 LIOps.push_back(Ops[i]); 2870 Ops.erase(Ops.begin()+i); 2871 --i; --e; 2872 } 2873 2874 // If we found some loop invariants, fold them into the recurrence. 2875 if (!LIOps.empty()) { 2876 // Compute nowrap flags for the addition of the loop-invariant ops and 2877 // the addrec. Temporarily push it as an operand for that purpose. These 2878 // flags are valid in the scope of the addrec only. 2879 LIOps.push_back(AddRec); 2880 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps); 2881 LIOps.pop_back(); 2882 2883 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} 2884 LIOps.push_back(AddRec->getStart()); 2885 2886 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); 2887 2888 // It is not in general safe to propagate flags valid on an add within 2889 // the addrec scope to one outside it. We must prove that the inner 2890 // scope is guaranteed to execute if the outer one does to be able to 2891 // safely propagate. We know the program is undefined if poison is 2892 // produced on the inner scoped addrec. We also know that *for this use* 2893 // the outer scoped add can't overflow (because of the flags we just 2894 // computed for the inner scoped add) without the program being undefined. 2895 // Proving that entry to the outer scope neccesitates entry to the inner 2896 // scope, thus proves the program undefined if the flags would be violated 2897 // in the outer scope. 2898 SCEV::NoWrapFlags AddFlags = Flags; 2899 if (AddFlags != SCEV::FlagAnyWrap) { 2900 auto *DefI = getDefiningScopeBound(LIOps); 2901 auto *ReachI = &*AddRecLoop->getHeader()->begin(); 2902 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI)) 2903 AddFlags = SCEV::FlagAnyWrap; 2904 } 2905 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1); 2906 2907 // Build the new addrec. Propagate the NUW and NSW flags if both the 2908 // outer add and the inner addrec are guaranteed to have no overflow. 2909 // Always propagate NW. 2910 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); 2911 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); 2912 2913 // If all of the other operands were loop invariant, we are done. 2914 if (Ops.size() == 1) return NewRec; 2915 2916 // Otherwise, add the folded AddRec by the non-invariant parts. 2917 for (unsigned i = 0;; ++i) 2918 if (Ops[i] == AddRec) { 2919 Ops[i] = NewRec; 2920 break; 2921 } 2922 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 2923 } 2924 2925 // Okay, if there weren't any loop invariants to be folded, check to see if 2926 // there are multiple AddRec's with the same loop induction variable being 2927 // added together. If so, we can fold them. 2928 for (unsigned OtherIdx = Idx+1; 2929 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2930 ++OtherIdx) { 2931 // We expect the AddRecExpr's to be sorted in reverse dominance order, 2932 // so that the 1st found AddRecExpr is dominated by all others. 2933 assert(DT.dominates( 2934 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), 2935 AddRec->getLoop()->getHeader()) && 2936 "AddRecExprs are not sorted in reverse dominance order?"); 2937 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { 2938 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> 2939 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); 2940 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2941 ++OtherIdx) { 2942 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); 2943 if (OtherAddRec->getLoop() == AddRecLoop) { 2944 for (unsigned i = 0, e = OtherAddRec->getNumOperands(); 2945 i != e; ++i) { 2946 if (i >= AddRecOps.size()) { 2947 append_range(AddRecOps, OtherAddRec->operands().drop_front(i)); 2948 break; 2949 } 2950 SmallVector<const SCEV *, 2> TwoOps = { 2951 AddRecOps[i], OtherAddRec->getOperand(i)}; 2952 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); 2953 } 2954 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 2955 } 2956 } 2957 // Step size has changed, so we cannot guarantee no self-wraparound. 2958 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); 2959 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 2960 } 2961 } 2962 2963 // Otherwise couldn't fold anything into this recurrence. Move onto the 2964 // next one. 2965 } 2966 2967 // Okay, it looks like we really DO need an add expr. Check to see if we 2968 // already have one, otherwise create a new one. 2969 return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); 2970 } 2971 2972 const SCEV * 2973 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops, 2974 SCEV::NoWrapFlags Flags) { 2975 FoldingSetNodeID ID; 2976 ID.AddInteger(scAddExpr); 2977 for (const SCEV *Op : Ops) 2978 ID.AddPointer(Op); 2979 void *IP = nullptr; 2980 SCEVAddExpr *S = 2981 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2982 if (!S) { 2983 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2984 llvm::uninitialized_copy(Ops, O); 2985 S = new (SCEVAllocator) 2986 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); 2987 UniqueSCEVs.InsertNode(S, IP); 2988 registerUser(S, Ops); 2989 } 2990 S->setNoWrapFlags(Flags); 2991 return S; 2992 } 2993 2994 const SCEV * 2995 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops, 2996 const Loop *L, SCEV::NoWrapFlags Flags) { 2997 FoldingSetNodeID ID; 2998 ID.AddInteger(scAddRecExpr); 2999 for (const SCEV *Op : Ops) 3000 ID.AddPointer(Op); 3001 ID.AddPointer(L); 3002 void *IP = nullptr; 3003 SCEVAddRecExpr *S = 3004 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 3005 if (!S) { 3006 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3007 llvm::uninitialized_copy(Ops, O); 3008 S = new (SCEVAllocator) 3009 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); 3010 UniqueSCEVs.InsertNode(S, IP); 3011 LoopUsers[L].push_back(S); 3012 registerUser(S, Ops); 3013 } 3014 setNoWrapFlags(S, Flags); 3015 return S; 3016 } 3017 3018 const SCEV * 3019 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops, 3020 SCEV::NoWrapFlags Flags) { 3021 FoldingSetNodeID ID; 3022 ID.AddInteger(scMulExpr); 3023 for (const SCEV *Op : Ops) 3024 ID.AddPointer(Op); 3025 void *IP = nullptr; 3026 SCEVMulExpr *S = 3027 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 3028 if (!S) { 3029 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3030 llvm::uninitialized_copy(Ops, O); 3031 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), 3032 O, Ops.size()); 3033 UniqueSCEVs.InsertNode(S, IP); 3034 registerUser(S, Ops); 3035 } 3036 S->setNoWrapFlags(Flags); 3037 return S; 3038 } 3039 3040 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { 3041 uint64_t k = i*j; 3042 if (j > 1 && k / j != i) Overflow = true; 3043 return k; 3044 } 3045 3046 /// Compute the result of "n choose k", the binomial coefficient. If an 3047 /// intermediate computation overflows, Overflow will be set and the return will 3048 /// be garbage. Overflow is not cleared on absence of overflow. 3049 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { 3050 // We use the multiplicative formula: 3051 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . 3052 // At each iteration, we take the n-th term of the numeral and divide by the 3053 // (k-n)th term of the denominator. This division will always produce an 3054 // integral result, and helps reduce the chance of overflow in the 3055 // intermediate computations. However, we can still overflow even when the 3056 // final result would fit. 3057 3058 if (n == 0 || n == k) return 1; 3059 if (k > n) return 0; 3060 3061 if (k > n/2) 3062 k = n-k; 3063 3064 uint64_t r = 1; 3065 for (uint64_t i = 1; i <= k; ++i) { 3066 r = umul_ov(r, n-(i-1), Overflow); 3067 r /= i; 3068 } 3069 return r; 3070 } 3071 3072 /// Determine if any of the operands in this SCEV are a constant or if 3073 /// any of the add or multiply expressions in this SCEV contain a constant. 3074 static bool containsConstantInAddMulChain(const SCEV *StartExpr) { 3075 struct FindConstantInAddMulChain { 3076 bool FoundConstant = false; 3077 3078 bool follow(const SCEV *S) { 3079 FoundConstant |= isa<SCEVConstant>(S); 3080 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S); 3081 } 3082 3083 bool isDone() const { 3084 return FoundConstant; 3085 } 3086 }; 3087 3088 FindConstantInAddMulChain F; 3089 SCEVTraversal<FindConstantInAddMulChain> ST(F); 3090 ST.visitAll(StartExpr); 3091 return F.FoundConstant; 3092 } 3093 3094 /// Get a canonical multiply expression, or something simpler if possible. 3095 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, 3096 SCEV::NoWrapFlags OrigFlags, 3097 unsigned Depth) { 3098 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && 3099 "only nuw or nsw allowed"); 3100 assert(!Ops.empty() && "Cannot get empty mul!"); 3101 if (Ops.size() == 1) return Ops[0]; 3102 #ifndef NDEBUG 3103 Type *ETy = Ops[0]->getType(); 3104 assert(!ETy->isPointerTy()); 3105 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 3106 assert(Ops[i]->getType() == ETy && 3107 "SCEVMulExpr operand types don't match!"); 3108 #endif 3109 3110 const SCEV *Folded = constantFoldAndGroupOps( 3111 *this, LI, DT, Ops, 3112 [](const APInt &C1, const APInt &C2) { return C1 * C2; }, 3113 [](const APInt &C) { return C.isOne(); }, // identity 3114 [](const APInt &C) { return C.isZero(); }); // absorber 3115 if (Folded) 3116 return Folded; 3117 3118 // Delay expensive flag strengthening until necessary. 3119 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { 3120 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags); 3121 }; 3122 3123 // Limit recursion calls depth. 3124 if (Depth > MaxArithDepth || hasHugeExpression(Ops)) 3125 return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); 3126 3127 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) { 3128 // Don't strengthen flags if we have no new information. 3129 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S); 3130 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags) 3131 Mul->setNoWrapFlags(ComputeFlags(Ops)); 3132 return S; 3133 } 3134 3135 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 3136 if (Ops.size() == 2) { 3137 // C1*(C2+V) -> C1*C2 + C1*V 3138 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) 3139 // If any of Add's ops are Adds or Muls with a constant, apply this 3140 // transformation as well. 3141 // 3142 // TODO: There are some cases where this transformation is not 3143 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of 3144 // this transformation should be narrowed down. 3145 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { 3146 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), 3147 SCEV::FlagAnyWrap, Depth + 1); 3148 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), 3149 SCEV::FlagAnyWrap, Depth + 1); 3150 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); 3151 } 3152 3153 if (Ops[0]->isAllOnesValue()) { 3154 // If we have a mul by -1 of an add, try distributing the -1 among the 3155 // add operands. 3156 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { 3157 SmallVector<const SCEV *, 4> NewOps; 3158 bool AnyFolded = false; 3159 for (const SCEV *AddOp : Add->operands()) { 3160 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, 3161 Depth + 1); 3162 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; 3163 NewOps.push_back(Mul); 3164 } 3165 if (AnyFolded) 3166 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); 3167 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { 3168 // Negation preserves a recurrence's no self-wrap property. 3169 SmallVector<const SCEV *, 4> Operands; 3170 for (const SCEV *AddRecOp : AddRec->operands()) 3171 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, 3172 Depth + 1)); 3173 // Let M be the minimum representable signed value. AddRec with nsw 3174 // multiplied by -1 can have signed overflow if and only if it takes a 3175 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the 3176 // maximum signed value. In all other cases signed overflow is 3177 // impossible. 3178 auto FlagsMask = SCEV::FlagNW; 3179 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) { 3180 auto MinInt = 3181 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType())); 3182 if (getSignedRangeMin(AddRec) != MinInt) 3183 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW); 3184 } 3185 return getAddRecExpr(Operands, AddRec->getLoop(), 3186 AddRec->getNoWrapFlags(FlagsMask)); 3187 } 3188 } 3189 } 3190 } 3191 3192 // Skip over the add expression until we get to a multiply. 3193 unsigned Idx = 0; 3194 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 3195 ++Idx; 3196 3197 // If there are mul operands inline them all into this expression. 3198 if (Idx < Ops.size()) { 3199 bool DeletedMul = false; 3200 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { 3201 if (Ops.size() > MulOpsInlineThreshold) 3202 break; 3203 // If we have an mul, expand the mul operands onto the end of the 3204 // operands list. 3205 Ops.erase(Ops.begin()+Idx); 3206 append_range(Ops, Mul->operands()); 3207 DeletedMul = true; 3208 } 3209 3210 // If we deleted at least one mul, we added operands to the end of the 3211 // list, and they are not necessarily sorted. Recurse to resort and 3212 // resimplify any operands we just acquired. 3213 if (DeletedMul) 3214 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 3215 } 3216 3217 // If there are any add recurrences in the operands list, see if any other 3218 // added values are loop invariant. If so, we can fold them into the 3219 // recurrence. 3220 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 3221 ++Idx; 3222 3223 // Scan over all recurrences, trying to fold loop invariants into them. 3224 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 3225 // Scan all of the other operands to this mul and add them to the vector 3226 // if they are loop invariant w.r.t. the recurrence. 3227 SmallVector<const SCEV *, 8> LIOps; 3228 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 3229 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 3230 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) { 3231 LIOps.push_back(Ops[i]); 3232 Ops.erase(Ops.begin()+i); 3233 --i; --e; 3234 } 3235 3236 // If we found some loop invariants, fold them into the recurrence. 3237 if (!LIOps.empty()) { 3238 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} 3239 SmallVector<const SCEV *, 4> NewOps; 3240 NewOps.reserve(AddRec->getNumOperands()); 3241 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); 3242 3243 // If both the mul and addrec are nuw, we can preserve nuw. 3244 // If both the mul and addrec are nsw, we can only preserve nsw if either 3245 // a) they are also nuw, or 3246 // b) all multiplications of addrec operands with scale are nsw. 3247 SCEV::NoWrapFlags Flags = 3248 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec})); 3249 3250 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 3251 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i), 3252 SCEV::FlagAnyWrap, Depth + 1)); 3253 3254 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) { 3255 ConstantRange NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 3256 Instruction::Mul, getSignedRange(Scale), 3257 OverflowingBinaryOperator::NoSignedWrap); 3258 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i)))) 3259 Flags = clearFlags(Flags, SCEV::FlagNSW); 3260 } 3261 } 3262 3263 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags); 3264 3265 // If all of the other operands were loop invariant, we are done. 3266 if (Ops.size() == 1) return NewRec; 3267 3268 // Otherwise, multiply the folded AddRec by the non-invariant parts. 3269 for (unsigned i = 0;; ++i) 3270 if (Ops[i] == AddRec) { 3271 Ops[i] = NewRec; 3272 break; 3273 } 3274 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 3275 } 3276 3277 // Okay, if there weren't any loop invariants to be folded, check to see 3278 // if there are multiple AddRec's with the same loop induction variable 3279 // being multiplied together. If so, we can fold them. 3280 3281 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> 3282 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ 3283 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z 3284 // ]]],+,...up to x=2n}. 3285 // Note that the arguments to choose() are always integers with values 3286 // known at compile time, never SCEV objects. 3287 // 3288 // The implementation avoids pointless extra computations when the two 3289 // addrec's are of different length (mathematically, it's equivalent to 3290 // an infinite stream of zeros on the right). 3291 bool OpsModified = false; 3292 for (unsigned OtherIdx = Idx+1; 3293 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 3294 ++OtherIdx) { 3295 const SCEVAddRecExpr *OtherAddRec = 3296 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); 3297 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop()) 3298 continue; 3299 3300 // Limit max number of arguments to avoid creation of unreasonably big 3301 // SCEVAddRecs with very complex operands. 3302 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 > 3303 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec})) 3304 continue; 3305 3306 bool Overflow = false; 3307 Type *Ty = AddRec->getType(); 3308 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; 3309 SmallVector<const SCEV*, 7> AddRecOps; 3310 for (int x = 0, xe = AddRec->getNumOperands() + 3311 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { 3312 SmallVector <const SCEV *, 7> SumOps; 3313 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { 3314 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); 3315 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), 3316 ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); 3317 z < ze && !Overflow; ++z) { 3318 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); 3319 uint64_t Coeff; 3320 if (LargerThan64Bits) 3321 Coeff = umul_ov(Coeff1, Coeff2, Overflow); 3322 else 3323 Coeff = Coeff1*Coeff2; 3324 const SCEV *CoeffTerm = getConstant(Ty, Coeff); 3325 const SCEV *Term1 = AddRec->getOperand(y-z); 3326 const SCEV *Term2 = OtherAddRec->getOperand(z); 3327 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, 3328 SCEV::FlagAnyWrap, Depth + 1)); 3329 } 3330 } 3331 if (SumOps.empty()) 3332 SumOps.push_back(getZero(Ty)); 3333 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); 3334 } 3335 if (!Overflow) { 3336 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), 3337 SCEV::FlagAnyWrap); 3338 if (Ops.size() == 2) return NewAddRec; 3339 Ops[Idx] = NewAddRec; 3340 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 3341 OpsModified = true; 3342 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); 3343 if (!AddRec) 3344 break; 3345 } 3346 } 3347 if (OpsModified) 3348 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); 3349 3350 // Otherwise couldn't fold anything into this recurrence. Move onto the 3351 // next one. 3352 } 3353 3354 // Okay, it looks like we really DO need an mul expr. Check to see if we 3355 // already have one, otherwise create a new one. 3356 return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); 3357 } 3358 3359 /// Represents an unsigned remainder expression based on unsigned division. 3360 const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, 3361 const SCEV *RHS) { 3362 assert(getEffectiveSCEVType(LHS->getType()) == 3363 getEffectiveSCEVType(RHS->getType()) && 3364 "SCEVURemExpr operand types don't match!"); 3365 3366 // Short-circuit easy cases 3367 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 3368 // If constant is one, the result is trivial 3369 if (RHSC->getValue()->isOne()) 3370 return getZero(LHS->getType()); // X urem 1 --> 0 3371 3372 // If constant is a power of two, fold into a zext(trunc(LHS)). 3373 if (RHSC->getAPInt().isPowerOf2()) { 3374 Type *FullTy = LHS->getType(); 3375 Type *TruncTy = 3376 IntegerType::get(getContext(), RHSC->getAPInt().logBase2()); 3377 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy); 3378 } 3379 } 3380 3381 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y) 3382 const SCEV *UDiv = getUDivExpr(LHS, RHS); 3383 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); 3384 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); 3385 } 3386 3387 /// Get a canonical unsigned division expression, or something simpler if 3388 /// possible. 3389 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, 3390 const SCEV *RHS) { 3391 assert(!LHS->getType()->isPointerTy() && 3392 "SCEVUDivExpr operand can't be pointer!"); 3393 assert(LHS->getType() == RHS->getType() && 3394 "SCEVUDivExpr operand types don't match!"); 3395 3396 FoldingSetNodeID ID; 3397 ID.AddInteger(scUDivExpr); 3398 ID.AddPointer(LHS); 3399 ID.AddPointer(RHS); 3400 void *IP = nullptr; 3401 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) 3402 return S; 3403 3404 // 0 udiv Y == 0 3405 if (match(LHS, m_scev_Zero())) 3406 return LHS; 3407 3408 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 3409 if (RHSC->getValue()->isOne()) 3410 return LHS; // X udiv 1 --> x 3411 // If the denominator is zero, the result of the udiv is undefined. Don't 3412 // try to analyze it, because the resolution chosen here may differ from 3413 // the resolution chosen in other parts of the compiler. 3414 if (!RHSC->getValue()->isZero()) { 3415 // Determine if the division can be folded into the operands of 3416 // its operands. 3417 // TODO: Generalize this to non-constants by using known-bits information. 3418 Type *Ty = LHS->getType(); 3419 unsigned LZ = RHSC->getAPInt().countl_zero(); 3420 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; 3421 // For non-power-of-two values, effectively round the value up to the 3422 // nearest power of two. 3423 if (!RHSC->getAPInt().isPowerOf2()) 3424 ++MaxShiftAmt; 3425 IntegerType *ExtTy = 3426 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); 3427 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) 3428 if (const SCEVConstant *Step = 3429 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) { 3430 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. 3431 const APInt &StepInt = Step->getAPInt(); 3432 const APInt &DivInt = RHSC->getAPInt(); 3433 if (!StepInt.urem(DivInt) && 3434 getZeroExtendExpr(AR, ExtTy) == 3435 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 3436 getZeroExtendExpr(Step, ExtTy), 3437 AR->getLoop(), SCEV::FlagAnyWrap)) { 3438 SmallVector<const SCEV *, 4> Operands; 3439 for (const SCEV *Op : AR->operands()) 3440 Operands.push_back(getUDivExpr(Op, RHS)); 3441 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); 3442 } 3443 /// Get a canonical UDivExpr for a recurrence. 3444 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. 3445 // We can currently only fold X%N if X is constant. 3446 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart()); 3447 if (StartC && !DivInt.urem(StepInt) && 3448 getZeroExtendExpr(AR, ExtTy) == 3449 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 3450 getZeroExtendExpr(Step, ExtTy), 3451 AR->getLoop(), SCEV::FlagAnyWrap)) { 3452 const APInt &StartInt = StartC->getAPInt(); 3453 const APInt &StartRem = StartInt.urem(StepInt); 3454 if (StartRem != 0) { 3455 const SCEV *NewLHS = 3456 getAddRecExpr(getConstant(StartInt - StartRem), Step, 3457 AR->getLoop(), SCEV::FlagNW); 3458 if (LHS != NewLHS) { 3459 LHS = NewLHS; 3460 3461 // Reset the ID to include the new LHS, and check if it is 3462 // already cached. 3463 ID.clear(); 3464 ID.AddInteger(scUDivExpr); 3465 ID.AddPointer(LHS); 3466 ID.AddPointer(RHS); 3467 IP = nullptr; 3468 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) 3469 return S; 3470 } 3471 } 3472 } 3473 } 3474 // (A*B)/C --> A*(B/C) if safe and B/C can be folded. 3475 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { 3476 SmallVector<const SCEV *, 4> Operands; 3477 for (const SCEV *Op : M->operands()) 3478 Operands.push_back(getZeroExtendExpr(Op, ExtTy)); 3479 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) 3480 // Find an operand that's safely divisible. 3481 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { 3482 const SCEV *Op = M->getOperand(i); 3483 const SCEV *Div = getUDivExpr(Op, RHSC); 3484 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { 3485 Operands = SmallVector<const SCEV *, 4>(M->operands()); 3486 Operands[i] = Div; 3487 return getMulExpr(Operands); 3488 } 3489 } 3490 } 3491 3492 // (A/B)/C --> A/(B*C) if safe and B*C can be folded. 3493 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) { 3494 if (auto *DivisorConstant = 3495 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) { 3496 bool Overflow = false; 3497 APInt NewRHS = 3498 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow); 3499 if (Overflow) { 3500 return getConstant(RHSC->getType(), 0, false); 3501 } 3502 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS)); 3503 } 3504 } 3505 3506 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. 3507 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { 3508 SmallVector<const SCEV *, 4> Operands; 3509 for (const SCEV *Op : A->operands()) 3510 Operands.push_back(getZeroExtendExpr(Op, ExtTy)); 3511 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { 3512 Operands.clear(); 3513 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { 3514 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); 3515 if (isa<SCEVUDivExpr>(Op) || 3516 getMulExpr(Op, RHS) != A->getOperand(i)) 3517 break; 3518 Operands.push_back(Op); 3519 } 3520 if (Operands.size() == A->getNumOperands()) 3521 return getAddExpr(Operands); 3522 } 3523 } 3524 3525 // Fold if both operands are constant. 3526 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) 3527 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt())); 3528 } 3529 } 3530 3531 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C. 3532 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS); 3533 AE && AE->getNumOperands() == 2) { 3534 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) { 3535 const APInt &NegC = VC->getAPInt(); 3536 if (NegC.isNegative() && !NegC.isMinSignedValue()) { 3537 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1)); 3538 if (MME && MME->getNumOperands() == 2 && 3539 isa<SCEVConstant>(MME->getOperand(0)) && 3540 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC && 3541 MME->getOperand(1) == RHS) 3542 return getZero(LHS->getType()); 3543 } 3544 } 3545 } 3546 3547 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs 3548 // changes). Make sure we get a new one. 3549 IP = nullptr; 3550 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 3551 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), 3552 LHS, RHS); 3553 UniqueSCEVs.InsertNode(S, IP); 3554 registerUser(S, {LHS, RHS}); 3555 return S; 3556 } 3557 3558 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { 3559 APInt A = C1->getAPInt().abs(); 3560 APInt B = C2->getAPInt().abs(); 3561 uint32_t ABW = A.getBitWidth(); 3562 uint32_t BBW = B.getBitWidth(); 3563 3564 if (ABW > BBW) 3565 B = B.zext(ABW); 3566 else if (ABW < BBW) 3567 A = A.zext(BBW); 3568 3569 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B)); 3570 } 3571 3572 /// Get a canonical unsigned division expression, or something simpler if 3573 /// possible. There is no representation for an exact udiv in SCEV IR, but we 3574 /// can attempt to remove factors from the LHS and RHS. We can't do this when 3575 /// it's not exact because the udiv may be clearing bits. 3576 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, 3577 const SCEV *RHS) { 3578 // TODO: we could try to find factors in all sorts of things, but for now we 3579 // just deal with u/exact (multiply, constant). See SCEVDivision towards the 3580 // end of this file for inspiration. 3581 3582 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); 3583 if (!Mul || !Mul->hasNoUnsignedWrap()) 3584 return getUDivExpr(LHS, RHS); 3585 3586 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { 3587 // If the mulexpr multiplies by a constant, then that constant must be the 3588 // first element of the mulexpr. 3589 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) { 3590 if (LHSCst == RHSCst) { 3591 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands())); 3592 return getMulExpr(Operands); 3593 } 3594 3595 // We can't just assume that LHSCst divides RHSCst cleanly, it could be 3596 // that there's a factor provided by one of the other terms. We need to 3597 // check. 3598 APInt Factor = gcd(LHSCst, RHSCst); 3599 if (!Factor.isIntN(1)) { 3600 LHSCst = 3601 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor))); 3602 RHSCst = 3603 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor))); 3604 SmallVector<const SCEV *, 2> Operands; 3605 Operands.push_back(LHSCst); 3606 append_range(Operands, Mul->operands().drop_front()); 3607 LHS = getMulExpr(Operands); 3608 RHS = RHSCst; 3609 Mul = dyn_cast<SCEVMulExpr>(LHS); 3610 if (!Mul) 3611 return getUDivExactExpr(LHS, RHS); 3612 } 3613 } 3614 } 3615 3616 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { 3617 if (Mul->getOperand(i) == RHS) { 3618 SmallVector<const SCEV *, 2> Operands; 3619 append_range(Operands, Mul->operands().take_front(i)); 3620 append_range(Operands, Mul->operands().drop_front(i + 1)); 3621 return getMulExpr(Operands); 3622 } 3623 } 3624 3625 return getUDivExpr(LHS, RHS); 3626 } 3627 3628 /// Get an add recurrence expression for the specified loop. Simplify the 3629 /// expression as much as possible. 3630 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, 3631 const Loop *L, 3632 SCEV::NoWrapFlags Flags) { 3633 SmallVector<const SCEV *, 4> Operands; 3634 Operands.push_back(Start); 3635 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) 3636 if (StepChrec->getLoop() == L) { 3637 append_range(Operands, StepChrec->operands()); 3638 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); 3639 } 3640 3641 Operands.push_back(Step); 3642 return getAddRecExpr(Operands, L, Flags); 3643 } 3644 3645 /// Get an add recurrence expression for the specified loop. Simplify the 3646 /// expression as much as possible. 3647 const SCEV * 3648 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, 3649 const Loop *L, SCEV::NoWrapFlags Flags) { 3650 if (Operands.size() == 1) return Operands[0]; 3651 #ifndef NDEBUG 3652 Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); 3653 for (const SCEV *Op : llvm::drop_begin(Operands)) { 3654 assert(getEffectiveSCEVType(Op->getType()) == ETy && 3655 "SCEVAddRecExpr operand types don't match!"); 3656 assert(!Op->getType()->isPointerTy() && "Step must be integer"); 3657 } 3658 for (const SCEV *Op : Operands) 3659 assert(isAvailableAtLoopEntry(Op, L) && 3660 "SCEVAddRecExpr operand is not available at loop entry!"); 3661 #endif 3662 3663 if (Operands.back()->isZero()) { 3664 Operands.pop_back(); 3665 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X 3666 } 3667 3668 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and 3669 // use that information to infer NUW and NSW flags. However, computing a 3670 // BE count requires calling getAddRecExpr, so we may not yet have a 3671 // meaningful BE count at this point (and if we don't, we'd be stuck 3672 // with a SCEVCouldNotCompute as the cached BE count). 3673 3674 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); 3675 3676 // Canonicalize nested AddRecs in by nesting them in order of loop depth. 3677 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { 3678 const Loop *NestedLoop = NestedAR->getLoop(); 3679 if (L->contains(NestedLoop) 3680 ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) 3681 : (!NestedLoop->contains(L) && 3682 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { 3683 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands()); 3684 Operands[0] = NestedAR->getStart(); 3685 // AddRecs require their operands be loop-invariant with respect to their 3686 // loops. Don't perform this transformation if it would break this 3687 // requirement. 3688 bool AllInvariant = all_of( 3689 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); 3690 3691 if (AllInvariant) { 3692 // Create a recurrence for the outer loop with the same step size. 3693 // 3694 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the 3695 // inner recurrence has the same property. 3696 SCEV::NoWrapFlags OuterFlags = 3697 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); 3698 3699 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); 3700 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { 3701 return isLoopInvariant(Op, NestedLoop); 3702 }); 3703 3704 if (AllInvariant) { 3705 // Ok, both add recurrences are valid after the transformation. 3706 // 3707 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if 3708 // the outer recurrence has the same property. 3709 SCEV::NoWrapFlags InnerFlags = 3710 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); 3711 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); 3712 } 3713 } 3714 // Reset Operands to its original state. 3715 Operands[0] = NestedAR; 3716 } 3717 } 3718 3719 // Okay, it looks like we really DO need an addrec expr. Check to see if we 3720 // already have one, otherwise create a new one. 3721 return getOrCreateAddRecExpr(Operands, L, Flags); 3722 } 3723 3724 const SCEV * 3725 ScalarEvolution::getGEPExpr(GEPOperator *GEP, 3726 const SmallVectorImpl<const SCEV *> &IndexExprs) { 3727 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); 3728 // getSCEV(Base)->getType() has the same address space as Base->getType() 3729 // because SCEV::getType() preserves the address space. 3730 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType()); 3731 GEPNoWrapFlags NW = GEP->getNoWrapFlags(); 3732 if (NW != GEPNoWrapFlags::none()) { 3733 // We'd like to propagate flags from the IR to the corresponding SCEV nodes, 3734 // but to do that, we have to ensure that said flag is valid in the entire 3735 // defined scope of the SCEV. 3736 // TODO: non-instructions have global scope. We might be able to prove 3737 // some global scope cases 3738 auto *GEPI = dyn_cast<Instruction>(GEP); 3739 if (!GEPI || !isSCEVExprNeverPoison(GEPI)) 3740 NW = GEPNoWrapFlags::none(); 3741 } 3742 3743 SCEV::NoWrapFlags OffsetWrap = SCEV::FlagAnyWrap; 3744 if (NW.hasNoUnsignedSignedWrap()) 3745 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW); 3746 if (NW.hasNoUnsignedWrap()) 3747 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW); 3748 3749 Type *CurTy = GEP->getType(); 3750 bool FirstIter = true; 3751 SmallVector<const SCEV *, 4> Offsets; 3752 for (const SCEV *IndexExpr : IndexExprs) { 3753 // Compute the (potentially symbolic) offset in bytes for this index. 3754 if (StructType *STy = dyn_cast<StructType>(CurTy)) { 3755 // For a struct, add the member offset. 3756 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); 3757 unsigned FieldNo = Index->getZExtValue(); 3758 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); 3759 Offsets.push_back(FieldOffset); 3760 3761 // Update CurTy to the type of the field at Index. 3762 CurTy = STy->getTypeAtIndex(Index); 3763 } else { 3764 // Update CurTy to its element type. 3765 if (FirstIter) { 3766 assert(isa<PointerType>(CurTy) && 3767 "The first index of a GEP indexes a pointer"); 3768 CurTy = GEP->getSourceElementType(); 3769 FirstIter = false; 3770 } else { 3771 CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0); 3772 } 3773 // For an array, add the element offset, explicitly scaled. 3774 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy); 3775 // Getelementptr indices are signed. 3776 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy); 3777 3778 // Multiply the index by the element size to compute the element offset. 3779 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); 3780 Offsets.push_back(LocalOffset); 3781 } 3782 } 3783 3784 // Handle degenerate case of GEP without offsets. 3785 if (Offsets.empty()) 3786 return BaseExpr; 3787 3788 // Add the offsets together, assuming nsw if inbounds. 3789 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap); 3790 // Add the base address and the offset. We cannot use the nsw flag, as the 3791 // base address is unsigned. However, if we know that the offset is 3792 // non-negative, we can use nuw. 3793 bool NUW = NW.hasNoUnsignedWrap() || 3794 (NW.hasNoUnsignedSignedWrap() && isKnownNonNegative(Offset)); 3795 SCEV::NoWrapFlags BaseWrap = NUW ? SCEV::FlagNUW : SCEV::FlagAnyWrap; 3796 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap); 3797 assert(BaseExpr->getType() == GEPExpr->getType() && 3798 "GEP should not change type mid-flight."); 3799 return GEPExpr; 3800 } 3801 3802 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, 3803 ArrayRef<const SCEV *> Ops) { 3804 FoldingSetNodeID ID; 3805 ID.AddInteger(SCEVType); 3806 for (const SCEV *Op : Ops) 3807 ID.AddPointer(Op); 3808 void *IP = nullptr; 3809 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP); 3810 } 3811 3812 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) { 3813 SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap; 3814 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags)); 3815 } 3816 3817 const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, 3818 SmallVectorImpl<const SCEV *> &Ops) { 3819 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!"); 3820 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); 3821 if (Ops.size() == 1) return Ops[0]; 3822 #ifndef NDEBUG 3823 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 3824 for (unsigned i = 1, e = Ops.size(); i != e; ++i) { 3825 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 3826 "Operand types don't match!"); 3827 assert(Ops[0]->getType()->isPointerTy() == 3828 Ops[i]->getType()->isPointerTy() && 3829 "min/max should be consistently pointerish"); 3830 } 3831 #endif 3832 3833 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr; 3834 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr; 3835 3836 const SCEV *Folded = constantFoldAndGroupOps( 3837 *this, LI, DT, Ops, 3838 [&](const APInt &C1, const APInt &C2) { 3839 switch (Kind) { 3840 case scSMaxExpr: 3841 return APIntOps::smax(C1, C2); 3842 case scSMinExpr: 3843 return APIntOps::smin(C1, C2); 3844 case scUMaxExpr: 3845 return APIntOps::umax(C1, C2); 3846 case scUMinExpr: 3847 return APIntOps::umin(C1, C2); 3848 default: 3849 llvm_unreachable("Unknown SCEV min/max opcode"); 3850 } 3851 }, 3852 [&](const APInt &C) { 3853 // identity 3854 if (IsMax) 3855 return IsSigned ? C.isMinSignedValue() : C.isMinValue(); 3856 else 3857 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue(); 3858 }, 3859 [&](const APInt &C) { 3860 // absorber 3861 if (IsMax) 3862 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue(); 3863 else 3864 return IsSigned ? C.isMinSignedValue() : C.isMinValue(); 3865 }); 3866 if (Folded) 3867 return Folded; 3868 3869 // Check if we have created the same expression before. 3870 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) { 3871 return S; 3872 } 3873 3874 // Find the first operation of the same kind 3875 unsigned Idx = 0; 3876 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind) 3877 ++Idx; 3878 3879 // Check to see if one of the operands is of the same kind. If so, expand its 3880 // operands onto our operand list, and recurse to simplify. 3881 if (Idx < Ops.size()) { 3882 bool DeletedAny = false; 3883 while (Ops[Idx]->getSCEVType() == Kind) { 3884 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]); 3885 Ops.erase(Ops.begin()+Idx); 3886 append_range(Ops, SMME->operands()); 3887 DeletedAny = true; 3888 } 3889 3890 if (DeletedAny) 3891 return getMinMaxExpr(Kind, Ops); 3892 } 3893 3894 // Okay, check to see if the same value occurs in the operand list twice. If 3895 // so, delete one. Since we sorted the list, these values are required to 3896 // be adjacent. 3897 llvm::CmpInst::Predicate GEPred = 3898 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; 3899 llvm::CmpInst::Predicate LEPred = 3900 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; 3901 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred; 3902 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred; 3903 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) { 3904 if (Ops[i] == Ops[i + 1] || 3905 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) { 3906 // X op Y op Y --> X op Y 3907 // X op Y --> X, if we know X, Y are ordered appropriately 3908 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2); 3909 --i; 3910 --e; 3911 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i], 3912 Ops[i + 1])) { 3913 // X op Y --> Y, if we know X, Y are ordered appropriately 3914 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1); 3915 --i; 3916 --e; 3917 } 3918 } 3919 3920 if (Ops.size() == 1) return Ops[0]; 3921 3922 assert(!Ops.empty() && "Reduced smax down to nothing!"); 3923 3924 // Okay, it looks like we really DO need an expr. Check to see if we 3925 // already have one, otherwise create a new one. 3926 FoldingSetNodeID ID; 3927 ID.AddInteger(Kind); 3928 for (const SCEV *Op : Ops) 3929 ID.AddPointer(Op); 3930 void *IP = nullptr; 3931 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); 3932 if (ExistingSCEV) 3933 return ExistingSCEV; 3934 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3935 llvm::uninitialized_copy(Ops, O); 3936 SCEV *S = new (SCEVAllocator) 3937 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); 3938 3939 UniqueSCEVs.InsertNode(S, IP); 3940 registerUser(S, Ops); 3941 return S; 3942 } 3943 3944 namespace { 3945 3946 class SCEVSequentialMinMaxDeduplicatingVisitor final 3947 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, 3948 std::optional<const SCEV *>> { 3949 using RetVal = std::optional<const SCEV *>; 3950 using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>; 3951 3952 ScalarEvolution &SE; 3953 const SCEVTypes RootKind; // Must be a sequential min/max expression. 3954 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind. 3955 SmallPtrSet<const SCEV *, 16> SeenOps; 3956 3957 bool canRecurseInto(SCEVTypes Kind) const { 3958 // We can only recurse into the SCEV expression of the same effective type 3959 // as the type of our root SCEV expression. 3960 return RootKind == Kind || NonSequentialRootKind == Kind; 3961 }; 3962 3963 RetVal visitAnyMinMaxExpr(const SCEV *S) { 3964 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) && 3965 "Only for min/max expressions."); 3966 SCEVTypes Kind = S->getSCEVType(); 3967 3968 if (!canRecurseInto(Kind)) 3969 return S; 3970 3971 auto *NAry = cast<SCEVNAryExpr>(S); 3972 SmallVector<const SCEV *> NewOps; 3973 bool Changed = visit(Kind, NAry->operands(), NewOps); 3974 3975 if (!Changed) 3976 return S; 3977 if (NewOps.empty()) 3978 return std::nullopt; 3979 3980 return isa<SCEVSequentialMinMaxExpr>(S) 3981 ? SE.getSequentialMinMaxExpr(Kind, NewOps) 3982 : SE.getMinMaxExpr(Kind, NewOps); 3983 } 3984 3985 RetVal visit(const SCEV *S) { 3986 // Has the whole operand been seen already? 3987 if (!SeenOps.insert(S).second) 3988 return std::nullopt; 3989 return Base::visit(S); 3990 } 3991 3992 public: 3993 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE, 3994 SCEVTypes RootKind) 3995 : SE(SE), RootKind(RootKind), 3996 NonSequentialRootKind( 3997 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( 3998 RootKind)) {} 3999 4000 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps, 4001 SmallVectorImpl<const SCEV *> &NewOps) { 4002 bool Changed = false; 4003 SmallVector<const SCEV *> Ops; 4004 Ops.reserve(OrigOps.size()); 4005 4006 for (const SCEV *Op : OrigOps) { 4007 RetVal NewOp = visit(Op); 4008 if (NewOp != Op) 4009 Changed = true; 4010 if (NewOp) 4011 Ops.emplace_back(*NewOp); 4012 } 4013 4014 if (Changed) 4015 NewOps = std::move(Ops); 4016 return Changed; 4017 } 4018 4019 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; } 4020 4021 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; } 4022 4023 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; } 4024 4025 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; } 4026 4027 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; } 4028 4029 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; } 4030 4031 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; } 4032 4033 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; } 4034 4035 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; } 4036 4037 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; } 4038 4039 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) { 4040 return visitAnyMinMaxExpr(Expr); 4041 } 4042 4043 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) { 4044 return visitAnyMinMaxExpr(Expr); 4045 } 4046 4047 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) { 4048 return visitAnyMinMaxExpr(Expr); 4049 } 4050 4051 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) { 4052 return visitAnyMinMaxExpr(Expr); 4053 } 4054 4055 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 4056 return visitAnyMinMaxExpr(Expr); 4057 } 4058 4059 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; } 4060 4061 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; } 4062 }; 4063 4064 } // namespace 4065 4066 static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) { 4067 switch (Kind) { 4068 case scConstant: 4069 case scVScale: 4070 case scTruncate: 4071 case scZeroExtend: 4072 case scSignExtend: 4073 case scPtrToInt: 4074 case scAddExpr: 4075 case scMulExpr: 4076 case scUDivExpr: 4077 case scAddRecExpr: 4078 case scUMaxExpr: 4079 case scSMaxExpr: 4080 case scUMinExpr: 4081 case scSMinExpr: 4082 case scUnknown: 4083 // If any operand is poison, the whole expression is poison. 4084 return true; 4085 case scSequentialUMinExpr: 4086 // FIXME: if the *first* operand is poison, the whole expression is poison. 4087 return false; // Pessimistically, say that it does not propagate poison. 4088 case scCouldNotCompute: 4089 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 4090 } 4091 llvm_unreachable("Unknown SCEV kind!"); 4092 } 4093 4094 namespace { 4095 // The only way poison may be introduced in a SCEV expression is from a 4096 // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown, 4097 // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not* 4098 // introduce poison -- they encode guaranteed, non-speculated knowledge. 4099 // 4100 // Additionally, all SCEV nodes propagate poison from inputs to outputs, 4101 // with the notable exception of umin_seq, where only poison from the first 4102 // operand is (unconditionally) propagated. 4103 struct SCEVPoisonCollector { 4104 bool LookThroughMaybePoisonBlocking; 4105 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison; 4106 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking) 4107 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {} 4108 4109 bool follow(const SCEV *S) { 4110 if (!LookThroughMaybePoisonBlocking && 4111 !scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) 4112 return false; 4113 4114 if (auto *SU = dyn_cast<SCEVUnknown>(S)) { 4115 if (!isGuaranteedNotToBePoison(SU->getValue())) 4116 MaybePoison.insert(SU); 4117 } 4118 return true; 4119 } 4120 bool isDone() const { return false; } 4121 }; 4122 } // namespace 4123 4124 /// Return true if V is poison given that AssumedPoison is already poison. 4125 static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { 4126 // First collect all SCEVs that might result in AssumedPoison to be poison. 4127 // We need to look through potentially poison-blocking operations here, 4128 // because we want to find all SCEVs that *might* result in poison, not only 4129 // those that are *required* to. 4130 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true); 4131 visitAll(AssumedPoison, PC1); 4132 4133 // AssumedPoison is never poison. As the assumption is false, the implication 4134 // is true. Don't bother walking the other SCEV in this case. 4135 if (PC1.MaybePoison.empty()) 4136 return true; 4137 4138 // Collect all SCEVs in S that, if poison, *will* result in S being poison 4139 // as well. We cannot look through potentially poison-blocking operations 4140 // here, as their arguments only *may* make the result poison. 4141 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false); 4142 visitAll(S, PC2); 4143 4144 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison, 4145 // it will also make S poison by being part of PC2.MaybePoison. 4146 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison); 4147 } 4148 4149 void ScalarEvolution::getPoisonGeneratingValues( 4150 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) { 4151 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false); 4152 visitAll(S, PC); 4153 for (const SCEVUnknown *SU : PC.MaybePoison) 4154 Result.insert(SU->getValue()); 4155 } 4156 4157 bool ScalarEvolution::canReuseInstruction( 4158 const SCEV *S, Instruction *I, 4159 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) { 4160 // If the instruction cannot be poison, it's always safe to reuse. 4161 if (programUndefinedIfPoison(I)) 4162 return true; 4163 4164 // Otherwise, it is possible that I is more poisonous that S. Collect the 4165 // poison-contributors of S, and then check whether I has any additional 4166 // poison-contributors. Poison that is contributed through poison-generating 4167 // flags is handled by dropping those flags instead. 4168 SmallPtrSet<const Value *, 8> PoisonVals; 4169 getPoisonGeneratingValues(PoisonVals, S); 4170 4171 SmallVector<Value *> Worklist; 4172 SmallPtrSet<Value *, 8> Visited; 4173 Worklist.push_back(I); 4174 while (!Worklist.empty()) { 4175 Value *V = Worklist.pop_back_val(); 4176 if (!Visited.insert(V).second) 4177 continue; 4178 4179 // Avoid walking large instruction graphs. 4180 if (Visited.size() > 16) 4181 return false; 4182 4183 // Either the value can't be poison, or the S would also be poison if it 4184 // is. 4185 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V)) 4186 continue; 4187 4188 auto *I = dyn_cast<Instruction>(V); 4189 if (!I) 4190 return false; 4191 4192 // Disjoint or instructions are interpreted as adds by SCEV. However, we 4193 // can't replace an arbitrary add with disjoint or, even if we drop the 4194 // flag. We would need to convert the or into an add. 4195 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I)) 4196 if (PDI->isDisjoint()) 4197 return false; 4198 4199 // FIXME: Ignore vscale, even though it technically could be poison. Do this 4200 // because SCEV currently assumes it can't be poison. Remove this special 4201 // case once we proper model when vscale can be poison. 4202 if (auto *II = dyn_cast<IntrinsicInst>(I); 4203 II && II->getIntrinsicID() == Intrinsic::vscale) 4204 continue; 4205 4206 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false)) 4207 return false; 4208 4209 // If the instruction can't create poison, we can recurse to its operands. 4210 if (I->hasPoisonGeneratingAnnotations()) 4211 DropPoisonGeneratingInsts.push_back(I); 4212 4213 llvm::append_range(Worklist, I->operands()); 4214 } 4215 return true; 4216 } 4217 4218 const SCEV * 4219 ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, 4220 SmallVectorImpl<const SCEV *> &Ops) { 4221 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && 4222 "Not a SCEVSequentialMinMaxExpr!"); 4223 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); 4224 if (Ops.size() == 1) 4225 return Ops[0]; 4226 #ifndef NDEBUG 4227 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 4228 for (unsigned i = 1, e = Ops.size(); i != e; ++i) { 4229 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 4230 "Operand types don't match!"); 4231 assert(Ops[0]->getType()->isPointerTy() == 4232 Ops[i]->getType()->isPointerTy() && 4233 "min/max should be consistently pointerish"); 4234 } 4235 #endif 4236 4237 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative, 4238 // so we can *NOT* do any kind of sorting of the expressions! 4239 4240 // Check if we have created the same expression before. 4241 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) 4242 return S; 4243 4244 // FIXME: there are *some* simplifications that we can do here. 4245 4246 // Keep only the first instance of an operand. 4247 { 4248 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind); 4249 bool Changed = Deduplicator.visit(Kind, Ops, Ops); 4250 if (Changed) 4251 return getSequentialMinMaxExpr(Kind, Ops); 4252 } 4253 4254 // Check to see if one of the operands is of the same kind. If so, expand its 4255 // operands onto our operand list, and recurse to simplify. 4256 { 4257 unsigned Idx = 0; 4258 bool DeletedAny = false; 4259 while (Idx < Ops.size()) { 4260 if (Ops[Idx]->getSCEVType() != Kind) { 4261 ++Idx; 4262 continue; 4263 } 4264 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]); 4265 Ops.erase(Ops.begin() + Idx); 4266 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(), 4267 SMME->operands().end()); 4268 DeletedAny = true; 4269 } 4270 4271 if (DeletedAny) 4272 return getSequentialMinMaxExpr(Kind, Ops); 4273 } 4274 4275 const SCEV *SaturationPoint; 4276 ICmpInst::Predicate Pred; 4277 switch (Kind) { 4278 case scSequentialUMinExpr: 4279 SaturationPoint = getZero(Ops[0]->getType()); 4280 Pred = ICmpInst::ICMP_ULE; 4281 break; 4282 default: 4283 llvm_unreachable("Not a sequential min/max type."); 4284 } 4285 4286 for (unsigned i = 1, e = Ops.size(); i != e; ++i) { 4287 if (!isGuaranteedNotToCauseUB(Ops[i])) 4288 continue; 4289 // We can replace %x umin_seq %y with %x umin %y if either: 4290 // * %y being poison implies %x is also poison. 4291 // * %x cannot be the saturating value (e.g. zero for umin). 4292 if (::impliesPoison(Ops[i], Ops[i - 1]) || 4293 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1], 4294 SaturationPoint)) { 4295 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]}; 4296 Ops[i - 1] = getMinMaxExpr( 4297 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind), 4298 SeqOps); 4299 Ops.erase(Ops.begin() + i); 4300 return getSequentialMinMaxExpr(Kind, Ops); 4301 } 4302 // Fold %x umin_seq %y to %x if %x ule %y. 4303 // TODO: We might be able to prove the predicate for a later operand. 4304 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) { 4305 Ops.erase(Ops.begin() + i); 4306 return getSequentialMinMaxExpr(Kind, Ops); 4307 } 4308 } 4309 4310 // Okay, it looks like we really DO need an expr. Check to see if we 4311 // already have one, otherwise create a new one. 4312 FoldingSetNodeID ID; 4313 ID.AddInteger(Kind); 4314 for (const SCEV *Op : Ops) 4315 ID.AddPointer(Op); 4316 void *IP = nullptr; 4317 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); 4318 if (ExistingSCEV) 4319 return ExistingSCEV; 4320 4321 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 4322 llvm::uninitialized_copy(Ops, O); 4323 SCEV *S = new (SCEVAllocator) 4324 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); 4325 4326 UniqueSCEVs.InsertNode(S, IP); 4327 registerUser(S, Ops); 4328 return S; 4329 } 4330 4331 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { 4332 SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; 4333 return getSMaxExpr(Ops); 4334 } 4335 4336 const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 4337 return getMinMaxExpr(scSMaxExpr, Ops); 4338 } 4339 4340 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { 4341 SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; 4342 return getUMaxExpr(Ops); 4343 } 4344 4345 const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 4346 return getMinMaxExpr(scUMaxExpr, Ops); 4347 } 4348 4349 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, 4350 const SCEV *RHS) { 4351 SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; 4352 return getSMinExpr(Ops); 4353 } 4354 4355 const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) { 4356 return getMinMaxExpr(scSMinExpr, Ops); 4357 } 4358 4359 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS, 4360 bool Sequential) { 4361 SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; 4362 return getUMinExpr(Ops, Sequential); 4363 } 4364 4365 const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops, 4366 bool Sequential) { 4367 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops) 4368 : getMinMaxExpr(scUMinExpr, Ops); 4369 } 4370 4371 const SCEV * 4372 ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { 4373 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); 4374 if (Size.isScalable()) 4375 Res = getMulExpr(Res, getVScale(IntTy)); 4376 return Res; 4377 } 4378 4379 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { 4380 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); 4381 } 4382 4383 const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { 4384 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); 4385 } 4386 4387 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, 4388 StructType *STy, 4389 unsigned FieldNo) { 4390 // We can bypass creating a target-independent constant expression and then 4391 // folding it back into a ConstantInt. This is just a compile-time 4392 // optimization. 4393 const StructLayout *SL = getDataLayout().getStructLayout(STy); 4394 assert(!SL->getSizeInBits().isScalable() && 4395 "Cannot get offset for structure containing scalable vector types"); 4396 return getConstant(IntTy, SL->getElementOffset(FieldNo)); 4397 } 4398 4399 const SCEV *ScalarEvolution::getUnknown(Value *V) { 4400 // Don't attempt to do anything other than create a SCEVUnknown object 4401 // here. createSCEV only calls getUnknown after checking for all other 4402 // interesting possibilities, and any other code that calls getUnknown 4403 // is doing so in order to hide a value from SCEV canonicalization. 4404 4405 FoldingSetNodeID ID; 4406 ID.AddInteger(scUnknown); 4407 ID.AddPointer(V); 4408 void *IP = nullptr; 4409 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { 4410 assert(cast<SCEVUnknown>(S)->getValue() == V && 4411 "Stale SCEVUnknown in uniquing map!"); 4412 return S; 4413 } 4414 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, 4415 FirstUnknown); 4416 FirstUnknown = cast<SCEVUnknown>(S); 4417 UniqueSCEVs.InsertNode(S, IP); 4418 return S; 4419 } 4420 4421 //===----------------------------------------------------------------------===// 4422 // Basic SCEV Analysis and PHI Idiom Recognition Code 4423 // 4424 4425 /// Test if values of the given type are analyzable within the SCEV 4426 /// framework. This primarily includes integer types, and it can optionally 4427 /// include pointer types if the ScalarEvolution class has access to 4428 /// target-specific information. 4429 bool ScalarEvolution::isSCEVable(Type *Ty) const { 4430 // Integers and pointers are always SCEVable. 4431 return Ty->isIntOrPtrTy(); 4432 } 4433 4434 /// Return the size in bits of the specified type, for which isSCEVable must 4435 /// return true. 4436 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { 4437 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 4438 if (Ty->isPointerTy()) 4439 return getDataLayout().getIndexTypeSizeInBits(Ty); 4440 return getDataLayout().getTypeSizeInBits(Ty); 4441 } 4442 4443 /// Return a type with the same bitwidth as the given type and which represents 4444 /// how SCEV will treat the given type, for which isSCEVable must return 4445 /// true. For pointer types, this is the pointer index sized integer type. 4446 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { 4447 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 4448 4449 if (Ty->isIntegerTy()) 4450 return Ty; 4451 4452 // The only other support type is pointer. 4453 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); 4454 return getDataLayout().getIndexType(Ty); 4455 } 4456 4457 Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { 4458 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; 4459 } 4460 4461 bool ScalarEvolution::instructionCouldExistWithOperands(const SCEV *A, 4462 const SCEV *B) { 4463 /// For a valid use point to exist, the defining scope of one operand 4464 /// must dominate the other. 4465 bool PreciseA, PreciseB; 4466 auto *ScopeA = getDefiningScopeBound({A}, PreciseA); 4467 auto *ScopeB = getDefiningScopeBound({B}, PreciseB); 4468 if (!PreciseA || !PreciseB) 4469 // Can't tell. 4470 return false; 4471 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) || 4472 DT.dominates(ScopeB, ScopeA); 4473 } 4474 4475 const SCEV *ScalarEvolution::getCouldNotCompute() { 4476 return CouldNotCompute.get(); 4477 } 4478 4479 bool ScalarEvolution::checkValidity(const SCEV *S) const { 4480 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { 4481 auto *SU = dyn_cast<SCEVUnknown>(S); 4482 return SU && SU->getValue() == nullptr; 4483 }); 4484 4485 return !ContainsNulls; 4486 } 4487 4488 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { 4489 HasRecMapType::iterator I = HasRecMap.find(S); 4490 if (I != HasRecMap.end()) 4491 return I->second; 4492 4493 bool FoundAddRec = 4494 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); }); 4495 HasRecMap.insert({S, FoundAddRec}); 4496 return FoundAddRec; 4497 } 4498 4499 /// Return the ValueOffsetPair set for \p S. \p S can be represented 4500 /// by the value and offset from any ValueOffsetPair in the set. 4501 ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) { 4502 ExprValueMapType::iterator SI = ExprValueMap.find_as(S); 4503 if (SI == ExprValueMap.end()) 4504 return {}; 4505 return SI->second.getArrayRef(); 4506 } 4507 4508 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) 4509 /// cannot be used separately. eraseValueFromMap should be used to remove 4510 /// V from ValueExprMap and ExprValueMap at the same time. 4511 void ScalarEvolution::eraseValueFromMap(Value *V) { 4512 ValueExprMapType::iterator I = ValueExprMap.find_as(V); 4513 if (I != ValueExprMap.end()) { 4514 auto EVIt = ExprValueMap.find(I->second); 4515 bool Removed = EVIt->second.remove(V); 4516 (void) Removed; 4517 assert(Removed && "Value not in ExprValueMap?"); 4518 ValueExprMap.erase(I); 4519 } 4520 } 4521 4522 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { 4523 // A recursive query may have already computed the SCEV. It should be 4524 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily 4525 // inferred nowrap flags. 4526 auto It = ValueExprMap.find_as(V); 4527 if (It == ValueExprMap.end()) { 4528 ValueExprMap.insert({SCEVCallbackVH(V, this), S}); 4529 ExprValueMap[S].insert(V); 4530 } 4531 } 4532 4533 /// Return an existing SCEV if it exists, otherwise analyze the expression and 4534 /// create a new one. 4535 const SCEV *ScalarEvolution::getSCEV(Value *V) { 4536 assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); 4537 4538 if (const SCEV *S = getExistingSCEV(V)) 4539 return S; 4540 return createSCEVIter(V); 4541 } 4542 4543 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { 4544 assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); 4545 4546 ValueExprMapType::iterator I = ValueExprMap.find_as(V); 4547 if (I != ValueExprMap.end()) { 4548 const SCEV *S = I->second; 4549 assert(checkValidity(S) && 4550 "existing SCEV has not been properly invalidated"); 4551 return S; 4552 } 4553 return nullptr; 4554 } 4555 4556 /// Return a SCEV corresponding to -V = -1*V 4557 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, 4558 SCEV::NoWrapFlags Flags) { 4559 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 4560 return getConstant( 4561 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue()))); 4562 4563 Type *Ty = V->getType(); 4564 Ty = getEffectiveSCEVType(Ty); 4565 return getMulExpr(V, getMinusOne(Ty), Flags); 4566 } 4567 4568 /// If Expr computes ~A, return A else return nullptr 4569 static const SCEV *MatchNotExpr(const SCEV *Expr) { 4570 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); 4571 if (!Add || Add->getNumOperands() != 2 || 4572 !Add->getOperand(0)->isAllOnesValue()) 4573 return nullptr; 4574 4575 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); 4576 if (!AddRHS || AddRHS->getNumOperands() != 2 || 4577 !AddRHS->getOperand(0)->isAllOnesValue()) 4578 return nullptr; 4579 4580 return AddRHS->getOperand(1); 4581 } 4582 4583 /// Return a SCEV corresponding to ~V = -1-V 4584 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { 4585 assert(!V->getType()->isPointerTy() && "Can't negate pointer"); 4586 4587 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 4588 return getConstant( 4589 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue()))); 4590 4591 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y) 4592 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) { 4593 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) { 4594 SmallVector<const SCEV *, 2> MatchedOperands; 4595 for (const SCEV *Operand : MME->operands()) { 4596 const SCEV *Matched = MatchNotExpr(Operand); 4597 if (!Matched) 4598 return (const SCEV *)nullptr; 4599 MatchedOperands.push_back(Matched); 4600 } 4601 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()), 4602 MatchedOperands); 4603 }; 4604 if (const SCEV *Replaced = MatchMinMaxNegation(MME)) 4605 return Replaced; 4606 } 4607 4608 Type *Ty = V->getType(); 4609 Ty = getEffectiveSCEVType(Ty); 4610 return getMinusSCEV(getMinusOne(Ty), V); 4611 } 4612 4613 const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { 4614 assert(P->getType()->isPointerTy()); 4615 4616 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) { 4617 // The base of an AddRec is the first operand. 4618 SmallVector<const SCEV *> Ops{AddRec->operands()}; 4619 Ops[0] = removePointerBase(Ops[0]); 4620 // Don't try to transfer nowrap flags for now. We could in some cases 4621 // (for example, if pointer operand of the AddRec is a SCEVUnknown). 4622 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap); 4623 } 4624 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) { 4625 // The base of an Add is the pointer operand. 4626 SmallVector<const SCEV *> Ops{Add->operands()}; 4627 const SCEV **PtrOp = nullptr; 4628 for (const SCEV *&AddOp : Ops) { 4629 if (AddOp->getType()->isPointerTy()) { 4630 assert(!PtrOp && "Cannot have multiple pointer ops"); 4631 PtrOp = &AddOp; 4632 } 4633 } 4634 *PtrOp = removePointerBase(*PtrOp); 4635 // Don't try to transfer nowrap flags for now. We could in some cases 4636 // (for example, if the pointer operand of the Add is a SCEVUnknown). 4637 return getAddExpr(Ops); 4638 } 4639 // Any other expression must be a pointer base. 4640 return getZero(P->getType()); 4641 } 4642 4643 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, 4644 SCEV::NoWrapFlags Flags, 4645 unsigned Depth) { 4646 // Fast path: X - X --> 0. 4647 if (LHS == RHS) 4648 return getZero(LHS->getType()); 4649 4650 // If we subtract two pointers with different pointer bases, bail. 4651 // Eventually, we're going to add an assertion to getMulExpr that we 4652 // can't multiply by a pointer. 4653 if (RHS->getType()->isPointerTy()) { 4654 if (!LHS->getType()->isPointerTy() || 4655 getPointerBase(LHS) != getPointerBase(RHS)) 4656 return getCouldNotCompute(); 4657 LHS = removePointerBase(LHS); 4658 RHS = removePointerBase(RHS); 4659 } 4660 4661 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation 4662 // makes it so that we cannot make much use of NUW. 4663 auto AddFlags = SCEV::FlagAnyWrap; 4664 const bool RHSIsNotMinSigned = 4665 !getSignedRangeMin(RHS).isMinSignedValue(); 4666 if (hasFlags(Flags, SCEV::FlagNSW)) { 4667 // Let M be the minimum representable signed value. Then (-1)*RHS 4668 // signed-wraps if and only if RHS is M. That can happen even for 4669 // a NSW subtraction because e.g. (-1)*M signed-wraps even though 4670 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + 4671 // (-1)*RHS, we need to prove that RHS != M. 4672 // 4673 // If LHS is non-negative and we know that LHS - RHS does not 4674 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap 4675 // either by proving that RHS > M or that LHS >= 0. 4676 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { 4677 AddFlags = SCEV::FlagNSW; 4678 } 4679 } 4680 4681 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - 4682 // RHS is NSW and LHS >= 0. 4683 // 4684 // The difficulty here is that the NSW flag may have been proven 4685 // relative to a loop that is to be found in a recurrence in LHS and 4686 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a 4687 // larger scope than intended. 4688 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; 4689 4690 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); 4691 } 4692 4693 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, 4694 unsigned Depth) { 4695 Type *SrcTy = V->getType(); 4696 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4697 "Cannot truncate or zero extend with non-integer arguments!"); 4698 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4699 return V; // No conversion 4700 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 4701 return getTruncateExpr(V, Ty, Depth); 4702 return getZeroExtendExpr(V, Ty, Depth); 4703 } 4704 4705 const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, 4706 unsigned Depth) { 4707 Type *SrcTy = V->getType(); 4708 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4709 "Cannot truncate or zero extend with non-integer arguments!"); 4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4711 return V; // No conversion 4712 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 4713 return getTruncateExpr(V, Ty, Depth); 4714 return getSignExtendExpr(V, Ty, Depth); 4715 } 4716 4717 const SCEV * 4718 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { 4719 Type *SrcTy = V->getType(); 4720 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4721 "Cannot noop or zero extend with non-integer arguments!"); 4722 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 4723 "getNoopOrZeroExtend cannot truncate!"); 4724 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4725 return V; // No conversion 4726 return getZeroExtendExpr(V, Ty); 4727 } 4728 4729 const SCEV * 4730 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { 4731 Type *SrcTy = V->getType(); 4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4733 "Cannot noop or sign extend with non-integer arguments!"); 4734 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 4735 "getNoopOrSignExtend cannot truncate!"); 4736 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4737 return V; // No conversion 4738 return getSignExtendExpr(V, Ty); 4739 } 4740 4741 const SCEV * 4742 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { 4743 Type *SrcTy = V->getType(); 4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4745 "Cannot noop or any extend with non-integer arguments!"); 4746 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 4747 "getNoopOrAnyExtend cannot truncate!"); 4748 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4749 return V; // No conversion 4750 return getAnyExtendExpr(V, Ty); 4751 } 4752 4753 const SCEV * 4754 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { 4755 Type *SrcTy = V->getType(); 4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && 4757 "Cannot truncate or noop with non-integer arguments!"); 4758 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && 4759 "getTruncateOrNoop cannot extend!"); 4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 4761 return V; // No conversion 4762 return getTruncateExpr(V, Ty); 4763 } 4764 4765 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, 4766 const SCEV *RHS) { 4767 const SCEV *PromotedLHS = LHS; 4768 const SCEV *PromotedRHS = RHS; 4769 4770 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) 4771 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); 4772 else 4773 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); 4774 4775 return getUMaxExpr(PromotedLHS, PromotedRHS); 4776 } 4777 4778 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, 4779 const SCEV *RHS, 4780 bool Sequential) { 4781 SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; 4782 return getUMinFromMismatchedTypes(Ops, Sequential); 4783 } 4784 4785 const SCEV * 4786 ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops, 4787 bool Sequential) { 4788 assert(!Ops.empty() && "At least one operand must be!"); 4789 // Trivial case. 4790 if (Ops.size() == 1) 4791 return Ops[0]; 4792 4793 // Find the max type first. 4794 Type *MaxType = nullptr; 4795 for (const auto *S : Ops) 4796 if (MaxType) 4797 MaxType = getWiderType(MaxType, S->getType()); 4798 else 4799 MaxType = S->getType(); 4800 assert(MaxType && "Failed to find maximum type!"); 4801 4802 // Extend all ops to max type. 4803 SmallVector<const SCEV *, 2> PromotedOps; 4804 for (const auto *S : Ops) 4805 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); 4806 4807 // Generate umin. 4808 return getUMinExpr(PromotedOps, Sequential); 4809 } 4810 4811 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { 4812 // A pointer operand may evaluate to a nonpointer expression, such as null. 4813 if (!V->getType()->isPointerTy()) 4814 return V; 4815 4816 while (true) { 4817 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { 4818 V = AddRec->getStart(); 4819 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) { 4820 const SCEV *PtrOp = nullptr; 4821 for (const SCEV *AddOp : Add->operands()) { 4822 if (AddOp->getType()->isPointerTy()) { 4823 assert(!PtrOp && "Cannot have multiple pointer ops"); 4824 PtrOp = AddOp; 4825 } 4826 } 4827 assert(PtrOp && "Must have pointer op"); 4828 V = PtrOp; 4829 } else // Not something we can look further into. 4830 return V; 4831 } 4832 } 4833 4834 /// Push users of the given Instruction onto the given Worklist. 4835 static void PushDefUseChildren(Instruction *I, 4836 SmallVectorImpl<Instruction *> &Worklist, 4837 SmallPtrSetImpl<Instruction *> &Visited) { 4838 // Push the def-use children onto the Worklist stack. 4839 for (User *U : I->users()) { 4840 auto *UserInsn = cast<Instruction>(U); 4841 if (Visited.insert(UserInsn).second) 4842 Worklist.push_back(UserInsn); 4843 } 4844 } 4845 4846 namespace { 4847 4848 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start 4849 /// expression in case its Loop is L. If it is not L then 4850 /// if IgnoreOtherLoops is true then use AddRec itself 4851 /// otherwise rewrite cannot be done. 4852 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. 4853 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { 4854 public: 4855 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, 4856 bool IgnoreOtherLoops = true) { 4857 SCEVInitRewriter Rewriter(L, SE); 4858 const SCEV *Result = Rewriter.visit(S); 4859 if (Rewriter.hasSeenLoopVariantSCEVUnknown()) 4860 return SE.getCouldNotCompute(); 4861 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops 4862 ? SE.getCouldNotCompute() 4863 : Result; 4864 } 4865 4866 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 4867 if (!SE.isLoopInvariant(Expr, L)) 4868 SeenLoopVariantSCEVUnknown = true; 4869 return Expr; 4870 } 4871 4872 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 4873 // Only re-write AddRecExprs for this loop. 4874 if (Expr->getLoop() == L) 4875 return Expr->getStart(); 4876 SeenOtherLoops = true; 4877 return Expr; 4878 } 4879 4880 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } 4881 4882 bool hasSeenOtherLoops() { return SeenOtherLoops; } 4883 4884 private: 4885 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) 4886 : SCEVRewriteVisitor(SE), L(L) {} 4887 4888 const Loop *L; 4889 bool SeenLoopVariantSCEVUnknown = false; 4890 bool SeenOtherLoops = false; 4891 }; 4892 4893 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post 4894 /// increment expression in case its Loop is L. If it is not L then 4895 /// use AddRec itself. 4896 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. 4897 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> { 4898 public: 4899 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { 4900 SCEVPostIncRewriter Rewriter(L, SE); 4901 const SCEV *Result = Rewriter.visit(S); 4902 return Rewriter.hasSeenLoopVariantSCEVUnknown() 4903 ? SE.getCouldNotCompute() 4904 : Result; 4905 } 4906 4907 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 4908 if (!SE.isLoopInvariant(Expr, L)) 4909 SeenLoopVariantSCEVUnknown = true; 4910 return Expr; 4911 } 4912 4913 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 4914 // Only re-write AddRecExprs for this loop. 4915 if (Expr->getLoop() == L) 4916 return Expr->getPostIncExpr(SE); 4917 SeenOtherLoops = true; 4918 return Expr; 4919 } 4920 4921 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } 4922 4923 bool hasSeenOtherLoops() { return SeenOtherLoops; } 4924 4925 private: 4926 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE) 4927 : SCEVRewriteVisitor(SE), L(L) {} 4928 4929 const Loop *L; 4930 bool SeenLoopVariantSCEVUnknown = false; 4931 bool SeenOtherLoops = false; 4932 }; 4933 4934 /// This class evaluates the compare condition by matching it against the 4935 /// condition of loop latch. If there is a match we assume a true value 4936 /// for the condition while building SCEV nodes. 4937 class SCEVBackedgeConditionFolder 4938 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { 4939 public: 4940 static const SCEV *rewrite(const SCEV *S, const Loop *L, 4941 ScalarEvolution &SE) { 4942 bool IsPosBECond = false; 4943 Value *BECond = nullptr; 4944 if (BasicBlock *Latch = L->getLoopLatch()) { 4945 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); 4946 if (BI && BI->isConditional()) { 4947 assert(BI->getSuccessor(0) != BI->getSuccessor(1) && 4948 "Both outgoing branches should not target same header!"); 4949 BECond = BI->getCondition(); 4950 IsPosBECond = BI->getSuccessor(0) == L->getHeader(); 4951 } else { 4952 return S; 4953 } 4954 } 4955 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); 4956 return Rewriter.visit(S); 4957 } 4958 4959 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 4960 const SCEV *Result = Expr; 4961 bool InvariantF = SE.isLoopInvariant(Expr, L); 4962 4963 if (!InvariantF) { 4964 Instruction *I = cast<Instruction>(Expr->getValue()); 4965 switch (I->getOpcode()) { 4966 case Instruction::Select: { 4967 SelectInst *SI = cast<SelectInst>(I); 4968 std::optional<const SCEV *> Res = 4969 compareWithBackedgeCondition(SI->getCondition()); 4970 if (Res) { 4971 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne(); 4972 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); 4973 } 4974 break; 4975 } 4976 default: { 4977 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I); 4978 if (Res) 4979 Result = *Res; 4980 break; 4981 } 4982 } 4983 } 4984 return Result; 4985 } 4986 4987 private: 4988 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, 4989 bool IsPosBECond, ScalarEvolution &SE) 4990 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), 4991 IsPositiveBECond(IsPosBECond) {} 4992 4993 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC); 4994 4995 const Loop *L; 4996 /// Loop back condition. 4997 Value *BackedgeCond = nullptr; 4998 /// Set to true if loop back is on positive branch condition. 4999 bool IsPositiveBECond; 5000 }; 5001 5002 std::optional<const SCEV *> 5003 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { 5004 5005 // If value matches the backedge condition for loop latch, 5006 // then return a constant evolution node based on loopback 5007 // branch taken. 5008 if (BackedgeCond == IC) 5009 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) 5010 : SE.getZero(Type::getInt1Ty(SE.getContext())); 5011 return std::nullopt; 5012 } 5013 5014 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { 5015 public: 5016 static const SCEV *rewrite(const SCEV *S, const Loop *L, 5017 ScalarEvolution &SE) { 5018 SCEVShiftRewriter Rewriter(L, SE); 5019 const SCEV *Result = Rewriter.visit(S); 5020 return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); 5021 } 5022 5023 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 5024 // Only allow AddRecExprs for this loop. 5025 if (!SE.isLoopInvariant(Expr, L)) 5026 Valid = false; 5027 return Expr; 5028 } 5029 5030 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 5031 if (Expr->getLoop() == L && Expr->isAffine()) 5032 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); 5033 Valid = false; 5034 return Expr; 5035 } 5036 5037 bool isValid() { return Valid; } 5038 5039 private: 5040 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) 5041 : SCEVRewriteVisitor(SE), L(L) {} 5042 5043 const Loop *L; 5044 bool Valid = true; 5045 }; 5046 5047 } // end anonymous namespace 5048 5049 SCEV::NoWrapFlags 5050 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { 5051 if (!AR->isAffine()) 5052 return SCEV::FlagAnyWrap; 5053 5054 using OBO = OverflowingBinaryOperator; 5055 5056 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; 5057 5058 if (!AR->hasNoSelfWrap()) { 5059 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); 5060 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) { 5061 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this)); 5062 const APInt &BECountAP = BECountMax->getAPInt(); 5063 unsigned NoOverflowBitWidth = 5064 BECountAP.getActiveBits() + StepCR.getMinSignedBits(); 5065 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType())) 5066 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNW); 5067 } 5068 } 5069 5070 if (!AR->hasNoSignedWrap()) { 5071 ConstantRange AddRecRange = getSignedRange(AR); 5072 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); 5073 5074 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 5075 Instruction::Add, IncRange, OBO::NoSignedWrap); 5076 if (NSWRegion.contains(AddRecRange)) 5077 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); 5078 } 5079 5080 if (!AR->hasNoUnsignedWrap()) { 5081 ConstantRange AddRecRange = getUnsignedRange(AR); 5082 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); 5083 5084 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 5085 Instruction::Add, IncRange, OBO::NoUnsignedWrap); 5086 if (NUWRegion.contains(AddRecRange)) 5087 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); 5088 } 5089 5090 return Result; 5091 } 5092 5093 SCEV::NoWrapFlags 5094 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { 5095 SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); 5096 5097 if (AR->hasNoSignedWrap()) 5098 return Result; 5099 5100 if (!AR->isAffine()) 5101 return Result; 5102 5103 // This function can be expensive, only try to prove NSW once per AddRec. 5104 if (!SignedWrapViaInductionTried.insert(AR).second) 5105 return Result; 5106 5107 const SCEV *Step = AR->getStepRecurrence(*this); 5108 const Loop *L = AR->getLoop(); 5109 5110 // Check whether the backedge-taken count is SCEVCouldNotCompute. 5111 // Note that this serves two purposes: It filters out loops that are 5112 // simply not analyzable, and it covers the case where this code is 5113 // being called from within backedge-taken count analysis, such that 5114 // attempting to ask for the backedge-taken count would likely result 5115 // in infinite recursion. In the later case, the analysis code will 5116 // cope with a conservative value, and it will take care to purge 5117 // that value once it has finished. 5118 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); 5119 5120 // Normally, in the cases we can prove no-overflow via a 5121 // backedge guarding condition, we can also compute a backedge 5122 // taken count for the loop. The exceptions are assumptions and 5123 // guards present in the loop -- SCEV is not great at exploiting 5124 // these to compute max backedge taken counts, but can still use 5125 // these to prove lack of overflow. Use this fact to avoid 5126 // doing extra work that may not pay off. 5127 5128 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && 5129 AC.assumptions().empty()) 5130 return Result; 5131 5132 // If the backedge is guarded by a comparison with the pre-inc value the 5133 // addrec is safe. Also, if the entry is guarded by a comparison with the 5134 // start value and the backedge is guarded by a comparison with the post-inc 5135 // value, the addrec is safe. 5136 ICmpInst::Predicate Pred; 5137 const SCEV *OverflowLimit = 5138 getSignedOverflowLimitForStep(Step, &Pred, this); 5139 if (OverflowLimit && 5140 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || 5141 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { 5142 Result = setFlags(Result, SCEV::FlagNSW); 5143 } 5144 return Result; 5145 } 5146 SCEV::NoWrapFlags 5147 ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { 5148 SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); 5149 5150 if (AR->hasNoUnsignedWrap()) 5151 return Result; 5152 5153 if (!AR->isAffine()) 5154 return Result; 5155 5156 // This function can be expensive, only try to prove NUW once per AddRec. 5157 if (!UnsignedWrapViaInductionTried.insert(AR).second) 5158 return Result; 5159 5160 const SCEV *Step = AR->getStepRecurrence(*this); 5161 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 5162 const Loop *L = AR->getLoop(); 5163 5164 // Check whether the backedge-taken count is SCEVCouldNotCompute. 5165 // Note that this serves two purposes: It filters out loops that are 5166 // simply not analyzable, and it covers the case where this code is 5167 // being called from within backedge-taken count analysis, such that 5168 // attempting to ask for the backedge-taken count would likely result 5169 // in infinite recursion. In the later case, the analysis code will 5170 // cope with a conservative value, and it will take care to purge 5171 // that value once it has finished. 5172 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); 5173 5174 // Normally, in the cases we can prove no-overflow via a 5175 // backedge guarding condition, we can also compute a backedge 5176 // taken count for the loop. The exceptions are assumptions and 5177 // guards present in the loop -- SCEV is not great at exploiting 5178 // these to compute max backedge taken counts, but can still use 5179 // these to prove lack of overflow. Use this fact to avoid 5180 // doing extra work that may not pay off. 5181 5182 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && 5183 AC.assumptions().empty()) 5184 return Result; 5185 5186 // If the backedge is guarded by a comparison with the pre-inc value the 5187 // addrec is safe. Also, if the entry is guarded by a comparison with the 5188 // start value and the backedge is guarded by a comparison with the post-inc 5189 // value, the addrec is safe. 5190 if (isKnownPositive(Step)) { 5191 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - 5192 getUnsignedRangeMax(Step)); 5193 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || 5194 isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { 5195 Result = setFlags(Result, SCEV::FlagNUW); 5196 } 5197 } 5198 5199 return Result; 5200 } 5201 5202 namespace { 5203 5204 /// Represents an abstract binary operation. This may exist as a 5205 /// normal instruction or constant expression, or may have been 5206 /// derived from an expression tree. 5207 struct BinaryOp { 5208 unsigned Opcode; 5209 Value *LHS; 5210 Value *RHS; 5211 bool IsNSW = false; 5212 bool IsNUW = false; 5213 5214 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or 5215 /// constant expression. 5216 Operator *Op = nullptr; 5217 5218 explicit BinaryOp(Operator *Op) 5219 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), 5220 Op(Op) { 5221 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { 5222 IsNSW = OBO->hasNoSignedWrap(); 5223 IsNUW = OBO->hasNoUnsignedWrap(); 5224 } 5225 } 5226 5227 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, 5228 bool IsNUW = false) 5229 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {} 5230 }; 5231 5232 } // end anonymous namespace 5233 5234 /// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure. 5235 static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL, 5236 AssumptionCache &AC, 5237 const DominatorTree &DT, 5238 const Instruction *CxtI) { 5239 auto *Op = dyn_cast<Operator>(V); 5240 if (!Op) 5241 return std::nullopt; 5242 5243 // Implementation detail: all the cleverness here should happen without 5244 // creating new SCEV expressions -- our caller knowns tricks to avoid creating 5245 // SCEV expressions when possible, and we should not break that. 5246 5247 switch (Op->getOpcode()) { 5248 case Instruction::Add: 5249 case Instruction::Sub: 5250 case Instruction::Mul: 5251 case Instruction::UDiv: 5252 case Instruction::URem: 5253 case Instruction::And: 5254 case Instruction::AShr: 5255 case Instruction::Shl: 5256 return BinaryOp(Op); 5257 5258 case Instruction::Or: { 5259 // Convert or disjoint into add nuw nsw. 5260 if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) 5261 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1), 5262 /*IsNSW=*/true, /*IsNUW=*/true); 5263 return BinaryOp(Op); 5264 } 5265 5266 case Instruction::Xor: 5267 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) 5268 // If the RHS of the xor is a signmask, then this is just an add. 5269 // Instcombine turns add of signmask into xor as a strength reduction step. 5270 if (RHSC->getValue().isSignMask()) 5271 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); 5272 // Binary `xor` is a bit-wise `add`. 5273 if (V->getType()->isIntegerTy(1)) 5274 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); 5275 return BinaryOp(Op); 5276 5277 case Instruction::LShr: 5278 // Turn logical shift right of a constant into a unsigned divide. 5279 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { 5280 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); 5281 5282 // If the shift count is not less than the bitwidth, the result of 5283 // the shift is undefined. Don't try to analyze it, because the 5284 // resolution chosen here may differ from the resolution chosen in 5285 // other parts of the compiler. 5286 if (SA->getValue().ult(BitWidth)) { 5287 Constant *X = 5288 ConstantInt::get(SA->getContext(), 5289 APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 5290 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); 5291 } 5292 } 5293 return BinaryOp(Op); 5294 5295 case Instruction::ExtractValue: { 5296 auto *EVI = cast<ExtractValueInst>(Op); 5297 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) 5298 break; 5299 5300 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()); 5301 if (!WO) 5302 break; 5303 5304 Instruction::BinaryOps BinOp = WO->getBinaryOp(); 5305 bool Signed = WO->isSigned(); 5306 // TODO: Should add nuw/nsw flags for mul as well. 5307 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT)) 5308 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS()); 5309 5310 // Now that we know that all uses of the arithmetic-result component of 5311 // CI are guarded by the overflow check, we can go ahead and pretend 5312 // that the arithmetic is non-overflowing. 5313 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(), 5314 /* IsNSW = */ Signed, /* IsNUW = */ !Signed); 5315 } 5316 5317 default: 5318 break; 5319 } 5320 5321 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same 5322 // semantics as a Sub, return a binary sub expression. 5323 if (auto *II = dyn_cast<IntrinsicInst>(V)) 5324 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg) 5325 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1)); 5326 5327 return std::nullopt; 5328 } 5329 5330 /// Helper function to createAddRecFromPHIWithCasts. We have a phi 5331 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via 5332 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the 5333 /// way. This function checks if \p Op, an operand of this SCEVAddExpr, 5334 /// follows one of the following patterns: 5335 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) 5336 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) 5337 /// If the SCEV expression of \p Op conforms with one of the expected patterns 5338 /// we return the type of the truncation operation, and indicate whether the 5339 /// truncated type should be treated as signed/unsigned by setting 5340 /// \p Signed to true/false, respectively. 5341 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, 5342 bool &Signed, ScalarEvolution &SE) { 5343 // The case where Op == SymbolicPHI (that is, with no type conversions on 5344 // the way) is handled by the regular add recurrence creating logic and 5345 // would have already been triggered in createAddRecForPHI. Reaching it here 5346 // means that createAddRecFromPHI had failed for this PHI before (e.g., 5347 // because one of the other operands of the SCEVAddExpr updating this PHI is 5348 // not invariant). 5349 // 5350 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in 5351 // this case predicates that allow us to prove that Op == SymbolicPHI will 5352 // be added. 5353 if (Op == SymbolicPHI) 5354 return nullptr; 5355 5356 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType()); 5357 unsigned NewBits = SE.getTypeSizeInBits(Op->getType()); 5358 if (SourceBits != NewBits) 5359 return nullptr; 5360 5361 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op); 5362 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op); 5363 if (!SExt && !ZExt) 5364 return nullptr; 5365 const SCEVTruncateExpr *Trunc = 5366 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand()) 5367 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand()); 5368 if (!Trunc) 5369 return nullptr; 5370 const SCEV *X = Trunc->getOperand(); 5371 if (X != SymbolicPHI) 5372 return nullptr; 5373 Signed = SExt != nullptr; 5374 return Trunc->getType(); 5375 } 5376 5377 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { 5378 if (!PN->getType()->isIntegerTy()) 5379 return nullptr; 5380 const Loop *L = LI.getLoopFor(PN->getParent()); 5381 if (!L || L->getHeader() != PN->getParent()) 5382 return nullptr; 5383 return L; 5384 } 5385 5386 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the 5387 // computation that updates the phi follows the following pattern: 5388 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum 5389 // which correspond to a phi->trunc->sext/zext->add->phi update chain. 5390 // If so, try to see if it can be rewritten as an AddRecExpr under some 5391 // Predicates. If successful, return them as a pair. Also cache the results 5392 // of the analysis. 5393 // 5394 // Example usage scenario: 5395 // Say the Rewriter is called for the following SCEV: 5396 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step) 5397 // where: 5398 // %X = phi i64 (%Start, %BEValue) 5399 // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X), 5400 // and call this function with %SymbolicPHI = %X. 5401 // 5402 // The analysis will find that the value coming around the backedge has 5403 // the following SCEV: 5404 // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step) 5405 // Upon concluding that this matches the desired pattern, the function 5406 // will return the pair {NewAddRec, SmallPredsVec} where: 5407 // NewAddRec = {%Start,+,%Step} 5408 // SmallPredsVec = {P1, P2, P3} as follows: 5409 // P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw> 5410 // P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64) 5411 // P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64) 5412 // The returned pair means that SymbolicPHI can be rewritten into NewAddRec 5413 // under the predicates {P1,P2,P3}. 5414 // This predicated rewrite will be cached in PredicatedSCEVRewrites: 5415 // PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} 5416 // 5417 // TODO's: 5418 // 5419 // 1) Extend the Induction descriptor to also support inductions that involve 5420 // casts: When needed (namely, when we are called in the context of the 5421 // vectorizer induction analysis), a Set of cast instructions will be 5422 // populated by this method, and provided back to isInductionPHI. This is 5423 // needed to allow the vectorizer to properly record them to be ignored by 5424 // the cost model and to avoid vectorizing them (otherwise these casts, 5425 // which are redundant under the runtime overflow checks, will be 5426 // vectorized, which can be costly). 5427 // 5428 // 2) Support additional induction/PHISCEV patterns: We also want to support 5429 // inductions where the sext-trunc / zext-trunc operations (partly) occur 5430 // after the induction update operation (the induction increment): 5431 // 5432 // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) 5433 // which correspond to a phi->add->trunc->sext/zext->phi update chain. 5434 // 5435 // (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) 5436 // which correspond to a phi->trunc->add->sext/zext->phi update chain. 5437 // 5438 // 3) Outline common code with createAddRecFromPHI to avoid duplication. 5439 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> 5440 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { 5441 SmallVector<const SCEVPredicate *, 3> Predicates; 5442 5443 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can 5444 // return an AddRec expression under some predicate. 5445 5446 auto *PN = cast<PHINode>(SymbolicPHI->getValue()); 5447 const Loop *L = isIntegerLoopHeaderPHI(PN, LI); 5448 assert(L && "Expecting an integer loop header phi"); 5449 5450 // The loop may have multiple entrances or multiple exits; we can analyze 5451 // this phi as an addrec if it has a unique entry value and a unique 5452 // backedge value. 5453 Value *BEValueV = nullptr, *StartValueV = nullptr; 5454 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 5455 Value *V = PN->getIncomingValue(i); 5456 if (L->contains(PN->getIncomingBlock(i))) { 5457 if (!BEValueV) { 5458 BEValueV = V; 5459 } else if (BEValueV != V) { 5460 BEValueV = nullptr; 5461 break; 5462 } 5463 } else if (!StartValueV) { 5464 StartValueV = V; 5465 } else if (StartValueV != V) { 5466 StartValueV = nullptr; 5467 break; 5468 } 5469 } 5470 if (!BEValueV || !StartValueV) 5471 return std::nullopt; 5472 5473 const SCEV *BEValue = getSCEV(BEValueV); 5474 5475 // If the value coming around the backedge is an add with the symbolic 5476 // value we just inserted, possibly with casts that we can ignore under 5477 // an appropriate runtime guard, then we found a simple induction variable! 5478 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue); 5479 if (!Add) 5480 return std::nullopt; 5481 5482 // If there is a single occurrence of the symbolic value, possibly 5483 // casted, replace it with a recurrence. 5484 unsigned FoundIndex = Add->getNumOperands(); 5485 Type *TruncTy = nullptr; 5486 bool Signed; 5487 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 5488 if ((TruncTy = 5489 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this))) 5490 if (FoundIndex == e) { 5491 FoundIndex = i; 5492 break; 5493 } 5494 5495 if (FoundIndex == Add->getNumOperands()) 5496 return std::nullopt; 5497 5498 // Create an add with everything but the specified operand. 5499 SmallVector<const SCEV *, 8> Ops; 5500 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 5501 if (i != FoundIndex) 5502 Ops.push_back(Add->getOperand(i)); 5503 const SCEV *Accum = getAddExpr(Ops); 5504 5505 // The runtime checks will not be valid if the step amount is 5506 // varying inside the loop. 5507 if (!isLoopInvariant(Accum, L)) 5508 return std::nullopt; 5509 5510 // *** Part2: Create the predicates 5511 5512 // Analysis was successful: we have a phi-with-cast pattern for which we 5513 // can return an AddRec expression under the following predicates: 5514 // 5515 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum) 5516 // fits within the truncated type (does not overflow) for i = 0 to n-1. 5517 // P2: An Equal predicate that guarantees that 5518 // Start = (Ext ix (Trunc iy (Start) to ix) to iy) 5519 // P3: An Equal predicate that guarantees that 5520 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy) 5521 // 5522 // As we next prove, the above predicates guarantee that: 5523 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy) 5524 // 5525 // 5526 // More formally, we want to prove that: 5527 // Expr(i+1) = Start + (i+1) * Accum 5528 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum 5529 // 5530 // Given that: 5531 // 1) Expr(0) = Start 5532 // 2) Expr(1) = Start + Accum 5533 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2 5534 // 3) Induction hypothesis (step i): 5535 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum 5536 // 5537 // Proof: 5538 // Expr(i+1) = 5539 // = Start + (i+1)*Accum 5540 // = (Start + i*Accum) + Accum 5541 // = Expr(i) + Accum 5542 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum 5543 // :: from step i 5544 // 5545 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum 5546 // 5547 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) 5548 // + (Ext ix (Trunc iy (Accum) to ix) to iy) 5549 // + Accum :: from P3 5550 // 5551 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) 5552 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y) 5553 // 5554 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum 5555 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum 5556 // 5557 // By induction, the same applies to all iterations 1<=i<n: 5558 // 5559 5560 // Create a truncated addrec for which we will add a no overflow check (P1). 5561 const SCEV *StartVal = getSCEV(StartValueV); 5562 const SCEV *PHISCEV = 5563 getAddRecExpr(getTruncateExpr(StartVal, TruncTy), 5564 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); 5565 5566 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr. 5567 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV 5568 // will be constant. 5569 // 5570 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't 5571 // add P1. 5572 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { 5573 SCEVWrapPredicate::IncrementWrapFlags AddedFlags = 5574 Signed ? SCEVWrapPredicate::IncrementNSSW 5575 : SCEVWrapPredicate::IncrementNUSW; 5576 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); 5577 Predicates.push_back(AddRecPred); 5578 } 5579 5580 // Create the Equal Predicates P2,P3: 5581 5582 // It is possible that the predicates P2 and/or P3 are computable at 5583 // compile time due to StartVal and/or Accum being constants. 5584 // If either one is, then we can check that now and escape if either P2 5585 // or P3 is false. 5586 5587 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) 5588 // for each of StartVal and Accum 5589 auto getExtendedExpr = [&](const SCEV *Expr, 5590 bool CreateSignExtend) -> const SCEV * { 5591 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); 5592 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); 5593 const SCEV *ExtendedExpr = 5594 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) 5595 : getZeroExtendExpr(TruncatedExpr, Expr->getType()); 5596 return ExtendedExpr; 5597 }; 5598 5599 // Given: 5600 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy 5601 // = getExtendedExpr(Expr) 5602 // Determine whether the predicate P: Expr == ExtendedExpr 5603 // is known to be false at compile time 5604 auto PredIsKnownFalse = [&](const SCEV *Expr, 5605 const SCEV *ExtendedExpr) -> bool { 5606 return Expr != ExtendedExpr && 5607 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); 5608 }; 5609 5610 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); 5611 if (PredIsKnownFalse(StartVal, StartExtended)) { 5612 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); 5613 return std::nullopt; 5614 } 5615 5616 // The Step is always Signed (because the overflow checks are either 5617 // NSSW or NUSW) 5618 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); 5619 if (PredIsKnownFalse(Accum, AccumExtended)) { 5620 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); 5621 return std::nullopt; 5622 } 5623 5624 auto AppendPredicate = [&](const SCEV *Expr, 5625 const SCEV *ExtendedExpr) -> void { 5626 if (Expr != ExtendedExpr && 5627 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { 5628 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); 5629 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred); 5630 Predicates.push_back(Pred); 5631 } 5632 }; 5633 5634 AppendPredicate(StartVal, StartExtended); 5635 AppendPredicate(Accum, AccumExtended); 5636 5637 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in 5638 // which the casts had been folded away. The caller can rewrite SymbolicPHI 5639 // into NewAR if it will also add the runtime overflow checks specified in 5640 // Predicates. 5641 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); 5642 5643 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = 5644 std::make_pair(NewAR, Predicates); 5645 // Remember the result of the analysis for this SCEV at this locayyytion. 5646 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; 5647 return PredRewrite; 5648 } 5649 5650 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> 5651 ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { 5652 auto *PN = cast<PHINode>(SymbolicPHI->getValue()); 5653 const Loop *L = isIntegerLoopHeaderPHI(PN, LI); 5654 if (!L) 5655 return std::nullopt; 5656 5657 // Check to see if we already analyzed this PHI. 5658 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); 5659 if (I != PredicatedSCEVRewrites.end()) { 5660 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite = 5661 I->second; 5662 // Analysis was done before and failed to create an AddRec: 5663 if (Rewrite.first == SymbolicPHI) 5664 return std::nullopt; 5665 // Analysis was done before and succeeded to create an AddRec under 5666 // a predicate: 5667 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec"); 5668 assert(!(Rewrite.second).empty() && "Expected to find Predicates"); 5669 return Rewrite; 5670 } 5671 5672 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> 5673 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); 5674 5675 // Record in the cache that the analysis failed 5676 if (!Rewrite) { 5677 SmallVector<const SCEVPredicate *, 3> Predicates; 5678 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates}; 5679 return std::nullopt; 5680 } 5681 5682 return Rewrite; 5683 } 5684 5685 // FIXME: This utility is currently required because the Rewriter currently 5686 // does not rewrite this expression: 5687 // {0, +, (sext ix (trunc iy to ix) to iy)} 5688 // into {0, +, %step}, 5689 // even when the following Equal predicate exists: 5690 // "%step == (sext ix (trunc iy to ix) to iy)". 5691 bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( 5692 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const { 5693 if (AR1 == AR2) 5694 return true; 5695 5696 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { 5697 if (Expr1 != Expr2 && 5698 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) && 5699 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE)) 5700 return false; 5701 return true; 5702 }; 5703 5704 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) || 5705 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE))) 5706 return false; 5707 return true; 5708 } 5709 5710 /// A helper function for createAddRecFromPHI to handle simple cases. 5711 /// 5712 /// This function tries to find an AddRec expression for the simplest (yet most 5713 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). 5714 /// If it fails, createAddRecFromPHI will use a more general, but slow, 5715 /// technique for finding the AddRec expression. 5716 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, 5717 Value *BEValueV, 5718 Value *StartValueV) { 5719 const Loop *L = LI.getLoopFor(PN->getParent()); 5720 assert(L && L->getHeader() == PN->getParent()); 5721 assert(BEValueV && StartValueV); 5722 5723 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN); 5724 if (!BO) 5725 return nullptr; 5726 5727 if (BO->Opcode != Instruction::Add) 5728 return nullptr; 5729 5730 const SCEV *Accum = nullptr; 5731 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) 5732 Accum = getSCEV(BO->RHS); 5733 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) 5734 Accum = getSCEV(BO->LHS); 5735 5736 if (!Accum) 5737 return nullptr; 5738 5739 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 5740 if (BO->IsNUW) 5741 Flags = setFlags(Flags, SCEV::FlagNUW); 5742 if (BO->IsNSW) 5743 Flags = setFlags(Flags, SCEV::FlagNSW); 5744 5745 const SCEV *StartVal = getSCEV(StartValueV); 5746 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); 5747 insertValueToMap(PN, PHISCEV); 5748 5749 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { 5750 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), 5751 (SCEV::NoWrapFlags)(AR->getNoWrapFlags() | 5752 proveNoWrapViaConstantRanges(AR))); 5753 } 5754 5755 // We can add Flags to the post-inc expression only if we 5756 // know that it is *undefined behavior* for BEValueV to 5757 // overflow. 5758 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) { 5759 assert(isLoopInvariant(Accum, L) && 5760 "Accum is defined outside L, but is not invariant?"); 5761 if (isAddRecNeverPoison(BEInst, L)) 5762 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); 5763 } 5764 5765 return PHISCEV; 5766 } 5767 5768 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { 5769 const Loop *L = LI.getLoopFor(PN->getParent()); 5770 if (!L || L->getHeader() != PN->getParent()) 5771 return nullptr; 5772 5773 // The loop may have multiple entrances or multiple exits; we can analyze 5774 // this phi as an addrec if it has a unique entry value and a unique 5775 // backedge value. 5776 Value *BEValueV = nullptr, *StartValueV = nullptr; 5777 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 5778 Value *V = PN->getIncomingValue(i); 5779 if (L->contains(PN->getIncomingBlock(i))) { 5780 if (!BEValueV) { 5781 BEValueV = V; 5782 } else if (BEValueV != V) { 5783 BEValueV = nullptr; 5784 break; 5785 } 5786 } else if (!StartValueV) { 5787 StartValueV = V; 5788 } else if (StartValueV != V) { 5789 StartValueV = nullptr; 5790 break; 5791 } 5792 } 5793 if (!BEValueV || !StartValueV) 5794 return nullptr; 5795 5796 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && 5797 "PHI node already processed?"); 5798 5799 // First, try to find AddRec expression without creating a fictituos symbolic 5800 // value for PN. 5801 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) 5802 return S; 5803 5804 // Handle PHI node value symbolically. 5805 const SCEV *SymbolicName = getUnknown(PN); 5806 insertValueToMap(PN, SymbolicName); 5807 5808 // Using this symbolic name for the PHI, analyze the value coming around 5809 // the back-edge. 5810 const SCEV *BEValue = getSCEV(BEValueV); 5811 5812 // NOTE: If BEValue is loop invariant, we know that the PHI node just 5813 // has a special value for the first iteration of the loop. 5814 5815 // If the value coming around the backedge is an add with the symbolic 5816 // value we just inserted, then we found a simple induction variable! 5817 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { 5818 // If there is a single occurrence of the symbolic value, replace it 5819 // with a recurrence. 5820 unsigned FoundIndex = Add->getNumOperands(); 5821 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 5822 if (Add->getOperand(i) == SymbolicName) 5823 if (FoundIndex == e) { 5824 FoundIndex = i; 5825 break; 5826 } 5827 5828 if (FoundIndex != Add->getNumOperands()) { 5829 // Create an add with everything but the specified operand. 5830 SmallVector<const SCEV *, 8> Ops; 5831 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 5832 if (i != FoundIndex) 5833 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), 5834 L, *this)); 5835 const SCEV *Accum = getAddExpr(Ops); 5836 5837 // This is not a valid addrec if the step amount is varying each 5838 // loop iteration, but is not itself an addrec in this loop. 5839 if (isLoopInvariant(Accum, L) || 5840 (isa<SCEVAddRecExpr>(Accum) && 5841 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { 5842 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 5843 5844 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) { 5845 if (BO->Opcode == Instruction::Add && BO->LHS == PN) { 5846 if (BO->IsNUW) 5847 Flags = setFlags(Flags, SCEV::FlagNUW); 5848 if (BO->IsNSW) 5849 Flags = setFlags(Flags, SCEV::FlagNSW); 5850 } 5851 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { 5852 if (GEP->getOperand(0) == PN) { 5853 GEPNoWrapFlags NW = GEP->getNoWrapFlags(); 5854 // If the increment has any nowrap flags, then we know the address 5855 // space cannot be wrapped around. 5856 if (NW != GEPNoWrapFlags::none()) 5857 Flags = setFlags(Flags, SCEV::FlagNW); 5858 // If the GEP is nuw or nusw with non-negative offset, we know that 5859 // no unsigned wrap occurs. We cannot set the nsw flag as only the 5860 // offset is treated as signed, while the base is unsigned. 5861 if (NW.hasNoUnsignedWrap() || 5862 (NW.hasNoUnsignedSignedWrap() && isKnownNonNegative(Accum))) 5863 Flags = setFlags(Flags, SCEV::FlagNUW); 5864 } 5865 5866 // We cannot transfer nuw and nsw flags from subtraction 5867 // operations -- sub nuw X, Y is not the same as add nuw X, -Y 5868 // for instance. 5869 } 5870 5871 const SCEV *StartVal = getSCEV(StartValueV); 5872 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); 5873 5874 // Okay, for the entire analysis of this edge we assumed the PHI 5875 // to be symbolic. We now need to go back and purge all of the 5876 // entries for the scalars that use the symbolic expression. 5877 forgetMemoizedResults(SymbolicName); 5878 insertValueToMap(PN, PHISCEV); 5879 5880 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { 5881 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), 5882 (SCEV::NoWrapFlags)(AR->getNoWrapFlags() | 5883 proveNoWrapViaConstantRanges(AR))); 5884 } 5885 5886 // We can add Flags to the post-inc expression only if we 5887 // know that it is *undefined behavior* for BEValueV to 5888 // overflow. 5889 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) 5890 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) 5891 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); 5892 5893 return PHISCEV; 5894 } 5895 } 5896 } else { 5897 // Otherwise, this could be a loop like this: 5898 // i = 0; for (j = 1; ..; ++j) { .... i = j; } 5899 // In this case, j = {1,+,1} and BEValue is j. 5900 // Because the other in-value of i (0) fits the evolution of BEValue 5901 // i really is an addrec evolution. 5902 // 5903 // We can generalize this saying that i is the shifted value of BEValue 5904 // by one iteration: 5905 // PHI(f(0), f({1,+,1})) --> f({0,+,1}) 5906 5907 // Do not allow refinement in rewriting of BEValue. 5908 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); 5909 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); 5910 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() && 5911 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) { 5912 const SCEV *StartVal = getSCEV(StartValueV); 5913 if (Start == StartVal) { 5914 // Okay, for the entire analysis of this edge we assumed the PHI 5915 // to be symbolic. We now need to go back and purge all of the 5916 // entries for the scalars that use the symbolic expression. 5917 forgetMemoizedResults(SymbolicName); 5918 insertValueToMap(PN, Shifted); 5919 return Shifted; 5920 } 5921 } 5922 } 5923 5924 // Remove the temporary PHI node SCEV that has been inserted while intending 5925 // to create an AddRecExpr for this PHI node. We can not keep this temporary 5926 // as it will prevent later (possibly simpler) SCEV expressions to be added 5927 // to the ValueExprMap. 5928 eraseValueFromMap(PN); 5929 5930 return nullptr; 5931 } 5932 5933 // Try to match a control flow sequence that branches out at BI and merges back 5934 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful 5935 // match. 5936 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, 5937 Value *&C, Value *&LHS, Value *&RHS) { 5938 C = BI->getCondition(); 5939 5940 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); 5941 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); 5942 5943 if (!LeftEdge.isSingleEdge()) 5944 return false; 5945 5946 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); 5947 5948 Use &LeftUse = Merge->getOperandUse(0); 5949 Use &RightUse = Merge->getOperandUse(1); 5950 5951 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { 5952 LHS = LeftUse; 5953 RHS = RightUse; 5954 return true; 5955 } 5956 5957 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { 5958 LHS = RightUse; 5959 RHS = LeftUse; 5960 return true; 5961 } 5962 5963 return false; 5964 } 5965 5966 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { 5967 auto IsReachable = 5968 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; 5969 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { 5970 // Try to match 5971 // 5972 // br %cond, label %left, label %right 5973 // left: 5974 // br label %merge 5975 // right: 5976 // br label %merge 5977 // merge: 5978 // V = phi [ %x, %left ], [ %y, %right ] 5979 // 5980 // as "select %cond, %x, %y" 5981 5982 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); 5983 assert(IDom && "At least the entry block should dominate PN"); 5984 5985 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); 5986 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; 5987 5988 if (BI && BI->isConditional() && 5989 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && 5990 properlyDominates(getSCEV(LHS), PN->getParent()) && 5991 properlyDominates(getSCEV(RHS), PN->getParent())) 5992 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); 5993 } 5994 5995 return nullptr; 5996 } 5997 5998 /// Returns SCEV for the first operand of a phi if all phi operands have 5999 /// identical opcodes and operands 6000 /// eg. 6001 /// a: %add = %a + %b 6002 /// br %c 6003 /// b: %add1 = %a + %b 6004 /// br %c 6005 /// c: %phi = phi [%add, a], [%add1, b] 6006 /// scev(%phi) => scev(%add) 6007 const SCEV * 6008 ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) { 6009 BinaryOperator *CommonInst = nullptr; 6010 // Check if instructions are identical. 6011 for (Value *Incoming : PN->incoming_values()) { 6012 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming); 6013 if (!IncomingInst) 6014 return nullptr; 6015 if (CommonInst) { 6016 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst)) 6017 return nullptr; // Not identical, give up 6018 } else { 6019 // Remember binary operator 6020 CommonInst = IncomingInst; 6021 } 6022 } 6023 if (!CommonInst) 6024 return nullptr; 6025 6026 // Check if SCEV exprs for instructions are identical. 6027 const SCEV *CommonSCEV = getSCEV(CommonInst); 6028 bool SCEVExprsIdentical = 6029 all_of(drop_begin(PN->incoming_values()), 6030 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); }); 6031 return SCEVExprsIdentical ? CommonSCEV : nullptr; 6032 } 6033 6034 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { 6035 if (const SCEV *S = createAddRecFromPHI(PN)) 6036 return S; 6037 6038 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the 6039 // phi node for X. 6040 if (Value *V = simplifyInstruction( 6041 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr, 6042 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) 6043 return getSCEV(V); 6044 6045 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN)) 6046 return S; 6047 6048 if (const SCEV *S = createNodeFromSelectLikePHI(PN)) 6049 return S; 6050 6051 // If it's not a loop phi, we can't handle it yet. 6052 return getUnknown(PN); 6053 } 6054 6055 bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, 6056 SCEVTypes RootKind) { 6057 struct FindClosure { 6058 const SCEV *OperandToFind; 6059 const SCEVTypes RootKind; // Must be a sequential min/max expression. 6060 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind. 6061 6062 bool Found = false; 6063 6064 bool canRecurseInto(SCEVTypes Kind) const { 6065 // We can only recurse into the SCEV expression of the same effective type 6066 // as the type of our root SCEV expression, and into zero-extensions. 6067 return RootKind == Kind || NonSequentialRootKind == Kind || 6068 scZeroExtend == Kind; 6069 }; 6070 6071 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind) 6072 : OperandToFind(OperandToFind), RootKind(RootKind), 6073 NonSequentialRootKind( 6074 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( 6075 RootKind)) {} 6076 6077 bool follow(const SCEV *S) { 6078 Found = S == OperandToFind; 6079 6080 return !isDone() && canRecurseInto(S->getSCEVType()); 6081 } 6082 6083 bool isDone() const { return Found; } 6084 }; 6085 6086 FindClosure FC(OperandToFind, RootKind); 6087 visitAll(Root, FC); 6088 return FC.Found; 6089 } 6090 6091 std::optional<const SCEV *> 6092 ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, 6093 ICmpInst *Cond, 6094 Value *TrueVal, 6095 Value *FalseVal) { 6096 // Try to match some simple smax or umax patterns. 6097 auto *ICI = Cond; 6098 6099 Value *LHS = ICI->getOperand(0); 6100 Value *RHS = ICI->getOperand(1); 6101 6102 switch (ICI->getPredicate()) { 6103 case ICmpInst::ICMP_SLT: 6104 case ICmpInst::ICMP_SLE: 6105 case ICmpInst::ICMP_ULT: 6106 case ICmpInst::ICMP_ULE: 6107 std::swap(LHS, RHS); 6108 [[fallthrough]]; 6109 case ICmpInst::ICMP_SGT: 6110 case ICmpInst::ICMP_SGE: 6111 case ICmpInst::ICMP_UGT: 6112 case ICmpInst::ICMP_UGE: 6113 // a > b ? a+x : b+x -> max(a, b)+x 6114 // a > b ? b+x : a+x -> min(a, b)+x 6115 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty)) { 6116 bool Signed = ICI->isSigned(); 6117 const SCEV *LA = getSCEV(TrueVal); 6118 const SCEV *RA = getSCEV(FalseVal); 6119 const SCEV *LS = getSCEV(LHS); 6120 const SCEV *RS = getSCEV(RHS); 6121 if (LA->getType()->isPointerTy()) { 6122 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA. 6123 // Need to make sure we can't produce weird expressions involving 6124 // negated pointers. 6125 if (LA == LS && RA == RS) 6126 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS); 6127 if (LA == RS && RA == LS) 6128 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS); 6129 } 6130 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * { 6131 if (Op->getType()->isPointerTy()) { 6132 Op = getLosslessPtrToIntExpr(Op); 6133 if (isa<SCEVCouldNotCompute>(Op)) 6134 return Op; 6135 } 6136 if (Signed) 6137 Op = getNoopOrSignExtend(Op, Ty); 6138 else 6139 Op = getNoopOrZeroExtend(Op, Ty); 6140 return Op; 6141 }; 6142 LS = CoerceOperand(LS); 6143 RS = CoerceOperand(RS); 6144 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS)) 6145 break; 6146 const SCEV *LDiff = getMinusSCEV(LA, LS); 6147 const SCEV *RDiff = getMinusSCEV(RA, RS); 6148 if (LDiff == RDiff) 6149 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS), 6150 LDiff); 6151 LDiff = getMinusSCEV(LA, RS); 6152 RDiff = getMinusSCEV(RA, LS); 6153 if (LDiff == RDiff) 6154 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS), 6155 LDiff); 6156 } 6157 break; 6158 case ICmpInst::ICMP_NE: 6159 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y 6160 std::swap(TrueVal, FalseVal); 6161 [[fallthrough]]; 6162 case ICmpInst::ICMP_EQ: 6163 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1 6164 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty) && 6165 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { 6166 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty); 6167 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y 6168 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y 6169 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x 6170 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y 6171 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1)) 6172 return getAddExpr(getUMaxExpr(X, C), Y); 6173 } 6174 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...)) 6175 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...)) 6176 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...) 6177 // -> umin_seq(x, umin (..., umin_seq(...), ...)) 6178 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() && 6179 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) { 6180 const SCEV *X = getSCEV(LHS); 6181 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X)) 6182 X = ZExt->getOperand(); 6183 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) { 6184 const SCEV *FalseValExpr = getSCEV(FalseVal); 6185 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr)) 6186 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr, 6187 /*Sequential=*/true); 6188 } 6189 } 6190 break; 6191 default: 6192 break; 6193 } 6194 6195 return std::nullopt; 6196 } 6197 6198 static std::optional<const SCEV *> 6199 createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, 6200 const SCEV *TrueExpr, const SCEV *FalseExpr) { 6201 assert(CondExpr->getType()->isIntegerTy(1) && 6202 TrueExpr->getType() == FalseExpr->getType() && 6203 TrueExpr->getType()->isIntegerTy(1) && 6204 "Unexpected operands of a select."); 6205 6206 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0) 6207 // --> C + (umin_seq cond, x - C) 6208 // 6209 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C)) 6210 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0) 6211 // --> C + (umin_seq ~cond, x - C) 6212 6213 // FIXME: while we can't legally model the case where both of the hands 6214 // are fully variable, we only require that the *difference* is constant. 6215 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr)) 6216 return std::nullopt; 6217 6218 const SCEV *X, *C; 6219 if (isa<SCEVConstant>(TrueExpr)) { 6220 CondExpr = SE->getNotSCEV(CondExpr); 6221 X = FalseExpr; 6222 C = TrueExpr; 6223 } else { 6224 X = TrueExpr; 6225 C = FalseExpr; 6226 } 6227 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C), 6228 /*Sequential=*/true)); 6229 } 6230 6231 static std::optional<const SCEV *> 6232 createNodeForSelectViaUMinSeq(ScalarEvolution *SE, Value *Cond, Value *TrueVal, 6233 Value *FalseVal) { 6234 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal)) 6235 return std::nullopt; 6236 6237 const auto *SECond = SE->getSCEV(Cond); 6238 const auto *SETrue = SE->getSCEV(TrueVal); 6239 const auto *SEFalse = SE->getSCEV(FalseVal); 6240 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse); 6241 } 6242 6243 const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( 6244 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) { 6245 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?"); 6246 assert(TrueVal->getType() == FalseVal->getType() && 6247 V->getType() == TrueVal->getType() && 6248 "Types of select hands and of the result must match."); 6249 6250 // For now, only deal with i1-typed `select`s. 6251 if (!V->getType()->isIntegerTy(1)) 6252 return getUnknown(V); 6253 6254 if (std::optional<const SCEV *> S = 6255 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal)) 6256 return *S; 6257 6258 return getUnknown(V); 6259 } 6260 6261 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, 6262 Value *TrueVal, 6263 Value *FalseVal) { 6264 // Handle "constant" branch or select. This can occur for instance when a 6265 // loop pass transforms an inner loop and moves on to process the outer loop. 6266 if (auto *CI = dyn_cast<ConstantInt>(Cond)) 6267 return getSCEV(CI->isOne() ? TrueVal : FalseVal); 6268 6269 if (auto *I = dyn_cast<Instruction>(V)) { 6270 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) { 6271 if (std::optional<const SCEV *> S = 6272 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI, 6273 TrueVal, FalseVal)) 6274 return *S; 6275 } 6276 } 6277 6278 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal); 6279 } 6280 6281 /// Expand GEP instructions into add and multiply operations. This allows them 6282 /// to be analyzed by regular SCEV code. 6283 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { 6284 assert(GEP->getSourceElementType()->isSized() && 6285 "GEP source element type must be sized"); 6286 6287 SmallVector<const SCEV *, 4> IndexExprs; 6288 for (Value *Index : GEP->indices()) 6289 IndexExprs.push_back(getSCEV(Index)); 6290 return getGEPExpr(GEP, IndexExprs); 6291 } 6292 6293 APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { 6294 uint64_t BitWidth = getTypeSizeInBits(S->getType()); 6295 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { 6296 return TrailingZeros >= BitWidth 6297 ? APInt::getZero(BitWidth) 6298 : APInt::getOneBitSet(BitWidth, TrailingZeros); 6299 }; 6300 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) { 6301 // The result is GCD of all operands results. 6302 APInt Res = getConstantMultiple(N->getOperand(0)); 6303 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I) 6304 Res = APIntOps::GreatestCommonDivisor( 6305 Res, getConstantMultiple(N->getOperand(I))); 6306 return Res; 6307 }; 6308 6309 switch (S->getSCEVType()) { 6310 case scConstant: 6311 return cast<SCEVConstant>(S)->getAPInt(); 6312 case scPtrToInt: 6313 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand()); 6314 case scUDivExpr: 6315 case scVScale: 6316 return APInt(BitWidth, 1); 6317 case scTruncate: { 6318 // Only multiples that are a power of 2 will hold after truncation. 6319 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S); 6320 uint32_t TZ = getMinTrailingZeros(T->getOperand()); 6321 return GetShiftedByZeros(TZ); 6322 } 6323 case scZeroExtend: { 6324 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S); 6325 return getConstantMultiple(Z->getOperand()).zext(BitWidth); 6326 } 6327 case scSignExtend: { 6328 // Only multiples that are a power of 2 will hold after sext. 6329 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S); 6330 uint32_t TZ = getMinTrailingZeros(E->getOperand()); 6331 return GetShiftedByZeros(TZ); 6332 } 6333 case scMulExpr: { 6334 const SCEVMulExpr *M = cast<SCEVMulExpr>(S); 6335 if (M->hasNoUnsignedWrap()) { 6336 // The result is the product of all operand results. 6337 APInt Res = getConstantMultiple(M->getOperand(0)); 6338 for (const SCEV *Operand : M->operands().drop_front()) 6339 Res = Res * getConstantMultiple(Operand); 6340 return Res; 6341 } 6342 6343 // If there are no wrap guarentees, find the trailing zeros, which is the 6344 // sum of trailing zeros for all its operands. 6345 uint32_t TZ = 0; 6346 for (const SCEV *Operand : M->operands()) 6347 TZ += getMinTrailingZeros(Operand); 6348 return GetShiftedByZeros(TZ); 6349 } 6350 case scAddExpr: 6351 case scAddRecExpr: { 6352 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S); 6353 if (N->hasNoUnsignedWrap()) 6354 return GetGCDMultiple(N); 6355 // Find the trailing bits, which is the minimum of its operands. 6356 uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); 6357 for (const SCEV *Operand : N->operands().drop_front()) 6358 TZ = std::min(TZ, getMinTrailingZeros(Operand)); 6359 return GetShiftedByZeros(TZ); 6360 } 6361 case scUMaxExpr: 6362 case scSMaxExpr: 6363 case scUMinExpr: 6364 case scSMinExpr: 6365 case scSequentialUMinExpr: 6366 return GetGCDMultiple(cast<SCEVNAryExpr>(S)); 6367 case scUnknown: { 6368 // ask ValueTracking for known bits 6369 const SCEVUnknown *U = cast<SCEVUnknown>(S); 6370 unsigned Known = 6371 computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT) 6372 .countMinTrailingZeros(); 6373 return GetShiftedByZeros(Known); 6374 } 6375 case scCouldNotCompute: 6376 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 6377 } 6378 llvm_unreachable("Unknown SCEV kind!"); 6379 } 6380 6381 APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { 6382 auto I = ConstantMultipleCache.find(S); 6383 if (I != ConstantMultipleCache.end()) 6384 return I->second; 6385 6386 APInt Result = getConstantMultipleImpl(S); 6387 auto InsertPair = ConstantMultipleCache.insert({S, Result}); 6388 assert(InsertPair.second && "Should insert a new key"); 6389 return InsertPair.first->second; 6390 } 6391 6392 APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { 6393 APInt Multiple = getConstantMultiple(S); 6394 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; 6395 } 6396 6397 uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { 6398 return std::min(getConstantMultiple(S).countTrailingZeros(), 6399 (unsigned)getTypeSizeInBits(S->getType())); 6400 } 6401 6402 /// Helper method to assign a range to V from metadata present in the IR. 6403 static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) { 6404 if (Instruction *I = dyn_cast<Instruction>(V)) { 6405 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) 6406 return getConstantRangeFromMetadata(*MD); 6407 if (const auto *CB = dyn_cast<CallBase>(V)) 6408 if (std::optional<ConstantRange> Range = CB->getRange()) 6409 return Range; 6410 } 6411 if (auto *A = dyn_cast<Argument>(V)) 6412 if (std::optional<ConstantRange> Range = A->getRange()) 6413 return Range; 6414 6415 return std::nullopt; 6416 } 6417 6418 void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec, 6419 SCEV::NoWrapFlags Flags) { 6420 if (AddRec->getNoWrapFlags(Flags) != Flags) { 6421 AddRec->setNoWrapFlags(Flags); 6422 UnsignedRanges.erase(AddRec); 6423 SignedRanges.erase(AddRec); 6424 ConstantMultipleCache.erase(AddRec); 6425 } 6426 } 6427 6428 ConstantRange ScalarEvolution:: 6429 getRangeForUnknownRecurrence(const SCEVUnknown *U) { 6430 const DataLayout &DL = getDataLayout(); 6431 6432 unsigned BitWidth = getTypeSizeInBits(U->getType()); 6433 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true); 6434 6435 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then 6436 // use information about the trip count to improve our available range. Note 6437 // that the trip count independent cases are already handled by known bits. 6438 // WARNING: The definition of recurrence used here is subtly different than 6439 // the one used by AddRec (and thus most of this file). Step is allowed to 6440 // be arbitrarily loop varying here, where AddRec allows only loop invariant 6441 // and other addrecs in the same loop (for non-affine addrecs). The code 6442 // below intentionally handles the case where step is not loop invariant. 6443 auto *P = dyn_cast<PHINode>(U->getValue()); 6444 if (!P) 6445 return FullSet; 6446 6447 // Make sure that no Phi input comes from an unreachable block. Otherwise, 6448 // even the values that are not available in these blocks may come from them, 6449 // and this leads to false-positive recurrence test. 6450 for (auto *Pred : predecessors(P->getParent())) 6451 if (!DT.isReachableFromEntry(Pred)) 6452 return FullSet; 6453 6454 BinaryOperator *BO; 6455 Value *Start, *Step; 6456 if (!matchSimpleRecurrence(P, BO, Start, Step)) 6457 return FullSet; 6458 6459 // If we found a recurrence in reachable code, we must be in a loop. Note 6460 // that BO might be in some subloop of L, and that's completely okay. 6461 auto *L = LI.getLoopFor(P->getParent()); 6462 assert(L && L->getHeader() == P->getParent()); 6463 if (!L->contains(BO->getParent())) 6464 // NOTE: This bailout should be an assert instead. However, asserting 6465 // the condition here exposes a case where LoopFusion is querying SCEV 6466 // with malformed loop information during the midst of the transform. 6467 // There doesn't appear to be an obvious fix, so for the moment bailout 6468 // until the caller issue can be fixed. PR49566 tracks the bug. 6469 return FullSet; 6470 6471 // TODO: Extend to other opcodes such as mul, and div 6472 switch (BO->getOpcode()) { 6473 default: 6474 return FullSet; 6475 case Instruction::AShr: 6476 case Instruction::LShr: 6477 case Instruction::Shl: 6478 break; 6479 }; 6480 6481 if (BO->getOperand(0) != P) 6482 // TODO: Handle the power function forms some day. 6483 return FullSet; 6484 6485 unsigned TC = getSmallConstantMaxTripCount(L); 6486 if (!TC || TC >= BitWidth) 6487 return FullSet; 6488 6489 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT); 6490 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT); 6491 assert(KnownStart.getBitWidth() == BitWidth && 6492 KnownStep.getBitWidth() == BitWidth); 6493 6494 // Compute total shift amount, being careful of overflow and bitwidths. 6495 auto MaxShiftAmt = KnownStep.getMaxValue(); 6496 APInt TCAP(BitWidth, TC-1); 6497 bool Overflow = false; 6498 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow); 6499 if (Overflow) 6500 return FullSet; 6501 6502 switch (BO->getOpcode()) { 6503 default: 6504 llvm_unreachable("filtered out above"); 6505 case Instruction::AShr: { 6506 // For each ashr, three cases: 6507 // shift = 0 => unchanged value 6508 // saturation => 0 or -1 6509 // other => a value closer to zero (of the same sign) 6510 // Thus, the end value is closer to zero than the start. 6511 auto KnownEnd = KnownBits::ashr(KnownStart, 6512 KnownBits::makeConstant(TotalShift)); 6513 if (KnownStart.isNonNegative()) 6514 // Analogous to lshr (simply not yet canonicalized) 6515 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(), 6516 KnownStart.getMaxValue() + 1); 6517 if (KnownStart.isNegative()) 6518 // End >=u Start && End <=s Start 6519 return ConstantRange::getNonEmpty(KnownStart.getMinValue(), 6520 KnownEnd.getMaxValue() + 1); 6521 break; 6522 } 6523 case Instruction::LShr: { 6524 // For each lshr, three cases: 6525 // shift = 0 => unchanged value 6526 // saturation => 0 6527 // other => a smaller positive number 6528 // Thus, the low end of the unsigned range is the last value produced. 6529 auto KnownEnd = KnownBits::lshr(KnownStart, 6530 KnownBits::makeConstant(TotalShift)); 6531 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(), 6532 KnownStart.getMaxValue() + 1); 6533 } 6534 case Instruction::Shl: { 6535 // Iff no bits are shifted out, value increases on every shift. 6536 auto KnownEnd = KnownBits::shl(KnownStart, 6537 KnownBits::makeConstant(TotalShift)); 6538 if (TotalShift.ult(KnownStart.countMinLeadingZeros())) 6539 return ConstantRange(KnownStart.getMinValue(), 6540 KnownEnd.getMaxValue() + 1); 6541 break; 6542 } 6543 }; 6544 return FullSet; 6545 } 6546 6547 const ConstantRange & 6548 ScalarEvolution::getRangeRefIter(const SCEV *S, 6549 ScalarEvolution::RangeSignHint SignHint) { 6550 DenseMap<const SCEV *, ConstantRange> &Cache = 6551 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges 6552 : SignedRanges; 6553 SmallVector<const SCEV *> WorkList; 6554 SmallPtrSet<const SCEV *, 8> Seen; 6555 6556 // Add Expr to the worklist, if Expr is either an N-ary expression or a 6557 // SCEVUnknown PHI node. 6558 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) { 6559 if (!Seen.insert(Expr).second) 6560 return; 6561 if (Cache.contains(Expr)) 6562 return; 6563 switch (Expr->getSCEVType()) { 6564 case scUnknown: 6565 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue())) 6566 break; 6567 [[fallthrough]]; 6568 case scConstant: 6569 case scVScale: 6570 case scTruncate: 6571 case scZeroExtend: 6572 case scSignExtend: 6573 case scPtrToInt: 6574 case scAddExpr: 6575 case scMulExpr: 6576 case scUDivExpr: 6577 case scAddRecExpr: 6578 case scUMaxExpr: 6579 case scSMaxExpr: 6580 case scUMinExpr: 6581 case scSMinExpr: 6582 case scSequentialUMinExpr: 6583 WorkList.push_back(Expr); 6584 break; 6585 case scCouldNotCompute: 6586 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 6587 } 6588 }; 6589 AddToWorklist(S); 6590 6591 // Build worklist by queuing operands of N-ary expressions and phi nodes. 6592 for (unsigned I = 0; I != WorkList.size(); ++I) { 6593 const SCEV *P = WorkList[I]; 6594 auto *UnknownS = dyn_cast<SCEVUnknown>(P); 6595 // If it is not a `SCEVUnknown`, just recurse into operands. 6596 if (!UnknownS) { 6597 for (const SCEV *Op : P->operands()) 6598 AddToWorklist(Op); 6599 continue; 6600 } 6601 // `SCEVUnknown`'s require special treatment. 6602 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) { 6603 if (!PendingPhiRangesIter.insert(P).second) 6604 continue; 6605 for (auto &Op : reverse(P->operands())) 6606 AddToWorklist(getSCEV(Op)); 6607 } 6608 } 6609 6610 if (!WorkList.empty()) { 6611 // Use getRangeRef to compute ranges for items in the worklist in reverse 6612 // order. This will force ranges for earlier operands to be computed before 6613 // their users in most cases. 6614 for (const SCEV *P : reverse(drop_begin(WorkList))) { 6615 getRangeRef(P, SignHint); 6616 6617 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P)) 6618 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) 6619 PendingPhiRangesIter.erase(P); 6620 } 6621 } 6622 6623 return getRangeRef(S, SignHint, 0); 6624 } 6625 6626 /// Determine the range for a particular SCEV. If SignHint is 6627 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges 6628 /// with a "cleaner" unsigned (resp. signed) representation. 6629 const ConstantRange &ScalarEvolution::getRangeRef( 6630 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) { 6631 DenseMap<const SCEV *, ConstantRange> &Cache = 6632 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges 6633 : SignedRanges; 6634 ConstantRange::PreferredRangeType RangeType = 6635 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned 6636 : ConstantRange::Signed; 6637 6638 // See if we've computed this range already. 6639 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S); 6640 if (I != Cache.end()) 6641 return I->second; 6642 6643 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 6644 return setRange(C, SignHint, ConstantRange(C->getAPInt())); 6645 6646 // Switch to iteratively computing the range for S, if it is part of a deeply 6647 // nested expression. 6648 if (Depth > RangeIterThreshold) 6649 return getRangeRefIter(S, SignHint); 6650 6651 unsigned BitWidth = getTypeSizeInBits(S->getType()); 6652 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); 6653 using OBO = OverflowingBinaryOperator; 6654 6655 // If the value has known zeros, the maximum value will have those known zeros 6656 // as well. 6657 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { 6658 APInt Multiple = getNonZeroConstantMultiple(S); 6659 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple); 6660 if (!Remainder.isZero()) 6661 ConservativeResult = 6662 ConstantRange(APInt::getMinValue(BitWidth), 6663 APInt::getMaxValue(BitWidth) - Remainder + 1); 6664 } 6665 else { 6666 uint32_t TZ = getMinTrailingZeros(S); 6667 if (TZ != 0) { 6668 ConservativeResult = ConstantRange( 6669 APInt::getSignedMinValue(BitWidth), 6670 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); 6671 } 6672 } 6673 6674 switch (S->getSCEVType()) { 6675 case scConstant: 6676 llvm_unreachable("Already handled above."); 6677 case scVScale: 6678 return setRange(S, SignHint, getVScaleRange(&F, BitWidth)); 6679 case scTruncate: { 6680 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S); 6681 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); 6682 return setRange( 6683 Trunc, SignHint, 6684 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType)); 6685 } 6686 case scZeroExtend: { 6687 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S); 6688 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1); 6689 return setRange( 6690 ZExt, SignHint, 6691 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType)); 6692 } 6693 case scSignExtend: { 6694 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S); 6695 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1); 6696 return setRange( 6697 SExt, SignHint, 6698 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType)); 6699 } 6700 case scPtrToInt: { 6701 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S); 6702 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1); 6703 return setRange(PtrToInt, SignHint, X); 6704 } 6705 case scAddExpr: { 6706 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S); 6707 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1); 6708 unsigned WrapType = OBO::AnyWrap; 6709 if (Add->hasNoSignedWrap()) 6710 WrapType |= OBO::NoSignedWrap; 6711 if (Add->hasNoUnsignedWrap()) 6712 WrapType |= OBO::NoUnsignedWrap; 6713 for (const SCEV *Op : drop_begin(Add->operands())) 6714 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType, 6715 RangeType); 6716 return setRange(Add, SignHint, 6717 ConservativeResult.intersectWith(X, RangeType)); 6718 } 6719 case scMulExpr: { 6720 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S); 6721 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1); 6722 for (const SCEV *Op : drop_begin(Mul->operands())) 6723 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1)); 6724 return setRange(Mul, SignHint, 6725 ConservativeResult.intersectWith(X, RangeType)); 6726 } 6727 case scUDivExpr: { 6728 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 6729 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1); 6730 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1); 6731 return setRange(UDiv, SignHint, 6732 ConservativeResult.intersectWith(X.udiv(Y), RangeType)); 6733 } 6734 case scAddRecExpr: { 6735 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S); 6736 // If there's no unsigned wrap, the value will never be less than its 6737 // initial value. 6738 if (AddRec->hasNoUnsignedWrap()) { 6739 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart()); 6740 if (!UnsignedMinValue.isZero()) 6741 ConservativeResult = ConservativeResult.intersectWith( 6742 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType); 6743 } 6744 6745 // If there's no signed wrap, and all the operands except initial value have 6746 // the same sign or zero, the value won't ever be: 6747 // 1: smaller than initial value if operands are non negative, 6748 // 2: bigger than initial value if operands are non positive. 6749 // For both cases, value can not cross signed min/max boundary. 6750 if (AddRec->hasNoSignedWrap()) { 6751 bool AllNonNeg = true; 6752 bool AllNonPos = true; 6753 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) { 6754 if (!isKnownNonNegative(AddRec->getOperand(i))) 6755 AllNonNeg = false; 6756 if (!isKnownNonPositive(AddRec->getOperand(i))) 6757 AllNonPos = false; 6758 } 6759 if (AllNonNeg) 6760 ConservativeResult = ConservativeResult.intersectWith( 6761 ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()), 6762 APInt::getSignedMinValue(BitWidth)), 6763 RangeType); 6764 else if (AllNonPos) 6765 ConservativeResult = ConservativeResult.intersectWith( 6766 ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth), 6767 getSignedRangeMax(AddRec->getStart()) + 6768 1), 6769 RangeType); 6770 } 6771 6772 // TODO: non-affine addrec 6773 if (AddRec->isAffine()) { 6774 const SCEV *MaxBEScev = 6775 getConstantMaxBackedgeTakenCount(AddRec->getLoop()); 6776 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) { 6777 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt(); 6778 6779 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if 6780 // MaxBECount's active bits are all <= AddRec's bit width. 6781 if (MaxBECount.getBitWidth() > BitWidth && 6782 MaxBECount.getActiveBits() <= BitWidth) 6783 MaxBECount = MaxBECount.trunc(BitWidth); 6784 else if (MaxBECount.getBitWidth() < BitWidth) 6785 MaxBECount = MaxBECount.zext(BitWidth); 6786 6787 if (MaxBECount.getBitWidth() == BitWidth) { 6788 auto RangeFromAffine = getRangeForAffineAR( 6789 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); 6790 ConservativeResult = 6791 ConservativeResult.intersectWith(RangeFromAffine, RangeType); 6792 6793 auto RangeFromFactoring = getRangeViaFactoring( 6794 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); 6795 ConservativeResult = 6796 ConservativeResult.intersectWith(RangeFromFactoring, RangeType); 6797 } 6798 } 6799 6800 // Now try symbolic BE count and more powerful methods. 6801 if (UseExpensiveRangeSharpening) { 6802 const SCEV *SymbolicMaxBECount = 6803 getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); 6804 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) && 6805 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth && 6806 AddRec->hasNoSelfWrap()) { 6807 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR( 6808 AddRec, SymbolicMaxBECount, BitWidth, SignHint); 6809 ConservativeResult = 6810 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType); 6811 } 6812 } 6813 } 6814 6815 return setRange(AddRec, SignHint, std::move(ConservativeResult)); 6816 } 6817 case scUMaxExpr: 6818 case scSMaxExpr: 6819 case scUMinExpr: 6820 case scSMinExpr: 6821 case scSequentialUMinExpr: { 6822 Intrinsic::ID ID; 6823 switch (S->getSCEVType()) { 6824 case scUMaxExpr: 6825 ID = Intrinsic::umax; 6826 break; 6827 case scSMaxExpr: 6828 ID = Intrinsic::smax; 6829 break; 6830 case scUMinExpr: 6831 case scSequentialUMinExpr: 6832 ID = Intrinsic::umin; 6833 break; 6834 case scSMinExpr: 6835 ID = Intrinsic::smin; 6836 break; 6837 default: 6838 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr."); 6839 } 6840 6841 const auto *NAry = cast<SCEVNAryExpr>(S); 6842 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1); 6843 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i) 6844 X = X.intrinsic( 6845 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)}); 6846 return setRange(S, SignHint, 6847 ConservativeResult.intersectWith(X, RangeType)); 6848 } 6849 case scUnknown: { 6850 const SCEVUnknown *U = cast<SCEVUnknown>(S); 6851 Value *V = U->getValue(); 6852 6853 // Check if the IR explicitly contains !range metadata. 6854 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V); 6855 if (MDRange) 6856 ConservativeResult = 6857 ConservativeResult.intersectWith(*MDRange, RangeType); 6858 6859 // Use facts about recurrences in the underlying IR. Note that add 6860 // recurrences are AddRecExprs and thus don't hit this path. This 6861 // primarily handles shift recurrences. 6862 auto CR = getRangeForUnknownRecurrence(U); 6863 ConservativeResult = ConservativeResult.intersectWith(CR); 6864 6865 // See if ValueTracking can give us a useful range. 6866 const DataLayout &DL = getDataLayout(); 6867 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT); 6868 if (Known.getBitWidth() != BitWidth) 6869 Known = Known.zextOrTrunc(BitWidth); 6870 6871 // ValueTracking may be able to compute a tighter result for the number of 6872 // sign bits than for the value of those sign bits. 6873 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT); 6874 if (U->getType()->isPointerTy()) { 6875 // If the pointer size is larger than the index size type, this can cause 6876 // NS to be larger than BitWidth. So compensate for this. 6877 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType()); 6878 int ptrIdxDiff = ptrSize - BitWidth; 6879 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff) 6880 NS -= ptrIdxDiff; 6881 } 6882 6883 if (NS > 1) { 6884 // If we know any of the sign bits, we know all of the sign bits. 6885 if (!Known.Zero.getHiBits(NS).isZero()) 6886 Known.Zero.setHighBits(NS); 6887 if (!Known.One.getHiBits(NS).isZero()) 6888 Known.One.setHighBits(NS); 6889 } 6890 6891 if (Known.getMinValue() != Known.getMaxValue() + 1) 6892 ConservativeResult = ConservativeResult.intersectWith( 6893 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1), 6894 RangeType); 6895 if (NS > 1) 6896 ConservativeResult = ConservativeResult.intersectWith( 6897 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), 6898 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1), 6899 RangeType); 6900 6901 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) { 6902 // Strengthen the range if the underlying IR value is a 6903 // global/alloca/heap allocation using the size of the object. 6904 bool CanBeNull, CanBeFreed; 6905 uint64_t DerefBytes = 6906 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed); 6907 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) { 6908 // The highest address the object can start is DerefBytes bytes before 6909 // the end (unsigned max value). If this value is not a multiple of the 6910 // alignment, the last possible start value is the next lowest multiple 6911 // of the alignment. Note: The computations below cannot overflow, 6912 // because if they would there's no possible start address for the 6913 // object. 6914 APInt MaxVal = 6915 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes); 6916 uint64_t Align = U->getValue()->getPointerAlignment(DL).value(); 6917 uint64_t Rem = MaxVal.urem(Align); 6918 MaxVal -= APInt(BitWidth, Rem); 6919 APInt MinVal = APInt::getZero(BitWidth); 6920 if (llvm::isKnownNonZero(V, DL)) 6921 MinVal = Align; 6922 ConservativeResult = ConservativeResult.intersectWith( 6923 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType); 6924 } 6925 } 6926 6927 // A range of Phi is a subset of union of all ranges of its input. 6928 if (PHINode *Phi = dyn_cast<PHINode>(V)) { 6929 // Make sure that we do not run over cycled Phis. 6930 if (PendingPhiRanges.insert(Phi).second) { 6931 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); 6932 6933 for (const auto &Op : Phi->operands()) { 6934 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1); 6935 RangeFromOps = RangeFromOps.unionWith(OpRange); 6936 // No point to continue if we already have a full set. 6937 if (RangeFromOps.isFullSet()) 6938 break; 6939 } 6940 ConservativeResult = 6941 ConservativeResult.intersectWith(RangeFromOps, RangeType); 6942 bool Erased = PendingPhiRanges.erase(Phi); 6943 assert(Erased && "Failed to erase Phi properly?"); 6944 (void)Erased; 6945 } 6946 } 6947 6948 // vscale can't be equal to zero 6949 if (const auto *II = dyn_cast<IntrinsicInst>(V)) 6950 if (II->getIntrinsicID() == Intrinsic::vscale) { 6951 ConstantRange Disallowed = APInt::getZero(BitWidth); 6952 ConservativeResult = ConservativeResult.difference(Disallowed); 6953 } 6954 6955 return setRange(U, SignHint, std::move(ConservativeResult)); 6956 } 6957 case scCouldNotCompute: 6958 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 6959 } 6960 6961 return setRange(S, SignHint, std::move(ConservativeResult)); 6962 } 6963 6964 // Given a StartRange, Step and MaxBECount for an expression compute a range of 6965 // values that the expression can take. Initially, the expression has a value 6966 // from StartRange and then is changed by Step up to MaxBECount times. Signed 6967 // argument defines if we treat Step as signed or unsigned. 6968 static ConstantRange getRangeForAffineARHelper(APInt Step, 6969 const ConstantRange &StartRange, 6970 const APInt &MaxBECount, 6971 bool Signed) { 6972 unsigned BitWidth = Step.getBitWidth(); 6973 assert(BitWidth == StartRange.getBitWidth() && 6974 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths"); 6975 // If either Step or MaxBECount is 0, then the expression won't change, and we 6976 // just need to return the initial range. 6977 if (Step == 0 || MaxBECount == 0) 6978 return StartRange; 6979 6980 // If we don't know anything about the initial value (i.e. StartRange is 6981 // FullRange), then we don't know anything about the final range either. 6982 // Return FullRange. 6983 if (StartRange.isFullSet()) 6984 return ConstantRange::getFull(BitWidth); 6985 6986 // If Step is signed and negative, then we use its absolute value, but we also 6987 // note that we're moving in the opposite direction. 6988 bool Descending = Signed && Step.isNegative(); 6989 6990 if (Signed) 6991 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: 6992 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. 6993 // This equations hold true due to the well-defined wrap-around behavior of 6994 // APInt. 6995 Step = Step.abs(); 6996 6997 // Check if Offset is more than full span of BitWidth. If it is, the 6998 // expression is guaranteed to overflow. 6999 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) 7000 return ConstantRange::getFull(BitWidth); 7001 7002 // Offset is by how much the expression can change. Checks above guarantee no 7003 // overflow here. 7004 APInt Offset = Step * MaxBECount; 7005 7006 // Minimum value of the final range will match the minimal value of StartRange 7007 // if the expression is increasing and will be decreased by Offset otherwise. 7008 // Maximum value of the final range will match the maximal value of StartRange 7009 // if the expression is decreasing and will be increased by Offset otherwise. 7010 APInt StartLower = StartRange.getLower(); 7011 APInt StartUpper = StartRange.getUpper() - 1; 7012 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset)) 7013 : (StartUpper + std::move(Offset)); 7014 7015 // It's possible that the new minimum/maximum value will fall into the initial 7016 // range (due to wrap around). This means that the expression can take any 7017 // value in this bitwidth, and we have to return full range. 7018 if (StartRange.contains(MovedBoundary)) 7019 return ConstantRange::getFull(BitWidth); 7020 7021 APInt NewLower = 7022 Descending ? std::move(MovedBoundary) : std::move(StartLower); 7023 APInt NewUpper = 7024 Descending ? std::move(StartUpper) : std::move(MovedBoundary); 7025 NewUpper += 1; 7026 7027 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. 7028 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper)); 7029 } 7030 7031 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, 7032 const SCEV *Step, 7033 const APInt &MaxBECount) { 7034 assert(getTypeSizeInBits(Start->getType()) == 7035 getTypeSizeInBits(Step->getType()) && 7036 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() && 7037 "mismatched bit widths"); 7038 7039 // First, consider step signed. 7040 ConstantRange StartSRange = getSignedRange(Start); 7041 ConstantRange StepSRange = getSignedRange(Step); 7042 7043 // If Step can be both positive and negative, we need to find ranges for the 7044 // maximum absolute step values in both directions and union them. 7045 ConstantRange SR = getRangeForAffineARHelper( 7046 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true); 7047 SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), 7048 StartSRange, MaxBECount, 7049 /* Signed = */ true)); 7050 7051 // Next, consider step unsigned. 7052 ConstantRange UR = getRangeForAffineARHelper( 7053 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount, 7054 /* Signed = */ false); 7055 7056 // Finally, intersect signed and unsigned ranges. 7057 return SR.intersectWith(UR, ConstantRange::Smallest); 7058 } 7059 7060 ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( 7061 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth, 7062 ScalarEvolution::RangeSignHint SignHint) { 7063 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"); 7064 assert(AddRec->hasNoSelfWrap() && 7065 "This only works for non-self-wrapping AddRecs!"); 7066 const bool IsSigned = SignHint == HINT_RANGE_SIGNED; 7067 const SCEV *Step = AddRec->getStepRecurrence(*this); 7068 // Only deal with constant step to save compile time. 7069 if (!isa<SCEVConstant>(Step)) 7070 return ConstantRange::getFull(BitWidth); 7071 // Let's make sure that we can prove that we do not self-wrap during 7072 // MaxBECount iterations. We need this because MaxBECount is a maximum 7073 // iteration count estimate, and we might infer nw from some exit for which we 7074 // do not know max exit count (or any other side reasoning). 7075 // TODO: Turn into assert at some point. 7076 if (getTypeSizeInBits(MaxBECount->getType()) > 7077 getTypeSizeInBits(AddRec->getType())) 7078 return ConstantRange::getFull(BitWidth); 7079 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType()); 7080 const SCEV *RangeWidth = getMinusOne(AddRec->getType()); 7081 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); 7082 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); 7083 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount, 7084 MaxItersWithoutWrap)) 7085 return ConstantRange::getFull(BitWidth); 7086 7087 ICmpInst::Predicate LEPred = 7088 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; 7089 ICmpInst::Predicate GEPred = 7090 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; 7091 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); 7092 7093 // We know that there is no self-wrap. Let's take Start and End values and 7094 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during 7095 // the iteration. They either lie inside the range [Min(Start, End), 7096 // Max(Start, End)] or outside it: 7097 // 7098 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax; 7099 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax; 7100 // 7101 // No self wrap flag guarantees that the intermediate values cannot be BOTH 7102 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that 7103 // knowledge, let's try to prove that we are dealing with Case 1. It is so if 7104 // Start <= End and step is positive, or Start >= End and step is negative. 7105 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); 7106 ConstantRange StartRange = getRangeRef(Start, SignHint); 7107 ConstantRange EndRange = getRangeRef(End, SignHint); 7108 ConstantRange RangeBetween = StartRange.unionWith(EndRange); 7109 // If they already cover full iteration space, we will know nothing useful 7110 // even if we prove what we want to prove. 7111 if (RangeBetween.isFullSet()) 7112 return RangeBetween; 7113 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax). 7114 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet() 7115 : RangeBetween.isWrappedSet(); 7116 if (IsWrappedSet) 7117 return ConstantRange::getFull(BitWidth); 7118 7119 if (isKnownPositive(Step) && 7120 isKnownPredicateViaConstantRanges(LEPred, Start, End)) 7121 return RangeBetween; 7122 if (isKnownNegative(Step) && 7123 isKnownPredicateViaConstantRanges(GEPred, Start, End)) 7124 return RangeBetween; 7125 return ConstantRange::getFull(BitWidth); 7126 } 7127 7128 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, 7129 const SCEV *Step, 7130 const APInt &MaxBECount) { 7131 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) 7132 // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) 7133 7134 unsigned BitWidth = MaxBECount.getBitWidth(); 7135 assert(getTypeSizeInBits(Start->getType()) == BitWidth && 7136 getTypeSizeInBits(Step->getType()) == BitWidth && 7137 "mismatched bit widths"); 7138 7139 struct SelectPattern { 7140 Value *Condition = nullptr; 7141 APInt TrueValue; 7142 APInt FalseValue; 7143 7144 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, 7145 const SCEV *S) { 7146 std::optional<unsigned> CastOp; 7147 APInt Offset(BitWidth, 0); 7148 7149 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && 7150 "Should be!"); 7151 7152 // Peel off a constant offset. In the future we could consider being 7153 // smarter here and handle {Start+Step,+,Step} too. 7154 const APInt *Off; 7155 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S)))) 7156 Offset = *Off; 7157 7158 // Peel off a cast operation 7159 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) { 7160 CastOp = SCast->getSCEVType(); 7161 S = SCast->getOperand(); 7162 } 7163 7164 using namespace llvm::PatternMatch; 7165 7166 auto *SU = dyn_cast<SCEVUnknown>(S); 7167 const APInt *TrueVal, *FalseVal; 7168 if (!SU || 7169 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), 7170 m_APInt(FalseVal)))) { 7171 Condition = nullptr; 7172 return; 7173 } 7174 7175 TrueValue = *TrueVal; 7176 FalseValue = *FalseVal; 7177 7178 // Re-apply the cast we peeled off earlier 7179 if (CastOp) 7180 switch (*CastOp) { 7181 default: 7182 llvm_unreachable("Unknown SCEV cast type!"); 7183 7184 case scTruncate: 7185 TrueValue = TrueValue.trunc(BitWidth); 7186 FalseValue = FalseValue.trunc(BitWidth); 7187 break; 7188 case scZeroExtend: 7189 TrueValue = TrueValue.zext(BitWidth); 7190 FalseValue = FalseValue.zext(BitWidth); 7191 break; 7192 case scSignExtend: 7193 TrueValue = TrueValue.sext(BitWidth); 7194 FalseValue = FalseValue.sext(BitWidth); 7195 break; 7196 } 7197 7198 // Re-apply the constant offset we peeled off earlier 7199 TrueValue += Offset; 7200 FalseValue += Offset; 7201 } 7202 7203 bool isRecognized() { return Condition != nullptr; } 7204 }; 7205 7206 SelectPattern StartPattern(*this, BitWidth, Start); 7207 if (!StartPattern.isRecognized()) 7208 return ConstantRange::getFull(BitWidth); 7209 7210 SelectPattern StepPattern(*this, BitWidth, Step); 7211 if (!StepPattern.isRecognized()) 7212 return ConstantRange::getFull(BitWidth); 7213 7214 if (StartPattern.Condition != StepPattern.Condition) { 7215 // We don't handle this case today; but we could, by considering four 7216 // possibilities below instead of two. I'm not sure if there are cases where 7217 // that will help over what getRange already does, though. 7218 return ConstantRange::getFull(BitWidth); 7219 } 7220 7221 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to 7222 // construct arbitrary general SCEV expressions here. This function is called 7223 // from deep in the call stack, and calling getSCEV (on a sext instruction, 7224 // say) can end up caching a suboptimal value. 7225 7226 // FIXME: without the explicit `this` receiver below, MSVC errors out with 7227 // C2352 and C2512 (otherwise it isn't needed). 7228 7229 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); 7230 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); 7231 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); 7232 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); 7233 7234 ConstantRange TrueRange = 7235 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount); 7236 ConstantRange FalseRange = 7237 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount); 7238 7239 return TrueRange.unionWith(FalseRange); 7240 } 7241 7242 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { 7243 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; 7244 const BinaryOperator *BinOp = cast<BinaryOperator>(V); 7245 7246 // Return early if there are no flags to propagate to the SCEV. 7247 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 7248 if (BinOp->hasNoUnsignedWrap()) 7249 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 7250 if (BinOp->hasNoSignedWrap()) 7251 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 7252 if (Flags == SCEV::FlagAnyWrap) 7253 return SCEV::FlagAnyWrap; 7254 7255 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; 7256 } 7257 7258 const Instruction * 7259 ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) { 7260 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 7261 return &*AddRec->getLoop()->getHeader()->begin(); 7262 if (auto *U = dyn_cast<SCEVUnknown>(S)) 7263 if (auto *I = dyn_cast<Instruction>(U->getValue())) 7264 return I; 7265 return nullptr; 7266 } 7267 7268 const Instruction * 7269 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops, 7270 bool &Precise) { 7271 Precise = true; 7272 // Do a bounded search of the def relation of the requested SCEVs. 7273 SmallSet<const SCEV *, 16> Visited; 7274 SmallVector<const SCEV *> Worklist; 7275 auto pushOp = [&](const SCEV *S) { 7276 if (!Visited.insert(S).second) 7277 return; 7278 // Threshold of 30 here is arbitrary. 7279 if (Visited.size() > 30) { 7280 Precise = false; 7281 return; 7282 } 7283 Worklist.push_back(S); 7284 }; 7285 7286 for (const auto *S : Ops) 7287 pushOp(S); 7288 7289 const Instruction *Bound = nullptr; 7290 while (!Worklist.empty()) { 7291 auto *S = Worklist.pop_back_val(); 7292 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) { 7293 if (!Bound || DT.dominates(Bound, DefI)) 7294 Bound = DefI; 7295 } else { 7296 for (const auto *Op : S->operands()) 7297 pushOp(Op); 7298 } 7299 } 7300 return Bound ? Bound : &*F.getEntryBlock().begin(); 7301 } 7302 7303 const Instruction * 7304 ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) { 7305 bool Discard; 7306 return getDefiningScopeBound(Ops, Discard); 7307 } 7308 7309 bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A, 7310 const Instruction *B) { 7311 if (A->getParent() == B->getParent() && 7312 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(), 7313 B->getIterator())) 7314 return true; 7315 7316 auto *BLoop = LI.getLoopFor(B->getParent()); 7317 if (BLoop && BLoop->getHeader() == B->getParent() && 7318 BLoop->getLoopPreheader() == A->getParent() && 7319 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(), 7320 A->getParent()->end()) && 7321 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(), 7322 B->getIterator())) 7323 return true; 7324 return false; 7325 } 7326 7327 bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) { 7328 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true); 7329 visitAll(Op, PC); 7330 return PC.MaybePoison.empty(); 7331 } 7332 7333 bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) { 7334 return !SCEVExprContains(Op, [this](const SCEV *S) { 7335 const SCEV *Op1; 7336 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1))); 7337 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor 7338 // is a non-zero constant, we have to assume the UDiv may be UB. 7339 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1)); 7340 }); 7341 } 7342 7343 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { 7344 // Only proceed if we can prove that I does not yield poison. 7345 if (!programUndefinedIfPoison(I)) 7346 return false; 7347 7348 // At this point we know that if I is executed, then it does not wrap 7349 // according to at least one of NSW or NUW. If I is not executed, then we do 7350 // not know if the calculation that I represents would wrap. Multiple 7351 // instructions can map to the same SCEV. If we apply NSW or NUW from I to 7352 // the SCEV, we must guarantee no wrapping for that SCEV also when it is 7353 // derived from other instructions that map to the same SCEV. We cannot make 7354 // that guarantee for cases where I is not executed. So we need to find a 7355 // upper bound on the defining scope for the SCEV, and prove that I is 7356 // executed every time we enter that scope. When the bounding scope is a 7357 // loop (the common case), this is equivalent to proving I executes on every 7358 // iteration of that loop. 7359 SmallVector<const SCEV *> SCEVOps; 7360 for (const Use &Op : I->operands()) { 7361 // I could be an extractvalue from a call to an overflow intrinsic. 7362 // TODO: We can do better here in some cases. 7363 if (isSCEVable(Op->getType())) 7364 SCEVOps.push_back(getSCEV(Op)); 7365 } 7366 auto *DefI = getDefiningScopeBound(SCEVOps); 7367 return isGuaranteedToTransferExecutionTo(DefI, I); 7368 } 7369 7370 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { 7371 // If we know that \c I can never be poison period, then that's enough. 7372 if (isSCEVExprNeverPoison(I)) 7373 return true; 7374 7375 // If the loop only has one exit, then we know that, if the loop is entered, 7376 // any instruction dominating that exit will be executed. If any such 7377 // instruction would result in UB, the addrec cannot be poison. 7378 // 7379 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but 7380 // also handles uses outside the loop header (they just need to dominate the 7381 // single exit). 7382 7383 auto *ExitingBB = L->getExitingBlock(); 7384 if (!ExitingBB || !loopHasNoAbnormalExits(L)) 7385 return false; 7386 7387 SmallPtrSet<const Value *, 16> KnownPoison; 7388 SmallVector<const Instruction *, 8> Worklist; 7389 7390 // We start by assuming \c I, the post-inc add recurrence, is poison. Only 7391 // things that are known to be poison under that assumption go on the 7392 // Worklist. 7393 KnownPoison.insert(I); 7394 Worklist.push_back(I); 7395 7396 while (!Worklist.empty()) { 7397 const Instruction *Poison = Worklist.pop_back_val(); 7398 7399 for (const Use &U : Poison->uses()) { 7400 const Instruction *PoisonUser = cast<Instruction>(U.getUser()); 7401 if (mustTriggerUB(PoisonUser, KnownPoison) && 7402 DT.dominates(PoisonUser->getParent(), ExitingBB)) 7403 return true; 7404 7405 if (propagatesPoison(U) && L->contains(PoisonUser)) 7406 if (KnownPoison.insert(PoisonUser).second) 7407 Worklist.push_back(PoisonUser); 7408 } 7409 } 7410 7411 return false; 7412 } 7413 7414 ScalarEvolution::LoopProperties 7415 ScalarEvolution::getLoopProperties(const Loop *L) { 7416 using LoopProperties = ScalarEvolution::LoopProperties; 7417 7418 auto Itr = LoopPropertiesCache.find(L); 7419 if (Itr == LoopPropertiesCache.end()) { 7420 auto HasSideEffects = [](Instruction *I) { 7421 if (auto *SI = dyn_cast<StoreInst>(I)) 7422 return !SI->isSimple(); 7423 7424 return I->mayThrow() || I->mayWriteToMemory(); 7425 }; 7426 7427 LoopProperties LP = {/* HasNoAbnormalExits */ true, 7428 /*HasNoSideEffects*/ true}; 7429 7430 for (auto *BB : L->getBlocks()) 7431 for (auto &I : *BB) { 7432 if (!isGuaranteedToTransferExecutionToSuccessor(&I)) 7433 LP.HasNoAbnormalExits = false; 7434 if (HasSideEffects(&I)) 7435 LP.HasNoSideEffects = false; 7436 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) 7437 break; // We're already as pessimistic as we can get. 7438 } 7439 7440 auto InsertPair = LoopPropertiesCache.insert({L, LP}); 7441 assert(InsertPair.second && "We just checked!"); 7442 Itr = InsertPair.first; 7443 } 7444 7445 return Itr->second; 7446 } 7447 7448 bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { 7449 // A mustprogress loop without side effects must be finite. 7450 // TODO: The check used here is very conservative. It's only *specific* 7451 // side effects which are well defined in infinite loops. 7452 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); 7453 } 7454 7455 const SCEV *ScalarEvolution::createSCEVIter(Value *V) { 7456 // Worklist item with a Value and a bool indicating whether all operands have 7457 // been visited already. 7458 using PointerTy = PointerIntPair<Value *, 1, bool>; 7459 SmallVector<PointerTy> Stack; 7460 7461 Stack.emplace_back(V, true); 7462 Stack.emplace_back(V, false); 7463 while (!Stack.empty()) { 7464 auto E = Stack.pop_back_val(); 7465 Value *CurV = E.getPointer(); 7466 7467 if (getExistingSCEV(CurV)) 7468 continue; 7469 7470 SmallVector<Value *> Ops; 7471 const SCEV *CreatedSCEV = nullptr; 7472 // If all operands have been visited already, create the SCEV. 7473 if (E.getInt()) { 7474 CreatedSCEV = createSCEV(CurV); 7475 } else { 7476 // Otherwise get the operands we need to create SCEV's for before creating 7477 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially, 7478 // just use it. 7479 CreatedSCEV = getOperandsToCreate(CurV, Ops); 7480 } 7481 7482 if (CreatedSCEV) { 7483 insertValueToMap(CurV, CreatedSCEV); 7484 } else { 7485 // Queue CurV for SCEV creation, followed by its's operands which need to 7486 // be constructed first. 7487 Stack.emplace_back(CurV, true); 7488 for (Value *Op : Ops) 7489 Stack.emplace_back(Op, false); 7490 } 7491 } 7492 7493 return getExistingSCEV(V); 7494 } 7495 7496 const SCEV * 7497 ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) { 7498 if (!isSCEVable(V->getType())) 7499 return getUnknown(V); 7500 7501 if (Instruction *I = dyn_cast<Instruction>(V)) { 7502 // Don't attempt to analyze instructions in blocks that aren't 7503 // reachable. Such instructions don't matter, and they aren't required 7504 // to obey basic rules for definitions dominating uses which this 7505 // analysis depends on. 7506 if (!DT.isReachableFromEntry(I->getParent())) 7507 return getUnknown(PoisonValue::get(V->getType())); 7508 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) 7509 return getConstant(CI); 7510 else if (isa<GlobalAlias>(V)) 7511 return getUnknown(V); 7512 else if (!isa<ConstantExpr>(V)) 7513 return getUnknown(V); 7514 7515 Operator *U = cast<Operator>(V); 7516 if (auto BO = 7517 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) { 7518 bool IsConstArg = isa<ConstantInt>(BO->RHS); 7519 switch (BO->Opcode) { 7520 case Instruction::Add: 7521 case Instruction::Mul: { 7522 // For additions and multiplications, traverse add/mul chains for which we 7523 // can potentially create a single SCEV, to reduce the number of 7524 // get{Add,Mul}Expr calls. 7525 do { 7526 if (BO->Op) { 7527 if (BO->Op != V && getExistingSCEV(BO->Op)) { 7528 Ops.push_back(BO->Op); 7529 break; 7530 } 7531 } 7532 Ops.push_back(BO->RHS); 7533 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT, 7534 dyn_cast<Instruction>(V)); 7535 if (!NewBO || 7536 (BO->Opcode == Instruction::Add && 7537 (NewBO->Opcode != Instruction::Add && 7538 NewBO->Opcode != Instruction::Sub)) || 7539 (BO->Opcode == Instruction::Mul && 7540 NewBO->Opcode != Instruction::Mul)) { 7541 Ops.push_back(BO->LHS); 7542 break; 7543 } 7544 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions 7545 // requires a SCEV for the LHS. 7546 if (BO->Op && (BO->IsNSW || BO->IsNUW)) { 7547 auto *I = dyn_cast<Instruction>(BO->Op); 7548 if (I && programUndefinedIfPoison(I)) { 7549 Ops.push_back(BO->LHS); 7550 break; 7551 } 7552 } 7553 BO = NewBO; 7554 } while (true); 7555 return nullptr; 7556 } 7557 case Instruction::Sub: 7558 case Instruction::UDiv: 7559 case Instruction::URem: 7560 break; 7561 case Instruction::AShr: 7562 case Instruction::Shl: 7563 case Instruction::Xor: 7564 if (!IsConstArg) 7565 return nullptr; 7566 break; 7567 case Instruction::And: 7568 case Instruction::Or: 7569 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1)) 7570 return nullptr; 7571 break; 7572 case Instruction::LShr: 7573 return getUnknown(V); 7574 default: 7575 llvm_unreachable("Unhandled binop"); 7576 break; 7577 } 7578 7579 Ops.push_back(BO->LHS); 7580 Ops.push_back(BO->RHS); 7581 return nullptr; 7582 } 7583 7584 switch (U->getOpcode()) { 7585 case Instruction::Trunc: 7586 case Instruction::ZExt: 7587 case Instruction::SExt: 7588 case Instruction::PtrToInt: 7589 Ops.push_back(U->getOperand(0)); 7590 return nullptr; 7591 7592 case Instruction::BitCast: 7593 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) { 7594 Ops.push_back(U->getOperand(0)); 7595 return nullptr; 7596 } 7597 return getUnknown(V); 7598 7599 case Instruction::SDiv: 7600 case Instruction::SRem: 7601 Ops.push_back(U->getOperand(0)); 7602 Ops.push_back(U->getOperand(1)); 7603 return nullptr; 7604 7605 case Instruction::GetElementPtr: 7606 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() && 7607 "GEP source element type must be sized"); 7608 llvm::append_range(Ops, U->operands()); 7609 return nullptr; 7610 7611 case Instruction::IntToPtr: 7612 return getUnknown(V); 7613 7614 case Instruction::PHI: 7615 // Keep constructing SCEVs' for phis recursively for now. 7616 return nullptr; 7617 7618 case Instruction::Select: { 7619 // Check if U is a select that can be simplified to a SCEVUnknown. 7620 auto CanSimplifyToUnknown = [this, U]() { 7621 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0))) 7622 return false; 7623 7624 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0)); 7625 if (!ICI) 7626 return false; 7627 Value *LHS = ICI->getOperand(0); 7628 Value *RHS = ICI->getOperand(1); 7629 if (ICI->getPredicate() == CmpInst::ICMP_EQ || 7630 ICI->getPredicate() == CmpInst::ICMP_NE) { 7631 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero())) 7632 return true; 7633 } else if (getTypeSizeInBits(LHS->getType()) > 7634 getTypeSizeInBits(U->getType())) 7635 return true; 7636 return false; 7637 }; 7638 if (CanSimplifyToUnknown()) 7639 return getUnknown(U); 7640 7641 llvm::append_range(Ops, U->operands()); 7642 return nullptr; 7643 break; 7644 } 7645 case Instruction::Call: 7646 case Instruction::Invoke: 7647 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) { 7648 Ops.push_back(RV); 7649 return nullptr; 7650 } 7651 7652 if (auto *II = dyn_cast<IntrinsicInst>(U)) { 7653 switch (II->getIntrinsicID()) { 7654 case Intrinsic::abs: 7655 Ops.push_back(II->getArgOperand(0)); 7656 return nullptr; 7657 case Intrinsic::umax: 7658 case Intrinsic::umin: 7659 case Intrinsic::smax: 7660 case Intrinsic::smin: 7661 case Intrinsic::usub_sat: 7662 case Intrinsic::uadd_sat: 7663 Ops.push_back(II->getArgOperand(0)); 7664 Ops.push_back(II->getArgOperand(1)); 7665 return nullptr; 7666 case Intrinsic::start_loop_iterations: 7667 case Intrinsic::annotation: 7668 case Intrinsic::ptr_annotation: 7669 Ops.push_back(II->getArgOperand(0)); 7670 return nullptr; 7671 default: 7672 break; 7673 } 7674 } 7675 break; 7676 } 7677 7678 return nullptr; 7679 } 7680 7681 const SCEV *ScalarEvolution::createSCEV(Value *V) { 7682 if (!isSCEVable(V->getType())) 7683 return getUnknown(V); 7684 7685 if (Instruction *I = dyn_cast<Instruction>(V)) { 7686 // Don't attempt to analyze instructions in blocks that aren't 7687 // reachable. Such instructions don't matter, and they aren't required 7688 // to obey basic rules for definitions dominating uses which this 7689 // analysis depends on. 7690 if (!DT.isReachableFromEntry(I->getParent())) 7691 return getUnknown(PoisonValue::get(V->getType())); 7692 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) 7693 return getConstant(CI); 7694 else if (isa<GlobalAlias>(V)) 7695 return getUnknown(V); 7696 else if (!isa<ConstantExpr>(V)) 7697 return getUnknown(V); 7698 7699 const SCEV *LHS; 7700 const SCEV *RHS; 7701 7702 Operator *U = cast<Operator>(V); 7703 if (auto BO = 7704 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) { 7705 switch (BO->Opcode) { 7706 case Instruction::Add: { 7707 // The simple thing to do would be to just call getSCEV on both operands 7708 // and call getAddExpr with the result. However if we're looking at a 7709 // bunch of things all added together, this can be quite inefficient, 7710 // because it leads to N-1 getAddExpr calls for N ultimate operands. 7711 // Instead, gather up all the operands and make a single getAddExpr call. 7712 // LLVM IR canonical form means we need only traverse the left operands. 7713 SmallVector<const SCEV *, 4> AddOps; 7714 do { 7715 if (BO->Op) { 7716 if (auto *OpSCEV = getExistingSCEV(BO->Op)) { 7717 AddOps.push_back(OpSCEV); 7718 break; 7719 } 7720 7721 // If a NUW or NSW flag can be applied to the SCEV for this 7722 // addition, then compute the SCEV for this addition by itself 7723 // with a separate call to getAddExpr. We need to do that 7724 // instead of pushing the operands of the addition onto AddOps, 7725 // since the flags are only known to apply to this particular 7726 // addition - they may not apply to other additions that can be 7727 // formed with operands from AddOps. 7728 const SCEV *RHS = getSCEV(BO->RHS); 7729 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); 7730 if (Flags != SCEV::FlagAnyWrap) { 7731 const SCEV *LHS = getSCEV(BO->LHS); 7732 if (BO->Opcode == Instruction::Sub) 7733 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); 7734 else 7735 AddOps.push_back(getAddExpr(LHS, RHS, Flags)); 7736 break; 7737 } 7738 } 7739 7740 if (BO->Opcode == Instruction::Sub) 7741 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); 7742 else 7743 AddOps.push_back(getSCEV(BO->RHS)); 7744 7745 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT, 7746 dyn_cast<Instruction>(V)); 7747 if (!NewBO || (NewBO->Opcode != Instruction::Add && 7748 NewBO->Opcode != Instruction::Sub)) { 7749 AddOps.push_back(getSCEV(BO->LHS)); 7750 break; 7751 } 7752 BO = NewBO; 7753 } while (true); 7754 7755 return getAddExpr(AddOps); 7756 } 7757 7758 case Instruction::Mul: { 7759 SmallVector<const SCEV *, 4> MulOps; 7760 do { 7761 if (BO->Op) { 7762 if (auto *OpSCEV = getExistingSCEV(BO->Op)) { 7763 MulOps.push_back(OpSCEV); 7764 break; 7765 } 7766 7767 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); 7768 if (Flags != SCEV::FlagAnyWrap) { 7769 LHS = getSCEV(BO->LHS); 7770 RHS = getSCEV(BO->RHS); 7771 MulOps.push_back(getMulExpr(LHS, RHS, Flags)); 7772 break; 7773 } 7774 } 7775 7776 MulOps.push_back(getSCEV(BO->RHS)); 7777 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT, 7778 dyn_cast<Instruction>(V)); 7779 if (!NewBO || NewBO->Opcode != Instruction::Mul) { 7780 MulOps.push_back(getSCEV(BO->LHS)); 7781 break; 7782 } 7783 BO = NewBO; 7784 } while (true); 7785 7786 return getMulExpr(MulOps); 7787 } 7788 case Instruction::UDiv: 7789 LHS = getSCEV(BO->LHS); 7790 RHS = getSCEV(BO->RHS); 7791 return getUDivExpr(LHS, RHS); 7792 case Instruction::URem: 7793 LHS = getSCEV(BO->LHS); 7794 RHS = getSCEV(BO->RHS); 7795 return getURemExpr(LHS, RHS); 7796 case Instruction::Sub: { 7797 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 7798 if (BO->Op) 7799 Flags = getNoWrapFlagsFromUB(BO->Op); 7800 LHS = getSCEV(BO->LHS); 7801 RHS = getSCEV(BO->RHS); 7802 return getMinusSCEV(LHS, RHS, Flags); 7803 } 7804 case Instruction::And: 7805 // For an expression like x&255 that merely masks off the high bits, 7806 // use zext(trunc(x)) as the SCEV expression. 7807 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { 7808 if (CI->isZero()) 7809 return getSCEV(BO->RHS); 7810 if (CI->isMinusOne()) 7811 return getSCEV(BO->LHS); 7812 const APInt &A = CI->getValue(); 7813 7814 // Instcombine's ShrinkDemandedConstant may strip bits out of 7815 // constants, obscuring what would otherwise be a low-bits mask. 7816 // Use computeKnownBits to compute what ShrinkDemandedConstant 7817 // knew about to reconstruct a low-bits mask value. 7818 unsigned LZ = A.countl_zero(); 7819 unsigned TZ = A.countr_zero(); 7820 unsigned BitWidth = A.getBitWidth(); 7821 KnownBits Known(BitWidth); 7822 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT); 7823 7824 APInt EffectiveMask = 7825 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); 7826 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { 7827 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); 7828 const SCEV *LHS = getSCEV(BO->LHS); 7829 const SCEV *ShiftedLHS = nullptr; 7830 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { 7831 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { 7832 // For an expression like (x * 8) & 8, simplify the multiply. 7833 unsigned MulZeros = OpC->getAPInt().countr_zero(); 7834 unsigned GCD = std::min(MulZeros, TZ); 7835 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); 7836 SmallVector<const SCEV*, 4> MulOps; 7837 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD))); 7838 append_range(MulOps, LHSMul->operands().drop_front()); 7839 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); 7840 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); 7841 } 7842 } 7843 if (!ShiftedLHS) 7844 ShiftedLHS = getUDivExpr(LHS, MulCount); 7845 return getMulExpr( 7846 getZeroExtendExpr( 7847 getTruncateExpr(ShiftedLHS, 7848 IntegerType::get(getContext(), BitWidth - LZ - TZ)), 7849 BO->LHS->getType()), 7850 MulCount); 7851 } 7852 } 7853 // Binary `and` is a bit-wise `umin`. 7854 if (BO->LHS->getType()->isIntegerTy(1)) { 7855 LHS = getSCEV(BO->LHS); 7856 RHS = getSCEV(BO->RHS); 7857 return getUMinExpr(LHS, RHS); 7858 } 7859 break; 7860 7861 case Instruction::Or: 7862 // Binary `or` is a bit-wise `umax`. 7863 if (BO->LHS->getType()->isIntegerTy(1)) { 7864 LHS = getSCEV(BO->LHS); 7865 RHS = getSCEV(BO->RHS); 7866 return getUMaxExpr(LHS, RHS); 7867 } 7868 break; 7869 7870 case Instruction::Xor: 7871 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { 7872 // If the RHS of xor is -1, then this is a not operation. 7873 if (CI->isMinusOne()) 7874 return getNotSCEV(getSCEV(BO->LHS)); 7875 7876 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. 7877 // This is a variant of the check for xor with -1, and it handles 7878 // the case where instcombine has trimmed non-demanded bits out 7879 // of an xor with -1. 7880 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) 7881 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) 7882 if (LBO->getOpcode() == Instruction::And && 7883 LCI->getValue() == CI->getValue()) 7884 if (const SCEVZeroExtendExpr *Z = 7885 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { 7886 Type *UTy = BO->LHS->getType(); 7887 const SCEV *Z0 = Z->getOperand(); 7888 Type *Z0Ty = Z0->getType(); 7889 unsigned Z0TySize = getTypeSizeInBits(Z0Ty); 7890 7891 // If C is a low-bits mask, the zero extend is serving to 7892 // mask off the high bits. Complement the operand and 7893 // re-apply the zext. 7894 if (CI->getValue().isMask(Z0TySize)) 7895 return getZeroExtendExpr(getNotSCEV(Z0), UTy); 7896 7897 // If C is a single bit, it may be in the sign-bit position 7898 // before the zero-extend. In this case, represent the xor 7899 // using an add, which is equivalent, and re-apply the zext. 7900 APInt Trunc = CI->getValue().trunc(Z0TySize); 7901 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && 7902 Trunc.isSignMask()) 7903 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), 7904 UTy); 7905 } 7906 } 7907 break; 7908 7909 case Instruction::Shl: 7910 // Turn shift left of a constant amount into a multiply. 7911 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { 7912 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); 7913 7914 // If the shift count is not less than the bitwidth, the result of 7915 // the shift is undefined. Don't try to analyze it, because the 7916 // resolution chosen here may differ from the resolution chosen in 7917 // other parts of the compiler. 7918 if (SA->getValue().uge(BitWidth)) 7919 break; 7920 7921 // We can safely preserve the nuw flag in all cases. It's also safe to 7922 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation 7923 // requires special handling. It can be preserved as long as we're not 7924 // left shifting by bitwidth - 1. 7925 auto Flags = SCEV::FlagAnyWrap; 7926 if (BO->Op) { 7927 auto MulFlags = getNoWrapFlagsFromUB(BO->Op); 7928 if ((MulFlags & SCEV::FlagNSW) && 7929 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1))) 7930 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW); 7931 if (MulFlags & SCEV::FlagNUW) 7932 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); 7933 } 7934 7935 ConstantInt *X = ConstantInt::get( 7936 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 7937 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags); 7938 } 7939 break; 7940 7941 case Instruction::AShr: 7942 // AShr X, C, where C is a constant. 7943 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); 7944 if (!CI) 7945 break; 7946 7947 Type *OuterTy = BO->LHS->getType(); 7948 uint64_t BitWidth = getTypeSizeInBits(OuterTy); 7949 // If the shift count is not less than the bitwidth, the result of 7950 // the shift is undefined. Don't try to analyze it, because the 7951 // resolution chosen here may differ from the resolution chosen in 7952 // other parts of the compiler. 7953 if (CI->getValue().uge(BitWidth)) 7954 break; 7955 7956 if (CI->isZero()) 7957 return getSCEV(BO->LHS); // shift by zero --> noop 7958 7959 uint64_t AShrAmt = CI->getZExtValue(); 7960 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); 7961 7962 Operator *L = dyn_cast<Operator>(BO->LHS); 7963 const SCEV *AddTruncateExpr = nullptr; 7964 ConstantInt *ShlAmtCI = nullptr; 7965 const SCEV *AddConstant = nullptr; 7966 7967 if (L && L->getOpcode() == Instruction::Add) { 7968 // X = Shl A, n 7969 // Y = Add X, c 7970 // Z = AShr Y, m 7971 // n, c and m are constants. 7972 7973 Operator *LShift = dyn_cast<Operator>(L->getOperand(0)); 7974 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1)); 7975 if (LShift && LShift->getOpcode() == Instruction::Shl) { 7976 if (AddOperandCI) { 7977 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0)); 7978 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1)); 7979 // since we truncate to TruncTy, the AddConstant should be of the 7980 // same type, so create a new Constant with type same as TruncTy. 7981 // Also, the Add constant should be shifted right by AShr amount. 7982 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt); 7983 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt)); 7984 // we model the expression as sext(add(trunc(A), c << n)), since the 7985 // sext(trunc) part is already handled below, we create a 7986 // AddExpr(TruncExp) which will be used later. 7987 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); 7988 } 7989 } 7990 } else if (L && L->getOpcode() == Instruction::Shl) { 7991 // X = Shl A, n 7992 // Y = AShr X, m 7993 // Both n and m are constant. 7994 7995 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); 7996 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); 7997 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); 7998 } 7999 8000 if (AddTruncateExpr && ShlAmtCI) { 8001 // We can merge the two given cases into a single SCEV statement, 8002 // incase n = m, the mul expression will be 2^0, so it gets resolved to 8003 // a simpler case. The following code handles the two cases: 8004 // 8005 // 1) For a two-shift sext-inreg, i.e. n = m, 8006 // use sext(trunc(x)) as the SCEV expression. 8007 // 8008 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV 8009 // expression. We already checked that ShlAmt < BitWidth, so 8010 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as 8011 // ShlAmt - AShrAmt < Amt. 8012 const APInt &ShlAmt = ShlAmtCI->getValue(); 8013 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) { 8014 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, 8015 ShlAmtCI->getZExtValue() - AShrAmt); 8016 const SCEV *CompositeExpr = 8017 getMulExpr(AddTruncateExpr, getConstant(Mul)); 8018 if (L->getOpcode() != Instruction::Shl) 8019 CompositeExpr = getAddExpr(CompositeExpr, AddConstant); 8020 8021 return getSignExtendExpr(CompositeExpr, OuterTy); 8022 } 8023 } 8024 break; 8025 } 8026 } 8027 8028 switch (U->getOpcode()) { 8029 case Instruction::Trunc: 8030 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); 8031 8032 case Instruction::ZExt: 8033 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 8034 8035 case Instruction::SExt: 8036 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT, 8037 dyn_cast<Instruction>(V))) { 8038 // The NSW flag of a subtract does not always survive the conversion to 8039 // A + (-1)*B. By pushing sign extension onto its operands we are much 8040 // more likely to preserve NSW and allow later AddRec optimisations. 8041 // 8042 // NOTE: This is effectively duplicating this logic from getSignExtend: 8043 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> 8044 // but by that point the NSW information has potentially been lost. 8045 if (BO->Opcode == Instruction::Sub && BO->IsNSW) { 8046 Type *Ty = U->getType(); 8047 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); 8048 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); 8049 return getMinusSCEV(V1, V2, SCEV::FlagNSW); 8050 } 8051 } 8052 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 8053 8054 case Instruction::BitCast: 8055 // BitCasts are no-op casts so we just eliminate the cast. 8056 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) 8057 return getSCEV(U->getOperand(0)); 8058 break; 8059 8060 case Instruction::PtrToInt: { 8061 // Pointer to integer cast is straight-forward, so do model it. 8062 const SCEV *Op = getSCEV(U->getOperand(0)); 8063 Type *DstIntTy = U->getType(); 8064 // But only if effective SCEV (integer) type is wide enough to represent 8065 // all possible pointer values. 8066 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy); 8067 if (isa<SCEVCouldNotCompute>(IntOp)) 8068 return getUnknown(V); 8069 return IntOp; 8070 } 8071 case Instruction::IntToPtr: 8072 // Just don't deal with inttoptr casts. 8073 return getUnknown(V); 8074 8075 case Instruction::SDiv: 8076 // If both operands are non-negative, this is just an udiv. 8077 if (isKnownNonNegative(getSCEV(U->getOperand(0))) && 8078 isKnownNonNegative(getSCEV(U->getOperand(1)))) 8079 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); 8080 break; 8081 8082 case Instruction::SRem: 8083 // If both operands are non-negative, this is just an urem. 8084 if (isKnownNonNegative(getSCEV(U->getOperand(0))) && 8085 isKnownNonNegative(getSCEV(U->getOperand(1)))) 8086 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); 8087 break; 8088 8089 case Instruction::GetElementPtr: 8090 return createNodeForGEP(cast<GEPOperator>(U)); 8091 8092 case Instruction::PHI: 8093 return createNodeForPHI(cast<PHINode>(U)); 8094 8095 case Instruction::Select: 8096 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1), 8097 U->getOperand(2)); 8098 8099 case Instruction::Call: 8100 case Instruction::Invoke: 8101 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) 8102 return getSCEV(RV); 8103 8104 if (auto *II = dyn_cast<IntrinsicInst>(U)) { 8105 switch (II->getIntrinsicID()) { 8106 case Intrinsic::abs: 8107 return getAbsExpr( 8108 getSCEV(II->getArgOperand(0)), 8109 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne()); 8110 case Intrinsic::umax: 8111 LHS = getSCEV(II->getArgOperand(0)); 8112 RHS = getSCEV(II->getArgOperand(1)); 8113 return getUMaxExpr(LHS, RHS); 8114 case Intrinsic::umin: 8115 LHS = getSCEV(II->getArgOperand(0)); 8116 RHS = getSCEV(II->getArgOperand(1)); 8117 return getUMinExpr(LHS, RHS); 8118 case Intrinsic::smax: 8119 LHS = getSCEV(II->getArgOperand(0)); 8120 RHS = getSCEV(II->getArgOperand(1)); 8121 return getSMaxExpr(LHS, RHS); 8122 case Intrinsic::smin: 8123 LHS = getSCEV(II->getArgOperand(0)); 8124 RHS = getSCEV(II->getArgOperand(1)); 8125 return getSMinExpr(LHS, RHS); 8126 case Intrinsic::usub_sat: { 8127 const SCEV *X = getSCEV(II->getArgOperand(0)); 8128 const SCEV *Y = getSCEV(II->getArgOperand(1)); 8129 const SCEV *ClampedY = getUMinExpr(X, Y); 8130 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW); 8131 } 8132 case Intrinsic::uadd_sat: { 8133 const SCEV *X = getSCEV(II->getArgOperand(0)); 8134 const SCEV *Y = getSCEV(II->getArgOperand(1)); 8135 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y)); 8136 return getAddExpr(ClampedX, Y, SCEV::FlagNUW); 8137 } 8138 case Intrinsic::start_loop_iterations: 8139 case Intrinsic::annotation: 8140 case Intrinsic::ptr_annotation: 8141 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is 8142 // just eqivalent to the first operand for SCEV purposes. 8143 return getSCEV(II->getArgOperand(0)); 8144 case Intrinsic::vscale: 8145 return getVScale(II->getType()); 8146 default: 8147 break; 8148 } 8149 } 8150 break; 8151 } 8152 8153 return getUnknown(V); 8154 } 8155 8156 //===----------------------------------------------------------------------===// 8157 // Iteration Count Computation Code 8158 // 8159 8160 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { 8161 if (isa<SCEVCouldNotCompute>(ExitCount)) 8162 return getCouldNotCompute(); 8163 8164 auto *ExitCountType = ExitCount->getType(); 8165 assert(ExitCountType->isIntegerTy()); 8166 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(), 8167 1 + ExitCountType->getScalarSizeInBits()); 8168 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr); 8169 } 8170 8171 const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, 8172 Type *EvalTy, 8173 const Loop *L) { 8174 if (isa<SCEVCouldNotCompute>(ExitCount)) 8175 return getCouldNotCompute(); 8176 8177 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType()); 8178 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits(); 8179 8180 auto CanAddOneWithoutOverflow = [&]() { 8181 ConstantRange ExitCountRange = 8182 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED); 8183 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize))) 8184 return true; 8185 8186 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount, 8187 getMinusOne(ExitCount->getType())); 8188 }; 8189 8190 // If we need to zero extend the backedge count, check if we can add one to 8191 // it prior to zero extending without overflow. Provided this is safe, it 8192 // allows better simplification of the +1. 8193 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow()) 8194 return getZeroExtendExpr( 8195 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy); 8196 8197 // Get the total trip count from the count by adding 1. This may wrap. 8198 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy)); 8199 } 8200 8201 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { 8202 if (!ExitCount) 8203 return 0; 8204 8205 ConstantInt *ExitConst = ExitCount->getValue(); 8206 8207 // Guard against huge trip counts. 8208 if (ExitConst->getValue().getActiveBits() > 32) 8209 return 0; 8210 8211 // In case of integer overflow, this returns 0, which is correct. 8212 return ((unsigned)ExitConst->getZExtValue()) + 1; 8213 } 8214 8215 unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { 8216 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact)); 8217 return getConstantTripCount(ExitCount); 8218 } 8219 8220 unsigned 8221 ScalarEvolution::getSmallConstantTripCount(const Loop *L, 8222 const BasicBlock *ExitingBlock) { 8223 assert(ExitingBlock && "Must pass a non-null exiting block!"); 8224 assert(L->isLoopExiting(ExitingBlock) && 8225 "Exiting block must actually branch out of the loop!"); 8226 const SCEVConstant *ExitCount = 8227 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); 8228 return getConstantTripCount(ExitCount); 8229 } 8230 8231 unsigned ScalarEvolution::getSmallConstantMaxTripCount( 8232 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) { 8233 8234 const auto *MaxExitCount = 8235 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates) 8236 : getConstantMaxBackedgeTakenCount(L); 8237 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount)); 8238 } 8239 8240 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { 8241 SmallVector<BasicBlock *, 8> ExitingBlocks; 8242 L->getExitingBlocks(ExitingBlocks); 8243 8244 std::optional<unsigned> Res; 8245 for (auto *ExitingBB : ExitingBlocks) { 8246 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB); 8247 if (!Res) 8248 Res = Multiple; 8249 Res = std::gcd(*Res, Multiple); 8250 } 8251 return Res.value_or(1); 8252 } 8253 8254 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, 8255 const SCEV *ExitCount) { 8256 if (isa<SCEVCouldNotCompute>(ExitCount)) 8257 return 1; 8258 8259 // Get the trip count 8260 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); 8261 8262 APInt Multiple = getNonZeroConstantMultiple(TCExpr); 8263 // If a trip multiple is huge (>=2^32), the trip count is still divisible by 8264 // the greatest power of 2 divisor less than 2^32. 8265 return Multiple.getActiveBits() > 32 8266 ? 1U << std::min(31U, Multiple.countTrailingZeros()) 8267 : (unsigned)Multiple.getZExtValue(); 8268 } 8269 8270 /// Returns the largest constant divisor of the trip count of this loop as a 8271 /// normal unsigned value, if possible. This means that the actual trip count is 8272 /// always a multiple of the returned value (don't forget the trip count could 8273 /// very well be zero as well!). 8274 /// 8275 /// Returns 1 if the trip count is unknown or not guaranteed to be the 8276 /// multiple of a constant (which is also the case if the trip count is simply 8277 /// constant, use getSmallConstantTripCount for that case), Will also return 1 8278 /// if the trip count is very large (>= 2^32). 8279 /// 8280 /// As explained in the comments for getSmallConstantTripCount, this assumes 8281 /// that control exits the loop via ExitingBlock. 8282 unsigned 8283 ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, 8284 const BasicBlock *ExitingBlock) { 8285 assert(ExitingBlock && "Must pass a non-null exiting block!"); 8286 assert(L->isLoopExiting(ExitingBlock) && 8287 "Exiting block must actually branch out of the loop!"); 8288 const SCEV *ExitCount = getExitCount(L, ExitingBlock); 8289 return getSmallConstantTripMultiple(L, ExitCount); 8290 } 8291 8292 const SCEV *ScalarEvolution::getExitCount(const Loop *L, 8293 const BasicBlock *ExitingBlock, 8294 ExitCountKind Kind) { 8295 switch (Kind) { 8296 case Exact: 8297 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); 8298 case SymbolicMaximum: 8299 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this); 8300 case ConstantMaximum: 8301 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this); 8302 }; 8303 llvm_unreachable("Invalid ExitCountKind!"); 8304 } 8305 8306 const SCEV *ScalarEvolution::getPredicatedExitCount( 8307 const Loop *L, const BasicBlock *ExitingBlock, 8308 SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) { 8309 switch (Kind) { 8310 case Exact: 8311 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this, 8312 Predicates); 8313 case SymbolicMaximum: 8314 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this, 8315 Predicates); 8316 case ConstantMaximum: 8317 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this, 8318 Predicates); 8319 }; 8320 llvm_unreachable("Invalid ExitCountKind!"); 8321 } 8322 8323 const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount( 8324 const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) { 8325 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); 8326 } 8327 8328 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, 8329 ExitCountKind Kind) { 8330 switch (Kind) { 8331 case Exact: 8332 return getBackedgeTakenInfo(L).getExact(L, this); 8333 case ConstantMaximum: 8334 return getBackedgeTakenInfo(L).getConstantMax(this); 8335 case SymbolicMaximum: 8336 return getBackedgeTakenInfo(L).getSymbolicMax(L, this); 8337 }; 8338 llvm_unreachable("Invalid ExitCountKind!"); 8339 } 8340 8341 const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount( 8342 const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) { 8343 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds); 8344 } 8345 8346 const SCEV *ScalarEvolution::getPredicatedConstantMaxBackedgeTakenCount( 8347 const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) { 8348 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds); 8349 } 8350 8351 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { 8352 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this); 8353 } 8354 8355 /// Push PHI nodes in the header of the given loop onto the given Worklist. 8356 static void PushLoopPHIs(const Loop *L, 8357 SmallVectorImpl<Instruction *> &Worklist, 8358 SmallPtrSetImpl<Instruction *> &Visited) { 8359 BasicBlock *Header = L->getHeader(); 8360 8361 // Push all Loop-header PHIs onto the Worklist stack. 8362 for (PHINode &PN : Header->phis()) 8363 if (Visited.insert(&PN).second) 8364 Worklist.push_back(&PN); 8365 } 8366 8367 ScalarEvolution::BackedgeTakenInfo & 8368 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { 8369 auto &BTI = getBackedgeTakenInfo(L); 8370 if (BTI.hasFullInfo()) 8371 return BTI; 8372 8373 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L); 8374 8375 if (!Pair.second) 8376 return Pair.first->second; 8377 8378 BackedgeTakenInfo Result = 8379 computeBackedgeTakenCount(L, /*AllowPredicates=*/true); 8380 8381 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); 8382 } 8383 8384 ScalarEvolution::BackedgeTakenInfo & 8385 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { 8386 // Initially insert an invalid entry for this loop. If the insertion 8387 // succeeds, proceed to actually compute a backedge-taken count and 8388 // update the value. The temporary CouldNotCompute value tells SCEV 8389 // code elsewhere that it shouldn't attempt to request a new 8390 // backedge-taken count, which could result in infinite recursion. 8391 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = 8392 BackedgeTakenCounts.try_emplace(L); 8393 if (!Pair.second) 8394 return Pair.first->second; 8395 8396 // computeBackedgeTakenCount may allocate memory for its result. Inserting it 8397 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result 8398 // must be cleared in this scope. 8399 BackedgeTakenInfo Result = computeBackedgeTakenCount(L); 8400 8401 // Now that we know more about the trip count for this loop, forget any 8402 // existing SCEV values for PHI nodes in this loop since they are only 8403 // conservative estimates made without the benefit of trip count 8404 // information. This invalidation is not necessary for correctness, and is 8405 // only done to produce more precise results. 8406 if (Result.hasAnyInfo()) { 8407 // Invalidate any expression using an addrec in this loop. 8408 SmallVector<const SCEV *, 8> ToForget; 8409 auto LoopUsersIt = LoopUsers.find(L); 8410 if (LoopUsersIt != LoopUsers.end()) 8411 append_range(ToForget, LoopUsersIt->second); 8412 forgetMemoizedResults(ToForget); 8413 8414 // Invalidate constant-evolved loop header phis. 8415 for (PHINode &PN : L->getHeader()->phis()) 8416 ConstantEvolutionLoopExitValue.erase(&PN); 8417 } 8418 8419 // Re-lookup the insert position, since the call to 8420 // computeBackedgeTakenCount above could result in a 8421 // recusive call to getBackedgeTakenInfo (on a different 8422 // loop), which would invalidate the iterator computed 8423 // earlier. 8424 return BackedgeTakenCounts.find(L)->second = std::move(Result); 8425 } 8426 8427 void ScalarEvolution::forgetAllLoops() { 8428 // This method is intended to forget all info about loops. It should 8429 // invalidate caches as if the following happened: 8430 // - The trip counts of all loops have changed arbitrarily 8431 // - Every llvm::Value has been updated in place to produce a different 8432 // result. 8433 BackedgeTakenCounts.clear(); 8434 PredicatedBackedgeTakenCounts.clear(); 8435 BECountUsers.clear(); 8436 LoopPropertiesCache.clear(); 8437 ConstantEvolutionLoopExitValue.clear(); 8438 ValueExprMap.clear(); 8439 ValuesAtScopes.clear(); 8440 ValuesAtScopesUsers.clear(); 8441 LoopDispositions.clear(); 8442 BlockDispositions.clear(); 8443 UnsignedRanges.clear(); 8444 SignedRanges.clear(); 8445 ExprValueMap.clear(); 8446 HasRecMap.clear(); 8447 ConstantMultipleCache.clear(); 8448 PredicatedSCEVRewrites.clear(); 8449 FoldCache.clear(); 8450 FoldCacheUser.clear(); 8451 } 8452 void ScalarEvolution::visitAndClearUsers( 8453 SmallVectorImpl<Instruction *> &Worklist, 8454 SmallPtrSetImpl<Instruction *> &Visited, 8455 SmallVectorImpl<const SCEV *> &ToForget) { 8456 while (!Worklist.empty()) { 8457 Instruction *I = Worklist.pop_back_val(); 8458 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I)) 8459 continue; 8460 8461 ValueExprMapType::iterator It = 8462 ValueExprMap.find_as(static_cast<Value *>(I)); 8463 if (It != ValueExprMap.end()) { 8464 eraseValueFromMap(It->first); 8465 ToForget.push_back(It->second); 8466 if (PHINode *PN = dyn_cast<PHINode>(I)) 8467 ConstantEvolutionLoopExitValue.erase(PN); 8468 } 8469 8470 PushDefUseChildren(I, Worklist, Visited); 8471 } 8472 } 8473 8474 void ScalarEvolution::forgetLoop(const Loop *L) { 8475 SmallVector<const Loop *, 16> LoopWorklist(1, L); 8476 SmallVector<Instruction *, 32> Worklist; 8477 SmallPtrSet<Instruction *, 16> Visited; 8478 SmallVector<const SCEV *, 16> ToForget; 8479 8480 // Iterate over all the loops and sub-loops to drop SCEV information. 8481 while (!LoopWorklist.empty()) { 8482 auto *CurrL = LoopWorklist.pop_back_val(); 8483 8484 // Drop any stored trip count value. 8485 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false); 8486 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true); 8487 8488 // Drop information about predicated SCEV rewrites for this loop. 8489 for (auto I = PredicatedSCEVRewrites.begin(); 8490 I != PredicatedSCEVRewrites.end();) { 8491 std::pair<const SCEV *, const Loop *> Entry = I->first; 8492 if (Entry.second == CurrL) 8493 PredicatedSCEVRewrites.erase(I++); 8494 else 8495 ++I; 8496 } 8497 8498 auto LoopUsersItr = LoopUsers.find(CurrL); 8499 if (LoopUsersItr != LoopUsers.end()) 8500 llvm::append_range(ToForget, LoopUsersItr->second); 8501 8502 // Drop information about expressions based on loop-header PHIs. 8503 PushLoopPHIs(CurrL, Worklist, Visited); 8504 visitAndClearUsers(Worklist, Visited, ToForget); 8505 8506 LoopPropertiesCache.erase(CurrL); 8507 // Forget all contained loops too, to avoid dangling entries in the 8508 // ValuesAtScopes map. 8509 LoopWorklist.append(CurrL->begin(), CurrL->end()); 8510 } 8511 forgetMemoizedResults(ToForget); 8512 } 8513 8514 void ScalarEvolution::forgetTopmostLoop(const Loop *L) { 8515 forgetLoop(L->getOutermostLoop()); 8516 } 8517 8518 void ScalarEvolution::forgetValue(Value *V) { 8519 Instruction *I = dyn_cast<Instruction>(V); 8520 if (!I) return; 8521 8522 // Drop information about expressions based on loop-header PHIs. 8523 SmallVector<Instruction *, 16> Worklist; 8524 SmallPtrSet<Instruction *, 8> Visited; 8525 SmallVector<const SCEV *, 8> ToForget; 8526 Worklist.push_back(I); 8527 Visited.insert(I); 8528 visitAndClearUsers(Worklist, Visited, ToForget); 8529 8530 forgetMemoizedResults(ToForget); 8531 } 8532 8533 void ScalarEvolution::forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V) { 8534 if (!isSCEVable(V->getType())) 8535 return; 8536 8537 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's 8538 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an 8539 // extra predecessor is added, this is no longer valid. Find all Unknowns and 8540 // AddRecs defined in the loop and invalidate any SCEV's making use of them. 8541 if (const SCEV *S = getExistingSCEV(V)) { 8542 struct InvalidationRootCollector { 8543 Loop *L; 8544 SmallVector<const SCEV *, 8> Roots; 8545 8546 InvalidationRootCollector(Loop *L) : L(L) {} 8547 8548 bool follow(const SCEV *S) { 8549 if (auto *SU = dyn_cast<SCEVUnknown>(S)) { 8550 if (auto *I = dyn_cast<Instruction>(SU->getValue())) 8551 if (L->contains(I)) 8552 Roots.push_back(S); 8553 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 8554 if (L->contains(AddRec->getLoop())) 8555 Roots.push_back(S); 8556 } 8557 return true; 8558 } 8559 bool isDone() const { return false; } 8560 }; 8561 8562 InvalidationRootCollector C(L); 8563 visitAll(S, C); 8564 forgetMemoizedResults(C.Roots); 8565 } 8566 8567 // Also perform the normal invalidation. 8568 forgetValue(V); 8569 } 8570 8571 void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); } 8572 8573 void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { 8574 // Unless a specific value is passed to invalidation, completely clear both 8575 // caches. 8576 if (!V) { 8577 BlockDispositions.clear(); 8578 LoopDispositions.clear(); 8579 return; 8580 } 8581 8582 if (!isSCEVable(V->getType())) 8583 return; 8584 8585 const SCEV *S = getExistingSCEV(V); 8586 if (!S) 8587 return; 8588 8589 // Invalidate the block and loop dispositions cached for S. Dispositions of 8590 // S's users may change if S's disposition changes (i.e. a user may change to 8591 // loop-invariant, if S changes to loop invariant), so also invalidate 8592 // dispositions of S's users recursively. 8593 SmallVector<const SCEV *, 8> Worklist = {S}; 8594 SmallPtrSet<const SCEV *, 8> Seen = {S}; 8595 while (!Worklist.empty()) { 8596 const SCEV *Curr = Worklist.pop_back_val(); 8597 bool LoopDispoRemoved = LoopDispositions.erase(Curr); 8598 bool BlockDispoRemoved = BlockDispositions.erase(Curr); 8599 if (!LoopDispoRemoved && !BlockDispoRemoved) 8600 continue; 8601 auto Users = SCEVUsers.find(Curr); 8602 if (Users != SCEVUsers.end()) 8603 for (const auto *User : Users->second) 8604 if (Seen.insert(User).second) 8605 Worklist.push_back(User); 8606 } 8607 } 8608 8609 /// Get the exact loop backedge taken count considering all loop exits. A 8610 /// computable result can only be returned for loops with all exiting blocks 8611 /// dominating the latch. howFarToZero assumes that the limit of each loop test 8612 /// is never skipped. This is a valid assumption as long as the loop exits via 8613 /// that test. For precise results, it is the caller's responsibility to specify 8614 /// the relevant loop exiting block using getExact(ExitingBlock, SE). 8615 const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( 8616 const Loop *L, ScalarEvolution *SE, 8617 SmallVectorImpl<const SCEVPredicate *> *Preds) const { 8618 // If any exits were not computable, the loop is not computable. 8619 if (!isComplete() || ExitNotTaken.empty()) 8620 return SE->getCouldNotCompute(); 8621 8622 const BasicBlock *Latch = L->getLoopLatch(); 8623 // All exiting blocks we have collected must dominate the only backedge. 8624 if (!Latch) 8625 return SE->getCouldNotCompute(); 8626 8627 // All exiting blocks we have gathered dominate loop's latch, so exact trip 8628 // count is simply a minimum out of all these calculated exit counts. 8629 SmallVector<const SCEV *, 2> Ops; 8630 for (const auto &ENT : ExitNotTaken) { 8631 const SCEV *BECount = ENT.ExactNotTaken; 8632 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); 8633 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && 8634 "We should only have known counts for exiting blocks that dominate " 8635 "latch!"); 8636 8637 Ops.push_back(BECount); 8638 8639 if (Preds) 8640 append_range(*Preds, ENT.Predicates); 8641 8642 assert((Preds || ENT.hasAlwaysTruePredicate()) && 8643 "Predicate should be always true!"); 8644 } 8645 8646 // If an earlier exit exits on the first iteration (exit count zero), then 8647 // a later poison exit count should not propagate into the result. This are 8648 // exactly the semantics provided by umin_seq. 8649 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true); 8650 } 8651 8652 const ScalarEvolution::ExitNotTakenInfo * 8653 ScalarEvolution::BackedgeTakenInfo::getExitNotTaken( 8654 const BasicBlock *ExitingBlock, 8655 SmallVectorImpl<const SCEVPredicate *> *Predicates) const { 8656 for (const auto &ENT : ExitNotTaken) 8657 if (ENT.ExitingBlock == ExitingBlock) { 8658 if (ENT.hasAlwaysTruePredicate()) 8659 return &ENT; 8660 else if (Predicates) { 8661 append_range(*Predicates, ENT.Predicates); 8662 return &ENT; 8663 } 8664 } 8665 8666 return nullptr; 8667 } 8668 8669 /// getConstantMax - Get the constant max backedge taken count for the loop. 8670 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( 8671 ScalarEvolution *SE, 8672 SmallVectorImpl<const SCEVPredicate *> *Predicates) const { 8673 if (!getConstantMax()) 8674 return SE->getCouldNotCompute(); 8675 8676 for (const auto &ENT : ExitNotTaken) 8677 if (!ENT.hasAlwaysTruePredicate()) { 8678 if (!Predicates) 8679 return SE->getCouldNotCompute(); 8680 append_range(*Predicates, ENT.Predicates); 8681 } 8682 8683 assert((isa<SCEVCouldNotCompute>(getConstantMax()) || 8684 isa<SCEVConstant>(getConstantMax())) && 8685 "No point in having a non-constant max backedge taken count!"); 8686 return getConstantMax(); 8687 } 8688 8689 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( 8690 const Loop *L, ScalarEvolution *SE, 8691 SmallVectorImpl<const SCEVPredicate *> *Predicates) { 8692 if (!SymbolicMax) { 8693 // Form an expression for the maximum exit count possible for this loop. We 8694 // merge the max and exact information to approximate a version of 8695 // getConstantMaxBackedgeTakenCount which isn't restricted to just 8696 // constants. 8697 SmallVector<const SCEV *, 4> ExitCounts; 8698 8699 for (const auto &ENT : ExitNotTaken) { 8700 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken; 8701 if (!isa<SCEVCouldNotCompute>(ExitCount)) { 8702 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) && 8703 "We should only have known counts for exiting blocks that " 8704 "dominate latch!"); 8705 ExitCounts.push_back(ExitCount); 8706 if (Predicates) 8707 append_range(*Predicates, ENT.Predicates); 8708 8709 assert((Predicates || ENT.hasAlwaysTruePredicate()) && 8710 "Predicate should be always true!"); 8711 } 8712 } 8713 if (ExitCounts.empty()) 8714 SymbolicMax = SE->getCouldNotCompute(); 8715 else 8716 SymbolicMax = 8717 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true); 8718 } 8719 return SymbolicMax; 8720 } 8721 8722 bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( 8723 ScalarEvolution *SE) const { 8724 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { 8725 return !ENT.hasAlwaysTruePredicate(); 8726 }; 8727 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); 8728 } 8729 8730 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) 8731 : ExitLimit(E, E, E, false) {} 8732 8733 ScalarEvolution::ExitLimit::ExitLimit( 8734 const SCEV *E, const SCEV *ConstantMaxNotTaken, 8735 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, 8736 ArrayRef<ArrayRef<const SCEVPredicate *>> PredLists) 8737 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken), 8738 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) { 8739 // If we prove the max count is zero, so is the symbolic bound. This happens 8740 // in practice due to differences in a) how context sensitive we've chosen 8741 // to be and b) how we reason about bounds implied by UB. 8742 if (ConstantMaxNotTaken->isZero()) { 8743 this->ExactNotTaken = E = ConstantMaxNotTaken; 8744 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken; 8745 } 8746 8747 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || 8748 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) && 8749 "Exact is not allowed to be less precise than Constant Max"); 8750 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || 8751 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) && 8752 "Exact is not allowed to be less precise than Symbolic Max"); 8753 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) || 8754 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) && 8755 "Symbolic Max is not allowed to be less precise than Constant Max"); 8756 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) || 8757 isa<SCEVConstant>(ConstantMaxNotTaken)) && 8758 "No point in having a non-constant max backedge taken count!"); 8759 SmallPtrSet<const SCEVPredicate *, 4> SeenPreds; 8760 for (const auto PredList : PredLists) 8761 for (const auto *P : PredList) { 8762 if (SeenPreds.contains(P)) 8763 continue; 8764 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!"); 8765 SeenPreds.insert(P); 8766 Predicates.push_back(P); 8767 } 8768 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) && 8769 "Backedge count should be int"); 8770 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) || 8771 !ConstantMaxNotTaken->getType()->isPointerTy()) && 8772 "Max backedge count should be int"); 8773 } 8774 8775 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, 8776 const SCEV *ConstantMaxNotTaken, 8777 const SCEV *SymbolicMaxNotTaken, 8778 bool MaxOrZero, 8779 ArrayRef<const SCEVPredicate *> PredList) 8780 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero, 8781 ArrayRef({PredList})) {} 8782 8783 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each 8784 /// computable exit into a persistent ExitNotTakenInfo array. 8785 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( 8786 ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts, 8787 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero) 8788 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) { 8789 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; 8790 8791 ExitNotTaken.reserve(ExitCounts.size()); 8792 std::transform(ExitCounts.begin(), ExitCounts.end(), 8793 std::back_inserter(ExitNotTaken), 8794 [&](const EdgeExitInfo &EEI) { 8795 BasicBlock *ExitBB = EEI.first; 8796 const ExitLimit &EL = EEI.second; 8797 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, 8798 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken, 8799 EL.Predicates); 8800 }); 8801 assert((isa<SCEVCouldNotCompute>(ConstantMax) || 8802 isa<SCEVConstant>(ConstantMax)) && 8803 "No point in having a non-constant max backedge taken count!"); 8804 } 8805 8806 /// Compute the number of times the backedge of the specified loop will execute. 8807 ScalarEvolution::BackedgeTakenInfo 8808 ScalarEvolution::computeBackedgeTakenCount(const Loop *L, 8809 bool AllowPredicates) { 8810 SmallVector<BasicBlock *, 8> ExitingBlocks; 8811 L->getExitingBlocks(ExitingBlocks); 8812 8813 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; 8814 8815 SmallVector<EdgeExitInfo, 4> ExitCounts; 8816 bool CouldComputeBECount = true; 8817 BasicBlock *Latch = L->getLoopLatch(); // may be NULL. 8818 const SCEV *MustExitMaxBECount = nullptr; 8819 const SCEV *MayExitMaxBECount = nullptr; 8820 bool MustExitMaxOrZero = false; 8821 bool IsOnlyExit = ExitingBlocks.size() == 1; 8822 8823 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts 8824 // and compute maxBECount. 8825 // Do a union of all the predicates here. 8826 for (BasicBlock *ExitBB : ExitingBlocks) { 8827 // We canonicalize untaken exits to br (constant), ignore them so that 8828 // proving an exit untaken doesn't negatively impact our ability to reason 8829 // about the loop as whole. 8830 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator())) 8831 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) { 8832 bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); 8833 if (ExitIfTrue == CI->isZero()) 8834 continue; 8835 } 8836 8837 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates); 8838 8839 assert((AllowPredicates || EL.Predicates.empty()) && 8840 "Predicated exit limit when predicates are not allowed!"); 8841 8842 // 1. For each exit that can be computed, add an entry to ExitCounts. 8843 // CouldComputeBECount is true only if all exits can be computed. 8844 if (EL.ExactNotTaken != getCouldNotCompute()) 8845 ++NumExitCountsComputed; 8846 else 8847 // We couldn't compute an exact value for this exit, so 8848 // we won't be able to compute an exact value for the loop. 8849 CouldComputeBECount = false; 8850 // Remember exit count if either exact or symbolic is known. Because 8851 // Exact always implies symbolic, only check symbolic. 8852 if (EL.SymbolicMaxNotTaken != getCouldNotCompute()) 8853 ExitCounts.emplace_back(ExitBB, EL); 8854 else { 8855 assert(EL.ExactNotTaken == getCouldNotCompute() && 8856 "Exact is known but symbolic isn't?"); 8857 ++NumExitCountsNotComputed; 8858 } 8859 8860 // 2. Derive the loop's MaxBECount from each exit's max number of 8861 // non-exiting iterations. Partition the loop exits into two kinds: 8862 // LoopMustExits and LoopMayExits. 8863 // 8864 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it 8865 // is a LoopMayExit. If any computable LoopMustExit is found, then 8866 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable 8867 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum 8868 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than 8869 // any 8870 // computable EL.ConstantMaxNotTaken. 8871 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch && 8872 DT.dominates(ExitBB, Latch)) { 8873 if (!MustExitMaxBECount) { 8874 MustExitMaxBECount = EL.ConstantMaxNotTaken; 8875 MustExitMaxOrZero = EL.MaxOrZero; 8876 } else { 8877 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount, 8878 EL.ConstantMaxNotTaken); 8879 } 8880 } else if (MayExitMaxBECount != getCouldNotCompute()) { 8881 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute()) 8882 MayExitMaxBECount = EL.ConstantMaxNotTaken; 8883 else { 8884 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount, 8885 EL.ConstantMaxNotTaken); 8886 } 8887 } 8888 } 8889 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : 8890 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); 8891 // The loop backedge will be taken the maximum or zero times if there's 8892 // a single exit that must be taken the maximum or zero times. 8893 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); 8894 8895 // Remember which SCEVs are used in exit limits for invalidation purposes. 8896 // We only care about non-constant SCEVs here, so we can ignore 8897 // EL.ConstantMaxNotTaken 8898 // and MaxBECount, which must be SCEVConstant. 8899 for (const auto &Pair : ExitCounts) { 8900 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken)) 8901 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates}); 8902 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken)) 8903 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert( 8904 {L, AllowPredicates}); 8905 } 8906 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, 8907 MaxBECount, MaxOrZero); 8908 } 8909 8910 ScalarEvolution::ExitLimit 8911 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, 8912 bool IsOnlyExit, bool AllowPredicates) { 8913 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); 8914 // If our exiting block does not dominate the latch, then its connection with 8915 // loop's exit limit may be far from trivial. 8916 const BasicBlock *Latch = L->getLoopLatch(); 8917 if (!Latch || !DT.dominates(ExitingBlock, Latch)) 8918 return getCouldNotCompute(); 8919 8920 Instruction *Term = ExitingBlock->getTerminator(); 8921 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { 8922 assert(BI->isConditional() && "If unconditional, it can't be in loop!"); 8923 bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); 8924 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && 8925 "It should have one successor in loop and one exit block!"); 8926 // Proceed to the next level to examine the exit condition expression. 8927 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, 8928 /*ControlsOnlyExit=*/IsOnlyExit, 8929 AllowPredicates); 8930 } 8931 8932 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { 8933 // For switch, make sure that there is a single exit from the loop. 8934 BasicBlock *Exit = nullptr; 8935 for (auto *SBB : successors(ExitingBlock)) 8936 if (!L->contains(SBB)) { 8937 if (Exit) // Multiple exit successors. 8938 return getCouldNotCompute(); 8939 Exit = SBB; 8940 } 8941 assert(Exit && "Exiting block must have at least one exit"); 8942 return computeExitLimitFromSingleExitSwitch( 8943 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit); 8944 } 8945 8946 return getCouldNotCompute(); 8947 } 8948 8949 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( 8950 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, 8951 bool AllowPredicates) { 8952 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); 8953 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, 8954 ControlsOnlyExit, AllowPredicates); 8955 } 8956 8957 std::optional<ScalarEvolution::ExitLimit> 8958 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, 8959 bool ExitIfTrue, bool ControlsOnlyExit, 8960 bool AllowPredicates) { 8961 (void)this->L; 8962 (void)this->ExitIfTrue; 8963 (void)this->AllowPredicates; 8964 8965 assert(this->L == L && this->ExitIfTrue == ExitIfTrue && 8966 this->AllowPredicates == AllowPredicates && 8967 "Variance in assumed invariant key components!"); 8968 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit}); 8969 if (Itr == TripCountMap.end()) 8970 return std::nullopt; 8971 return Itr->second; 8972 } 8973 8974 void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, 8975 bool ExitIfTrue, 8976 bool ControlsOnlyExit, 8977 bool AllowPredicates, 8978 const ExitLimit &EL) { 8979 assert(this->L == L && this->ExitIfTrue == ExitIfTrue && 8980 this->AllowPredicates == AllowPredicates && 8981 "Variance in assumed invariant key components!"); 8982 8983 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL}); 8984 assert(InsertResult.second && "Expected successful insertion!"); 8985 (void)InsertResult; 8986 (void)ExitIfTrue; 8987 } 8988 8989 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( 8990 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, 8991 bool ControlsOnlyExit, bool AllowPredicates) { 8992 8993 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit, 8994 AllowPredicates)) 8995 return *MaybeEL; 8996 8997 ExitLimit EL = computeExitLimitFromCondImpl( 8998 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates); 8999 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL); 9000 return EL; 9001 } 9002 9003 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( 9004 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, 9005 bool ControlsOnlyExit, bool AllowPredicates) { 9006 // Handle BinOp conditions (And, Or). 9007 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( 9008 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) 9009 return *LimitFromBinOp; 9010 9011 // With an icmp, it may be feasible to compute an exact backedge-taken count. 9012 // Proceed to the next level to examine the icmp. 9013 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { 9014 ExitLimit EL = 9015 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit); 9016 if (EL.hasFullInfo() || !AllowPredicates) 9017 return EL; 9018 9019 // Try again, but use SCEV predicates this time. 9020 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, 9021 ControlsOnlyExit, 9022 /*AllowPredicates=*/true); 9023 } 9024 9025 // Check for a constant condition. These are normally stripped out by 9026 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to 9027 // preserve the CFG and is temporarily leaving constant conditions 9028 // in place. 9029 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { 9030 if (ExitIfTrue == !CI->getZExtValue()) 9031 // The backedge is always taken. 9032 return getCouldNotCompute(); 9033 // The backedge is never taken. 9034 return getZero(CI->getType()); 9035 } 9036 9037 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic 9038 // with a constant step, we can form an equivalent icmp predicate and figure 9039 // out how many iterations will be taken before we exit. 9040 const WithOverflowInst *WO; 9041 const APInt *C; 9042 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && 9043 match(WO->getRHS(), m_APInt(C))) { 9044 ConstantRange NWR = 9045 ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, 9046 WO->getNoWrapKind()); 9047 CmpInst::Predicate Pred; 9048 APInt NewRHSC, Offset; 9049 NWR.getEquivalentICmp(Pred, NewRHSC, Offset); 9050 if (!ExitIfTrue) 9051 Pred = ICmpInst::getInversePredicate(Pred); 9052 auto *LHS = getSCEV(WO->getLHS()); 9053 if (Offset != 0) 9054 LHS = getAddExpr(LHS, getConstant(Offset)); 9055 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), 9056 ControlsOnlyExit, AllowPredicates); 9057 if (EL.hasAnyInfo()) 9058 return EL; 9059 } 9060 9061 // If it's not an integer or pointer comparison then compute it the hard way. 9062 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); 9063 } 9064 9065 std::optional<ScalarEvolution::ExitLimit> 9066 ScalarEvolution::computeExitLimitFromCondFromBinOp( 9067 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, 9068 bool ControlsOnlyExit, bool AllowPredicates) { 9069 // Check if the controlling expression for this loop is an And or Or. 9070 Value *Op0, *Op1; 9071 bool IsAnd = false; 9072 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) 9073 IsAnd = true; 9074 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) 9075 IsAnd = false; 9076 else 9077 return std::nullopt; 9078 9079 // EitherMayExit is true in these two cases: 9080 // br (and Op0 Op1), loop, exit 9081 // br (or Op0 Op1), exit, loop 9082 bool EitherMayExit = IsAnd ^ ExitIfTrue; 9083 ExitLimit EL0 = computeExitLimitFromCondCached( 9084 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, 9085 AllowPredicates); 9086 ExitLimit EL1 = computeExitLimitFromCondCached( 9087 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, 9088 AllowPredicates); 9089 9090 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement" 9091 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd); 9092 if (isa<ConstantInt>(Op1)) 9093 return Op1 == NeutralElement ? EL0 : EL1; 9094 if (isa<ConstantInt>(Op0)) 9095 return Op0 == NeutralElement ? EL1 : EL0; 9096 9097 const SCEV *BECount = getCouldNotCompute(); 9098 const SCEV *ConstantMaxBECount = getCouldNotCompute(); 9099 const SCEV *SymbolicMaxBECount = getCouldNotCompute(); 9100 if (EitherMayExit) { 9101 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond); 9102 // Both conditions must be same for the loop to continue executing. 9103 // Choose the less conservative count. 9104 if (EL0.ExactNotTaken != getCouldNotCompute() && 9105 EL1.ExactNotTaken != getCouldNotCompute()) { 9106 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, 9107 UseSequentialUMin); 9108 } 9109 if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) 9110 ConstantMaxBECount = EL1.ConstantMaxNotTaken; 9111 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) 9112 ConstantMaxBECount = EL0.ConstantMaxNotTaken; 9113 else 9114 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, 9115 EL1.ConstantMaxNotTaken); 9116 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute()) 9117 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken; 9118 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute()) 9119 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken; 9120 else 9121 SymbolicMaxBECount = getUMinFromMismatchedTypes( 9122 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin); 9123 } else { 9124 // Both conditions must be same at the same time for the loop to exit. 9125 // For now, be conservative. 9126 if (EL0.ExactNotTaken == EL1.ExactNotTaken) 9127 BECount = EL0.ExactNotTaken; 9128 } 9129 9130 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able 9131 // to be more aggressive when computing BECount than when computing 9132 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken 9133 // and 9134 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and 9135 // EL1.ConstantMaxNotTaken to not. 9136 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) && 9137 !isa<SCEVCouldNotCompute>(BECount)) 9138 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); 9139 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount)) 9140 SymbolicMaxBECount = 9141 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; 9142 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, 9143 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); 9144 } 9145 9146 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( 9147 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, 9148 bool AllowPredicates) { 9149 // If the condition was exit on true, convert the condition to exit on false 9150 CmpPredicate Pred; 9151 if (!ExitIfTrue) 9152 Pred = ExitCond->getCmpPredicate(); 9153 else 9154 Pred = ExitCond->getInverseCmpPredicate(); 9155 const ICmpInst::Predicate OriginalPred = Pred; 9156 9157 const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); 9158 const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); 9159 9160 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, 9161 AllowPredicates); 9162 if (EL.hasAnyInfo()) 9163 return EL; 9164 9165 auto *ExhaustiveCount = 9166 computeExitCountExhaustively(L, ExitCond, ExitIfTrue); 9167 9168 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) 9169 return ExhaustiveCount; 9170 9171 return computeShiftCompareExitLimit(ExitCond->getOperand(0), 9172 ExitCond->getOperand(1), L, OriginalPred); 9173 } 9174 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( 9175 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, 9176 bool ControlsOnlyExit, bool AllowPredicates) { 9177 9178 // Try to evaluate any dependencies out of the loop. 9179 LHS = getSCEVAtScope(LHS, L); 9180 RHS = getSCEVAtScope(RHS, L); 9181 9182 // At this point, we would like to compute how many iterations of the 9183 // loop the predicate will return true for these inputs. 9184 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { 9185 // If there is a loop-invariant, force it into the RHS. 9186 std::swap(LHS, RHS); 9187 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 9188 } 9189 9190 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) && 9191 loopIsFiniteByAssumption(L); 9192 // Simplify the operands before analyzing them. 9193 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0); 9194 9195 // If we have a comparison of a chrec against a constant, try to use value 9196 // ranges to answer this query. 9197 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) 9198 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) 9199 if (AddRec->getLoop() == L) { 9200 // Form the constant range. 9201 ConstantRange CompRange = 9202 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); 9203 9204 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); 9205 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; 9206 } 9207 9208 // If this loop must exit based on this condition (or execute undefined 9209 // behaviour), see if we can improve wrap flags. This is essentially 9210 // a must execute style proof. 9211 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) { 9212 // If we can prove the test sequence produced must repeat the same values 9213 // on self-wrap of the IV, then we can infer that IV doesn't self wrap 9214 // because if it did, we'd have an infinite (undefined) loop. 9215 // TODO: We can peel off any functions which are invertible *in L*. Loop 9216 // invariant terms are effectively constants for our purposes here. 9217 auto *InnerLHS = LHS; 9218 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) 9219 InnerLHS = ZExt->getOperand(); 9220 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS); 9221 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() && 9222 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true, 9223 /*OrNegative=*/true)) { 9224 auto Flags = AR->getNoWrapFlags(); 9225 Flags = setFlags(Flags, SCEV::FlagNW); 9226 SmallVector<const SCEV *> Operands{AR->operands()}; 9227 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); 9228 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags); 9229 } 9230 9231 // For a slt/ult condition with a positive step, can we prove nsw/nuw? 9232 // From no-self-wrap, this follows trivially from the fact that every 9233 // (un)signed-wrapped, but not self-wrapped value must be LT than the 9234 // last value before (un)signed wrap. Since we know that last value 9235 // didn't exit, nor will any smaller one. 9236 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) { 9237 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW; 9238 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS); 9239 AR && AR->getLoop() == L && AR->isAffine() && 9240 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() && 9241 isKnownPositive(AR->getStepRecurrence(*this))) { 9242 auto Flags = AR->getNoWrapFlags(); 9243 Flags = setFlags(Flags, WrapType); 9244 SmallVector<const SCEV*> Operands{AR->operands()}; 9245 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); 9246 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags); 9247 } 9248 } 9249 } 9250 9251 switch (Pred) { 9252 case ICmpInst::ICMP_NE: { // while (X != Y) 9253 // Convert to: while (X-Y != 0) 9254 if (LHS->getType()->isPointerTy()) { 9255 LHS = getLosslessPtrToIntExpr(LHS); 9256 if (isa<SCEVCouldNotCompute>(LHS)) 9257 return LHS; 9258 } 9259 if (RHS->getType()->isPointerTy()) { 9260 RHS = getLosslessPtrToIntExpr(RHS); 9261 if (isa<SCEVCouldNotCompute>(RHS)) 9262 return RHS; 9263 } 9264 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit, 9265 AllowPredicates); 9266 if (EL.hasAnyInfo()) 9267 return EL; 9268 break; 9269 } 9270 case ICmpInst::ICMP_EQ: { // while (X == Y) 9271 // Convert to: while (X-Y == 0) 9272 if (LHS->getType()->isPointerTy()) { 9273 LHS = getLosslessPtrToIntExpr(LHS); 9274 if (isa<SCEVCouldNotCompute>(LHS)) 9275 return LHS; 9276 } 9277 if (RHS->getType()->isPointerTy()) { 9278 RHS = getLosslessPtrToIntExpr(RHS); 9279 if (isa<SCEVCouldNotCompute>(RHS)) 9280 return RHS; 9281 } 9282 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); 9283 if (EL.hasAnyInfo()) return EL; 9284 break; 9285 } 9286 case ICmpInst::ICMP_SLE: 9287 case ICmpInst::ICMP_ULE: 9288 // Since the loop is finite, an invariant RHS cannot include the boundary 9289 // value, otherwise it would loop forever. 9290 if (!EnableFiniteLoopControl || !ControllingFiniteLoop || 9291 !isLoopInvariant(RHS, L)) { 9292 // Otherwise, perform the addition in a wider type, to avoid overflow. 9293 // If the LHS is an addrec with the appropriate nowrap flag, the 9294 // extension will be sunk into it and the exit count can be analyzed. 9295 auto *OldType = dyn_cast<IntegerType>(LHS->getType()); 9296 if (!OldType) 9297 break; 9298 // Prefer doubling the bitwidth over adding a single bit to make it more 9299 // likely that we use a legal type. 9300 auto *NewType = 9301 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2); 9302 if (ICmpInst::isSigned(Pred)) { 9303 LHS = getSignExtendExpr(LHS, NewType); 9304 RHS = getSignExtendExpr(RHS, NewType); 9305 } else { 9306 LHS = getZeroExtendExpr(LHS, NewType); 9307 RHS = getZeroExtendExpr(RHS, NewType); 9308 } 9309 } 9310 RHS = getAddExpr(getOne(RHS->getType()), RHS); 9311 [[fallthrough]]; 9312 case ICmpInst::ICMP_SLT: 9313 case ICmpInst::ICMP_ULT: { // while (X < Y) 9314 bool IsSigned = ICmpInst::isSigned(Pred); 9315 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, 9316 AllowPredicates); 9317 if (EL.hasAnyInfo()) 9318 return EL; 9319 break; 9320 } 9321 case ICmpInst::ICMP_SGE: 9322 case ICmpInst::ICMP_UGE: 9323 // Since the loop is finite, an invariant RHS cannot include the boundary 9324 // value, otherwise it would loop forever. 9325 if (!EnableFiniteLoopControl || !ControllingFiniteLoop || 9326 !isLoopInvariant(RHS, L)) 9327 break; 9328 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); 9329 [[fallthrough]]; 9330 case ICmpInst::ICMP_SGT: 9331 case ICmpInst::ICMP_UGT: { // while (X > Y) 9332 bool IsSigned = ICmpInst::isSigned(Pred); 9333 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, 9334 AllowPredicates); 9335 if (EL.hasAnyInfo()) 9336 return EL; 9337 break; 9338 } 9339 default: 9340 break; 9341 } 9342 9343 return getCouldNotCompute(); 9344 } 9345 9346 ScalarEvolution::ExitLimit 9347 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, 9348 SwitchInst *Switch, 9349 BasicBlock *ExitingBlock, 9350 bool ControlsOnlyExit) { 9351 assert(!L->contains(ExitingBlock) && "Not an exiting block!"); 9352 9353 // Give up if the exit is the default dest of a switch. 9354 if (Switch->getDefaultDest() == ExitingBlock) 9355 return getCouldNotCompute(); 9356 9357 assert(L->contains(Switch->getDefaultDest()) && 9358 "Default case must not exit the loop!"); 9359 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); 9360 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); 9361 9362 // while (X != Y) --> while (X-Y != 0) 9363 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); 9364 if (EL.hasAnyInfo()) 9365 return EL; 9366 9367 return getCouldNotCompute(); 9368 } 9369 9370 static ConstantInt * 9371 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, 9372 ScalarEvolution &SE) { 9373 const SCEV *InVal = SE.getConstant(C); 9374 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); 9375 assert(isa<SCEVConstant>(Val) && 9376 "Evaluation of SCEV at constant didn't fold correctly?"); 9377 return cast<SCEVConstant>(Val)->getValue(); 9378 } 9379 9380 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( 9381 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { 9382 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV); 9383 if (!RHS) 9384 return getCouldNotCompute(); 9385 9386 const BasicBlock *Latch = L->getLoopLatch(); 9387 if (!Latch) 9388 return getCouldNotCompute(); 9389 9390 const BasicBlock *Predecessor = L->getLoopPredecessor(); 9391 if (!Predecessor) 9392 return getCouldNotCompute(); 9393 9394 // Return true if V is of the form "LHS `shift_op` <positive constant>". 9395 // Return LHS in OutLHS and shift_opt in OutOpCode. 9396 auto MatchPositiveShift = 9397 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { 9398 9399 using namespace PatternMatch; 9400 9401 ConstantInt *ShiftAmt; 9402 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 9403 OutOpCode = Instruction::LShr; 9404 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 9405 OutOpCode = Instruction::AShr; 9406 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 9407 OutOpCode = Instruction::Shl; 9408 else 9409 return false; 9410 9411 return ShiftAmt->getValue().isStrictlyPositive(); 9412 }; 9413 9414 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in 9415 // 9416 // loop: 9417 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] 9418 // %iv.shifted = lshr i32 %iv, <positive constant> 9419 // 9420 // Return true on a successful match. Return the corresponding PHI node (%iv 9421 // above) in PNOut and the opcode of the shift operation in OpCodeOut. 9422 auto MatchShiftRecurrence = 9423 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { 9424 std::optional<Instruction::BinaryOps> PostShiftOpCode; 9425 9426 { 9427 Instruction::BinaryOps OpC; 9428 Value *V; 9429 9430 // If we encounter a shift instruction, "peel off" the shift operation, 9431 // and remember that we did so. Later when we inspect %iv's backedge 9432 // value, we will make sure that the backedge value uses the same 9433 // operation. 9434 // 9435 // Note: the peeled shift operation does not have to be the same 9436 // instruction as the one feeding into the PHI's backedge value. We only 9437 // really care about it being the same *kind* of shift instruction -- 9438 // that's all that is required for our later inferences to hold. 9439 if (MatchPositiveShift(LHS, V, OpC)) { 9440 PostShiftOpCode = OpC; 9441 LHS = V; 9442 } 9443 } 9444 9445 PNOut = dyn_cast<PHINode>(LHS); 9446 if (!PNOut || PNOut->getParent() != L->getHeader()) 9447 return false; 9448 9449 Value *BEValue = PNOut->getIncomingValueForBlock(Latch); 9450 Value *OpLHS; 9451 9452 return 9453 // The backedge value for the PHI node must be a shift by a positive 9454 // amount 9455 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && 9456 9457 // of the PHI node itself 9458 OpLHS == PNOut && 9459 9460 // and the kind of shift should be match the kind of shift we peeled 9461 // off, if any. 9462 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut); 9463 }; 9464 9465 PHINode *PN; 9466 Instruction::BinaryOps OpCode; 9467 if (!MatchShiftRecurrence(LHS, PN, OpCode)) 9468 return getCouldNotCompute(); 9469 9470 const DataLayout &DL = getDataLayout(); 9471 9472 // The key rationale for this optimization is that for some kinds of shift 9473 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 9474 // within a finite number of iterations. If the condition guarding the 9475 // backedge (in the sense that the backedge is taken if the condition is true) 9476 // is false for the value the shift recurrence stabilizes to, then we know 9477 // that the backedge is taken only a finite number of times. 9478 9479 ConstantInt *StableValue = nullptr; 9480 switch (OpCode) { 9481 default: 9482 llvm_unreachable("Impossible case!"); 9483 9484 case Instruction::AShr: { 9485 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most 9486 // bitwidth(K) iterations. 9487 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); 9488 KnownBits Known = computeKnownBits(FirstValue, DL, &AC, 9489 Predecessor->getTerminator(), &DT); 9490 auto *Ty = cast<IntegerType>(RHS->getType()); 9491 if (Known.isNonNegative()) 9492 StableValue = ConstantInt::get(Ty, 0); 9493 else if (Known.isNegative()) 9494 StableValue = ConstantInt::get(Ty, -1, true); 9495 else 9496 return getCouldNotCompute(); 9497 9498 break; 9499 } 9500 case Instruction::LShr: 9501 case Instruction::Shl: 9502 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>} 9503 // stabilize to 0 in at most bitwidth(K) iterations. 9504 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0); 9505 break; 9506 } 9507 9508 auto *Result = 9509 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); 9510 assert(Result->getType()->isIntegerTy(1) && 9511 "Otherwise cannot be an operand to a branch instruction"); 9512 9513 if (Result->isZeroValue()) { 9514 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 9515 const SCEV *UpperBound = 9516 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); 9517 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false); 9518 } 9519 9520 return getCouldNotCompute(); 9521 } 9522 9523 /// Return true if we can constant fold an instruction of the specified type, 9524 /// assuming that all operands were constants. 9525 static bool CanConstantFold(const Instruction *I) { 9526 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || 9527 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || 9528 isa<LoadInst>(I) || isa<ExtractValueInst>(I)) 9529 return true; 9530 9531 if (const CallInst *CI = dyn_cast<CallInst>(I)) 9532 if (const Function *F = CI->getCalledFunction()) 9533 return canConstantFoldCallTo(CI, F); 9534 return false; 9535 } 9536 9537 /// Determine whether this instruction can constant evolve within this loop 9538 /// assuming its operands can all constant evolve. 9539 static bool canConstantEvolve(Instruction *I, const Loop *L) { 9540 // An instruction outside of the loop can't be derived from a loop PHI. 9541 if (!L->contains(I)) return false; 9542 9543 if (isa<PHINode>(I)) { 9544 // We don't currently keep track of the control flow needed to evaluate 9545 // PHIs, so we cannot handle PHIs inside of loops. 9546 return L->getHeader() == I->getParent(); 9547 } 9548 9549 // If we won't be able to constant fold this expression even if the operands 9550 // are constants, bail early. 9551 return CanConstantFold(I); 9552 } 9553 9554 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by 9555 /// recursing through each instruction operand until reaching a loop header phi. 9556 static PHINode * 9557 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, 9558 DenseMap<Instruction *, PHINode *> &PHIMap, 9559 unsigned Depth) { 9560 if (Depth > MaxConstantEvolvingDepth) 9561 return nullptr; 9562 9563 // Otherwise, we can evaluate this instruction if all of its operands are 9564 // constant or derived from a PHI node themselves. 9565 PHINode *PHI = nullptr; 9566 for (Value *Op : UseInst->operands()) { 9567 if (isa<Constant>(Op)) continue; 9568 9569 Instruction *OpInst = dyn_cast<Instruction>(Op); 9570 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; 9571 9572 PHINode *P = dyn_cast<PHINode>(OpInst); 9573 if (!P) 9574 // If this operand is already visited, reuse the prior result. 9575 // We may have P != PHI if this is the deepest point at which the 9576 // inconsistent paths meet. 9577 P = PHIMap.lookup(OpInst); 9578 if (!P) { 9579 // Recurse and memoize the results, whether a phi is found or not. 9580 // This recursive call invalidates pointers into PHIMap. 9581 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); 9582 PHIMap[OpInst] = P; 9583 } 9584 if (!P) 9585 return nullptr; // Not evolving from PHI 9586 if (PHI && PHI != P) 9587 return nullptr; // Evolving from multiple different PHIs. 9588 PHI = P; 9589 } 9590 // This is a expression evolving from a constant PHI! 9591 return PHI; 9592 } 9593 9594 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node 9595 /// in the loop that V is derived from. We allow arbitrary operations along the 9596 /// way, but the operands of an operation must either be constants or a value 9597 /// derived from a constant PHI. If this expression does not fit with these 9598 /// constraints, return null. 9599 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { 9600 Instruction *I = dyn_cast<Instruction>(V); 9601 if (!I || !canConstantEvolve(I, L)) return nullptr; 9602 9603 if (PHINode *PN = dyn_cast<PHINode>(I)) 9604 return PN; 9605 9606 // Record non-constant instructions contained by the loop. 9607 DenseMap<Instruction *, PHINode *> PHIMap; 9608 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); 9609 } 9610 9611 /// EvaluateExpression - Given an expression that passes the 9612 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node 9613 /// in the loop has the value PHIVal. If we can't fold this expression for some 9614 /// reason, return null. 9615 static Constant *EvaluateExpression(Value *V, const Loop *L, 9616 DenseMap<Instruction *, Constant *> &Vals, 9617 const DataLayout &DL, 9618 const TargetLibraryInfo *TLI) { 9619 // Convenient constant check, but redundant for recursive calls. 9620 if (Constant *C = dyn_cast<Constant>(V)) return C; 9621 Instruction *I = dyn_cast<Instruction>(V); 9622 if (!I) return nullptr; 9623 9624 if (Constant *C = Vals.lookup(I)) return C; 9625 9626 // An instruction inside the loop depends on a value outside the loop that we 9627 // weren't given a mapping for, or a value such as a call inside the loop. 9628 if (!canConstantEvolve(I, L)) return nullptr; 9629 9630 // An unmapped PHI can be due to a branch or another loop inside this loop, 9631 // or due to this not being the initial iteration through a loop where we 9632 // couldn't compute the evolution of this particular PHI last time. 9633 if (isa<PHINode>(I)) return nullptr; 9634 9635 std::vector<Constant*> Operands(I->getNumOperands()); 9636 9637 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 9638 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i)); 9639 if (!Operand) { 9640 Operands[i] = dyn_cast<Constant>(I->getOperand(i)); 9641 if (!Operands[i]) return nullptr; 9642 continue; 9643 } 9644 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); 9645 Vals[Operand] = C; 9646 if (!C) return nullptr; 9647 Operands[i] = C; 9648 } 9649 9650 return ConstantFoldInstOperands(I, Operands, DL, TLI, 9651 /*AllowNonDeterministic=*/false); 9652 } 9653 9654 9655 // If every incoming value to PN except the one for BB is a specific Constant, 9656 // return that, else return nullptr. 9657 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { 9658 Constant *IncomingVal = nullptr; 9659 9660 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 9661 if (PN->getIncomingBlock(i) == BB) 9662 continue; 9663 9664 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i)); 9665 if (!CurrentVal) 9666 return nullptr; 9667 9668 if (IncomingVal != CurrentVal) { 9669 if (IncomingVal) 9670 return nullptr; 9671 IncomingVal = CurrentVal; 9672 } 9673 } 9674 9675 return IncomingVal; 9676 } 9677 9678 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is 9679 /// in the header of its containing loop, we know the loop executes a 9680 /// constant number of times, and the PHI node is just a recurrence 9681 /// involving constants, fold it. 9682 Constant * 9683 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, 9684 const APInt &BEs, 9685 const Loop *L) { 9686 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN); 9687 if (!Inserted) 9688 return I->second; 9689 9690 if (BEs.ugt(MaxBruteForceIterations)) 9691 return nullptr; // Not going to evaluate it. 9692 9693 Constant *&RetVal = I->second; 9694 9695 DenseMap<Instruction *, Constant *> CurrentIterVals; 9696 BasicBlock *Header = L->getHeader(); 9697 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 9698 9699 BasicBlock *Latch = L->getLoopLatch(); 9700 if (!Latch) 9701 return nullptr; 9702 9703 for (PHINode &PHI : Header->phis()) { 9704 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) 9705 CurrentIterVals[&PHI] = StartCST; 9706 } 9707 if (!CurrentIterVals.count(PN)) 9708 return RetVal = nullptr; 9709 9710 Value *BEValue = PN->getIncomingValueForBlock(Latch); 9711 9712 // Execute the loop symbolically to determine the exit value. 9713 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && 9714 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"); 9715 9716 unsigned NumIterations = BEs.getZExtValue(); // must be in range 9717 unsigned IterationNum = 0; 9718 const DataLayout &DL = getDataLayout(); 9719 for (; ; ++IterationNum) { 9720 if (IterationNum == NumIterations) 9721 return RetVal = CurrentIterVals[PN]; // Got exit value! 9722 9723 // Compute the value of the PHIs for the next iteration. 9724 // EvaluateExpression adds non-phi values to the CurrentIterVals map. 9725 DenseMap<Instruction *, Constant *> NextIterVals; 9726 Constant *NextPHI = 9727 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 9728 if (!NextPHI) 9729 return nullptr; // Couldn't evaluate! 9730 NextIterVals[PN] = NextPHI; 9731 9732 bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; 9733 9734 // Also evaluate the other PHI nodes. However, we don't get to stop if we 9735 // cease to be able to evaluate one of them or if they stop evolving, 9736 // because that doesn't necessarily prevent us from computing PN. 9737 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute; 9738 for (const auto &I : CurrentIterVals) { 9739 PHINode *PHI = dyn_cast<PHINode>(I.first); 9740 if (!PHI || PHI == PN || PHI->getParent() != Header) continue; 9741 PHIsToCompute.emplace_back(PHI, I.second); 9742 } 9743 // We use two distinct loops because EvaluateExpression may invalidate any 9744 // iterators into CurrentIterVals. 9745 for (const auto &I : PHIsToCompute) { 9746 PHINode *PHI = I.first; 9747 Constant *&NextPHI = NextIterVals[PHI]; 9748 if (!NextPHI) { // Not already computed. 9749 Value *BEValue = PHI->getIncomingValueForBlock(Latch); 9750 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 9751 } 9752 if (NextPHI != I.second) 9753 StoppedEvolving = false; 9754 } 9755 9756 // If all entries in CurrentIterVals == NextIterVals then we can stop 9757 // iterating, the loop can't continue to change. 9758 if (StoppedEvolving) 9759 return RetVal = CurrentIterVals[PN]; 9760 9761 CurrentIterVals.swap(NextIterVals); 9762 } 9763 } 9764 9765 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, 9766 Value *Cond, 9767 bool ExitWhen) { 9768 PHINode *PN = getConstantEvolvingPHI(Cond, L); 9769 if (!PN) return getCouldNotCompute(); 9770 9771 // If the loop is canonicalized, the PHI will have exactly two entries. 9772 // That's the only form we support here. 9773 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); 9774 9775 DenseMap<Instruction *, Constant *> CurrentIterVals; 9776 BasicBlock *Header = L->getHeader(); 9777 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 9778 9779 BasicBlock *Latch = L->getLoopLatch(); 9780 assert(Latch && "Should follow from NumIncomingValues == 2!"); 9781 9782 for (PHINode &PHI : Header->phis()) { 9783 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) 9784 CurrentIterVals[&PHI] = StartCST; 9785 } 9786 if (!CurrentIterVals.count(PN)) 9787 return getCouldNotCompute(); 9788 9789 // Okay, we find a PHI node that defines the trip count of this loop. Execute 9790 // the loop symbolically to determine when the condition gets a value of 9791 // "ExitWhen". 9792 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. 9793 const DataLayout &DL = getDataLayout(); 9794 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ 9795 auto *CondVal = dyn_cast_or_null<ConstantInt>( 9796 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); 9797 9798 // Couldn't symbolically evaluate. 9799 if (!CondVal) return getCouldNotCompute(); 9800 9801 if (CondVal->getValue() == uint64_t(ExitWhen)) { 9802 ++NumBruteForceTripCountsComputed; 9803 return getConstant(Type::getInt32Ty(getContext()), IterationNum); 9804 } 9805 9806 // Update all the PHI nodes for the next iteration. 9807 DenseMap<Instruction *, Constant *> NextIterVals; 9808 9809 // Create a list of which PHIs we need to compute. We want to do this before 9810 // calling EvaluateExpression on them because that may invalidate iterators 9811 // into CurrentIterVals. 9812 SmallVector<PHINode *, 8> PHIsToCompute; 9813 for (const auto &I : CurrentIterVals) { 9814 PHINode *PHI = dyn_cast<PHINode>(I.first); 9815 if (!PHI || PHI->getParent() != Header) continue; 9816 PHIsToCompute.push_back(PHI); 9817 } 9818 for (PHINode *PHI : PHIsToCompute) { 9819 Constant *&NextPHI = NextIterVals[PHI]; 9820 if (NextPHI) continue; // Already computed! 9821 9822 Value *BEValue = PHI->getIncomingValueForBlock(Latch); 9823 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 9824 } 9825 CurrentIterVals.swap(NextIterVals); 9826 } 9827 9828 // Too many iterations were needed to evaluate. 9829 return getCouldNotCompute(); 9830 } 9831 9832 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { 9833 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = 9834 ValuesAtScopes[V]; 9835 // Check to see if we've folded this expression at this loop before. 9836 for (auto &LS : Values) 9837 if (LS.first == L) 9838 return LS.second ? LS.second : V; 9839 9840 Values.emplace_back(L, nullptr); 9841 9842 // Otherwise compute it. 9843 const SCEV *C = computeSCEVAtScope(V, L); 9844 for (auto &LS : reverse(ValuesAtScopes[V])) 9845 if (LS.first == L) { 9846 LS.second = C; 9847 if (!isa<SCEVConstant>(C)) 9848 ValuesAtScopesUsers[C].push_back({L, V}); 9849 break; 9850 } 9851 return C; 9852 } 9853 9854 /// This builds up a Constant using the ConstantExpr interface. That way, we 9855 /// will return Constants for objects which aren't represented by a 9856 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. 9857 /// Returns NULL if the SCEV isn't representable as a Constant. 9858 static Constant *BuildConstantFromSCEV(const SCEV *V) { 9859 switch (V->getSCEVType()) { 9860 case scCouldNotCompute: 9861 case scAddRecExpr: 9862 case scVScale: 9863 return nullptr; 9864 case scConstant: 9865 return cast<SCEVConstant>(V)->getValue(); 9866 case scUnknown: 9867 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); 9868 case scPtrToInt: { 9869 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V); 9870 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand())) 9871 return ConstantExpr::getPtrToInt(CastOp, P2I->getType()); 9872 9873 return nullptr; 9874 } 9875 case scTruncate: { 9876 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); 9877 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) 9878 return ConstantExpr::getTrunc(CastOp, ST->getType()); 9879 return nullptr; 9880 } 9881 case scAddExpr: { 9882 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); 9883 Constant *C = nullptr; 9884 for (const SCEV *Op : SA->operands()) { 9885 Constant *OpC = BuildConstantFromSCEV(Op); 9886 if (!OpC) 9887 return nullptr; 9888 if (!C) { 9889 C = OpC; 9890 continue; 9891 } 9892 assert(!C->getType()->isPointerTy() && 9893 "Can only have one pointer, and it must be last"); 9894 if (OpC->getType()->isPointerTy()) { 9895 // The offsets have been converted to bytes. We can add bytes using 9896 // an i8 GEP. 9897 C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()), 9898 OpC, C); 9899 } else { 9900 C = ConstantExpr::getAdd(C, OpC); 9901 } 9902 } 9903 return C; 9904 } 9905 case scMulExpr: 9906 case scSignExtend: 9907 case scZeroExtend: 9908 case scUDivExpr: 9909 case scSMaxExpr: 9910 case scUMaxExpr: 9911 case scSMinExpr: 9912 case scUMinExpr: 9913 case scSequentialUMinExpr: 9914 return nullptr; 9915 } 9916 llvm_unreachable("Unknown SCEV kind!"); 9917 } 9918 9919 const SCEV * 9920 ScalarEvolution::getWithOperands(const SCEV *S, 9921 SmallVectorImpl<const SCEV *> &NewOps) { 9922 switch (S->getSCEVType()) { 9923 case scTruncate: 9924 case scZeroExtend: 9925 case scSignExtend: 9926 case scPtrToInt: 9927 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType()); 9928 case scAddRecExpr: { 9929 auto *AddRec = cast<SCEVAddRecExpr>(S); 9930 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags()); 9931 } 9932 case scAddExpr: 9933 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags()); 9934 case scMulExpr: 9935 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags()); 9936 case scUDivExpr: 9937 return getUDivExpr(NewOps[0], NewOps[1]); 9938 case scUMaxExpr: 9939 case scSMaxExpr: 9940 case scUMinExpr: 9941 case scSMinExpr: 9942 return getMinMaxExpr(S->getSCEVType(), NewOps); 9943 case scSequentialUMinExpr: 9944 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps); 9945 case scConstant: 9946 case scVScale: 9947 case scUnknown: 9948 return S; 9949 case scCouldNotCompute: 9950 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 9951 } 9952 llvm_unreachable("Unknown SCEV kind!"); 9953 } 9954 9955 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { 9956 switch (V->getSCEVType()) { 9957 case scConstant: 9958 case scVScale: 9959 return V; 9960 case scAddRecExpr: { 9961 // If this is a loop recurrence for a loop that does not contain L, then we 9962 // are dealing with the final value computed by the loop. 9963 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V); 9964 // First, attempt to evaluate each operand. 9965 // Avoid performing the look-up in the common case where the specified 9966 // expression has no loop-variant portions. 9967 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 9968 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); 9969 if (OpAtScope == AddRec->getOperand(i)) 9970 continue; 9971 9972 // Okay, at least one of these operands is loop variant but might be 9973 // foldable. Build a new instance of the folded commutative expression. 9974 SmallVector<const SCEV *, 8> NewOps; 9975 NewOps.reserve(AddRec->getNumOperands()); 9976 append_range(NewOps, AddRec->operands().take_front(i)); 9977 NewOps.push_back(OpAtScope); 9978 for (++i; i != e; ++i) 9979 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); 9980 9981 const SCEV *FoldedRec = getAddRecExpr( 9982 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); 9983 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec); 9984 // The addrec may be folded to a nonrecurrence, for example, if the 9985 // induction variable is multiplied by zero after constant folding. Go 9986 // ahead and return the folded value. 9987 if (!AddRec) 9988 return FoldedRec; 9989 break; 9990 } 9991 9992 // If the scope is outside the addrec's loop, evaluate it by using the 9993 // loop exit value of the addrec. 9994 if (!AddRec->getLoop()->contains(L)) { 9995 // To evaluate this recurrence, we need to know how many times the AddRec 9996 // loop iterates. Compute this now. 9997 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); 9998 if (BackedgeTakenCount == getCouldNotCompute()) 9999 return AddRec; 10000 10001 // Then, evaluate the AddRec. 10002 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); 10003 } 10004 10005 return AddRec; 10006 } 10007 case scTruncate: 10008 case scZeroExtend: 10009 case scSignExtend: 10010 case scPtrToInt: 10011 case scAddExpr: 10012 case scMulExpr: 10013 case scUDivExpr: 10014 case scUMaxExpr: 10015 case scSMaxExpr: 10016 case scUMinExpr: 10017 case scSMinExpr: 10018 case scSequentialUMinExpr: { 10019 ArrayRef<const SCEV *> Ops = V->operands(); 10020 // Avoid performing the look-up in the common case where the specified 10021 // expression has no loop-variant portions. 10022 for (unsigned i = 0, e = Ops.size(); i != e; ++i) { 10023 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L); 10024 if (OpAtScope != Ops[i]) { 10025 // Okay, at least one of these operands is loop variant but might be 10026 // foldable. Build a new instance of the folded commutative expression. 10027 SmallVector<const SCEV *, 8> NewOps; 10028 NewOps.reserve(Ops.size()); 10029 append_range(NewOps, Ops.take_front(i)); 10030 NewOps.push_back(OpAtScope); 10031 10032 for (++i; i != e; ++i) { 10033 OpAtScope = getSCEVAtScope(Ops[i], L); 10034 NewOps.push_back(OpAtScope); 10035 } 10036 10037 return getWithOperands(V, NewOps); 10038 } 10039 } 10040 // If we got here, all operands are loop invariant. 10041 return V; 10042 } 10043 case scUnknown: { 10044 // If this instruction is evolved from a constant-evolving PHI, compute the 10045 // exit value from the loop without using SCEVs. 10046 const SCEVUnknown *SU = cast<SCEVUnknown>(V); 10047 Instruction *I = dyn_cast<Instruction>(SU->getValue()); 10048 if (!I) 10049 return V; // This is some other type of SCEVUnknown, just return it. 10050 10051 if (PHINode *PN = dyn_cast<PHINode>(I)) { 10052 const Loop *CurrLoop = this->LI[I->getParent()]; 10053 // Looking for loop exit value. 10054 if (CurrLoop && CurrLoop->getParentLoop() == L && 10055 PN->getParent() == CurrLoop->getHeader()) { 10056 // Okay, there is no closed form solution for the PHI node. Check 10057 // to see if the loop that contains it has a known backedge-taken 10058 // count. If so, we may be able to force computation of the exit 10059 // value. 10060 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); 10061 // This trivial case can show up in some degenerate cases where 10062 // the incoming IR has not yet been fully simplified. 10063 if (BackedgeTakenCount->isZero()) { 10064 Value *InitValue = nullptr; 10065 bool MultipleInitValues = false; 10066 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { 10067 if (!CurrLoop->contains(PN->getIncomingBlock(i))) { 10068 if (!InitValue) 10069 InitValue = PN->getIncomingValue(i); 10070 else if (InitValue != PN->getIncomingValue(i)) { 10071 MultipleInitValues = true; 10072 break; 10073 } 10074 } 10075 } 10076 if (!MultipleInitValues && InitValue) 10077 return getSCEV(InitValue); 10078 } 10079 // Do we have a loop invariant value flowing around the backedge 10080 // for a loop which must execute the backedge? 10081 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && 10082 isKnownNonZero(BackedgeTakenCount) && 10083 PN->getNumIncomingValues() == 2) { 10084 10085 unsigned InLoopPred = 10086 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1; 10087 Value *BackedgeVal = PN->getIncomingValue(InLoopPred); 10088 if (CurrLoop->isLoopInvariant(BackedgeVal)) 10089 return getSCEV(BackedgeVal); 10090 } 10091 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) { 10092 // Okay, we know how many times the containing loop executes. If 10093 // this is a constant evolving PHI node, get the final value at 10094 // the specified iteration number. 10095 Constant *RV = 10096 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop); 10097 if (RV) 10098 return getSCEV(RV); 10099 } 10100 } 10101 } 10102 10103 // Okay, this is an expression that we cannot symbolically evaluate 10104 // into a SCEV. Check to see if it's possible to symbolically evaluate 10105 // the arguments into constants, and if so, try to constant propagate the 10106 // result. This is particularly useful for computing loop exit values. 10107 if (!CanConstantFold(I)) 10108 return V; // This is some other type of SCEVUnknown, just return it. 10109 10110 SmallVector<Constant *, 4> Operands; 10111 Operands.reserve(I->getNumOperands()); 10112 bool MadeImprovement = false; 10113 for (Value *Op : I->operands()) { 10114 if (Constant *C = dyn_cast<Constant>(Op)) { 10115 Operands.push_back(C); 10116 continue; 10117 } 10118 10119 // If any of the operands is non-constant and if they are 10120 // non-integer and non-pointer, don't even try to analyze them 10121 // with scev techniques. 10122 if (!isSCEVable(Op->getType())) 10123 return V; 10124 10125 const SCEV *OrigV = getSCEV(Op); 10126 const SCEV *OpV = getSCEVAtScope(OrigV, L); 10127 MadeImprovement |= OrigV != OpV; 10128 10129 Constant *C = BuildConstantFromSCEV(OpV); 10130 if (!C) 10131 return V; 10132 assert(C->getType() == Op->getType() && "Type mismatch"); 10133 Operands.push_back(C); 10134 } 10135 10136 // Check to see if getSCEVAtScope actually made an improvement. 10137 if (!MadeImprovement) 10138 return V; // This is some other type of SCEVUnknown, just return it. 10139 10140 Constant *C = nullptr; 10141 const DataLayout &DL = getDataLayout(); 10142 C = ConstantFoldInstOperands(I, Operands, DL, &TLI, 10143 /*AllowNonDeterministic=*/false); 10144 if (!C) 10145 return V; 10146 return getSCEV(C); 10147 } 10148 case scCouldNotCompute: 10149 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 10150 } 10151 llvm_unreachable("Unknown SCEV type!"); 10152 } 10153 10154 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { 10155 return getSCEVAtScope(getSCEV(V), L); 10156 } 10157 10158 const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { 10159 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) 10160 return stripInjectiveFunctions(ZExt->getOperand()); 10161 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) 10162 return stripInjectiveFunctions(SExt->getOperand()); 10163 return S; 10164 } 10165 10166 /// Finds the minimum unsigned root of the following equation: 10167 /// 10168 /// A * X = B (mod N) 10169 /// 10170 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of 10171 /// A and B isn't important. 10172 /// 10173 /// If the equation does not have a solution, SCEVCouldNotCompute is returned. 10174 static const SCEV * 10175 SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, 10176 SmallVectorImpl<const SCEVPredicate *> *Predicates, 10177 10178 ScalarEvolution &SE) { 10179 uint32_t BW = A.getBitWidth(); 10180 assert(BW == SE.getTypeSizeInBits(B->getType())); 10181 assert(A != 0 && "A must be non-zero."); 10182 10183 // 1. D = gcd(A, N) 10184 // 10185 // The gcd of A and N may have only one prime factor: 2. The number of 10186 // trailing zeros in A is its multiplicity 10187 uint32_t Mult2 = A.countr_zero(); 10188 // D = 2^Mult2 10189 10190 // 2. Check if B is divisible by D. 10191 // 10192 // B is divisible by D if and only if the multiplicity of prime factor 2 for B 10193 // is not less than multiplicity of this prime factor for D. 10194 if (SE.getMinTrailingZeros(B) < Mult2) { 10195 // Check if we can prove there's no remainder using URem. 10196 const SCEV *URem = 10197 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2))); 10198 const SCEV *Zero = SE.getZero(B->getType()); 10199 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) { 10200 // Try to add a predicate ensuring B is a multiple of 1 << Mult2. 10201 if (!Predicates) 10202 return SE.getCouldNotCompute(); 10203 10204 // Avoid adding a predicate that is known to be false. 10205 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero)) 10206 return SE.getCouldNotCompute(); 10207 Predicates->push_back(SE.getEqualPredicate(URem, Zero)); 10208 } 10209 } 10210 10211 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic 10212 // modulo (N / D). 10213 // 10214 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent 10215 // (N / D) in general. The inverse itself always fits into BW bits, though, 10216 // so we immediately truncate it. 10217 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D 10218 APInt I = AD.multiplicativeInverse().zext(BW); 10219 10220 // 4. Compute the minimum unsigned root of the equation: 10221 // I * (B / D) mod (N / D) 10222 // To simplify the computation, we factor out the divide by D: 10223 // (I * B mod N) / D 10224 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); 10225 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); 10226 } 10227 10228 /// For a given quadratic addrec, generate coefficients of the corresponding 10229 /// quadratic equation, multiplied by a common value to ensure that they are 10230 /// integers. 10231 /// The returned value is a tuple { A, B, C, M, BitWidth }, where 10232 /// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C 10233 /// were multiplied by, and BitWidth is the bit width of the original addrec 10234 /// coefficients. 10235 /// This function returns std::nullopt if the addrec coefficients are not 10236 /// compile- time constants. 10237 static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>> 10238 GetQuadraticEquation(const SCEVAddRecExpr *AddRec) { 10239 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); 10240 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); 10241 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); 10242 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); 10243 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: " 10244 << *AddRec << '\n'); 10245 10246 // We currently can only solve this if the coefficients are constants. 10247 if (!LC || !MC || !NC) { 10248 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n"); 10249 return std::nullopt; 10250 } 10251 10252 APInt L = LC->getAPInt(); 10253 APInt M = MC->getAPInt(); 10254 APInt N = NC->getAPInt(); 10255 assert(!N.isZero() && "This is not a quadratic addrec"); 10256 10257 unsigned BitWidth = LC->getAPInt().getBitWidth(); 10258 unsigned NewWidth = BitWidth + 1; 10259 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: " 10260 << BitWidth << '\n'); 10261 // The sign-extension (as opposed to a zero-extension) here matches the 10262 // extension used in SolveQuadraticEquationWrap (with the same motivation). 10263 N = N.sext(NewWidth); 10264 M = M.sext(NewWidth); 10265 L = L.sext(NewWidth); 10266 10267 // The increments are M, M+N, M+2N, ..., so the accumulated values are 10268 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is, 10269 // L+M, L+2M+N, L+3M+3N, ... 10270 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N. 10271 // 10272 // The equation Acc = 0 is then 10273 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0. 10274 // In a quadratic form it becomes: 10275 // N n^2 + (2M-N) n + 2L = 0. 10276 10277 APInt A = N; 10278 APInt B = 2 * M - A; 10279 APInt C = 2 * L; 10280 APInt T = APInt(NewWidth, 2); 10281 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B 10282 << "x + " << C << ", coeff bw: " << NewWidth 10283 << ", multiplied by " << T << '\n'); 10284 return std::make_tuple(A, B, C, T, BitWidth); 10285 } 10286 10287 /// Helper function to compare optional APInts: 10288 /// (a) if X and Y both exist, return min(X, Y), 10289 /// (b) if neither X nor Y exist, return std::nullopt, 10290 /// (c) if exactly one of X and Y exists, return that value. 10291 static std::optional<APInt> MinOptional(std::optional<APInt> X, 10292 std::optional<APInt> Y) { 10293 if (X && Y) { 10294 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth()); 10295 APInt XW = X->sext(W); 10296 APInt YW = Y->sext(W); 10297 return XW.slt(YW) ? *X : *Y; 10298 } 10299 if (!X && !Y) 10300 return std::nullopt; 10301 return X ? *X : *Y; 10302 } 10303 10304 /// Helper function to truncate an optional APInt to a given BitWidth. 10305 /// When solving addrec-related equations, it is preferable to return a value 10306 /// that has the same bit width as the original addrec's coefficients. If the 10307 /// solution fits in the original bit width, truncate it (except for i1). 10308 /// Returning a value of a different bit width may inhibit some optimizations. 10309 /// 10310 /// In general, a solution to a quadratic equation generated from an addrec 10311 /// may require BW+1 bits, where BW is the bit width of the addrec's 10312 /// coefficients. The reason is that the coefficients of the quadratic 10313 /// equation are BW+1 bits wide (to avoid truncation when converting from 10314 /// the addrec to the equation). 10315 static std::optional<APInt> TruncIfPossible(std::optional<APInt> X, 10316 unsigned BitWidth) { 10317 if (!X) 10318 return std::nullopt; 10319 unsigned W = X->getBitWidth(); 10320 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth)) 10321 return X->trunc(BitWidth); 10322 return X; 10323 } 10324 10325 /// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n 10326 /// iterations. The values L, M, N are assumed to be signed, and they 10327 /// should all have the same bit widths. 10328 /// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW, 10329 /// where BW is the bit width of the addrec's coefficients. 10330 /// If the calculated value is a BW-bit integer (for BW > 1), it will be 10331 /// returned as such, otherwise the bit width of the returned value may 10332 /// be greater than BW. 10333 /// 10334 /// This function returns std::nullopt if 10335 /// (a) the addrec coefficients are not constant, or 10336 /// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases 10337 /// like x^2 = 5, no integer solutions exist, in other cases an integer 10338 /// solution may exist, but SolveQuadraticEquationWrap may fail to find it. 10339 static std::optional<APInt> 10340 SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { 10341 APInt A, B, C, M; 10342 unsigned BitWidth; 10343 auto T = GetQuadraticEquation(AddRec); 10344 if (!T) 10345 return std::nullopt; 10346 10347 std::tie(A, B, C, M, BitWidth) = *T; 10348 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n"); 10349 std::optional<APInt> X = 10350 APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth + 1); 10351 if (!X) 10352 return std::nullopt; 10353 10354 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X); 10355 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE); 10356 if (!V->isZero()) 10357 return std::nullopt; 10358 10359 return TruncIfPossible(X, BitWidth); 10360 } 10361 10362 /// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n 10363 /// iterations. The values M, N are assumed to be signed, and they 10364 /// should all have the same bit widths. 10365 /// Find the least n such that c(n) does not belong to the given range, 10366 /// while c(n-1) does. 10367 /// 10368 /// This function returns std::nullopt if 10369 /// (a) the addrec coefficients are not constant, or 10370 /// (b) SolveQuadraticEquationWrap was unable to find a solution for the 10371 /// bounds of the range. 10372 static std::optional<APInt> 10373 SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, 10374 const ConstantRange &Range, ScalarEvolution &SE) { 10375 assert(AddRec->getOperand(0)->isZero() && 10376 "Starting value of addrec should be 0"); 10377 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range " 10378 << Range << ", addrec " << *AddRec << '\n'); 10379 // This case is handled in getNumIterationsInRange. Here we can assume that 10380 // we start in the range. 10381 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && 10382 "Addrec's initial value should be in range"); 10383 10384 APInt A, B, C, M; 10385 unsigned BitWidth; 10386 auto T = GetQuadraticEquation(AddRec); 10387 if (!T) 10388 return std::nullopt; 10389 10390 // Be careful about the return value: there can be two reasons for not 10391 // returning an actual number. First, if no solutions to the equations 10392 // were found, and second, if the solutions don't leave the given range. 10393 // The first case means that the actual solution is "unknown", the second 10394 // means that it's known, but not valid. If the solution is unknown, we 10395 // cannot make any conclusions. 10396 // Return a pair: the optional solution and a flag indicating if the 10397 // solution was found. 10398 auto SolveForBoundary = 10399 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> { 10400 // Solve for signed overflow and unsigned overflow, pick the lower 10401 // solution. 10402 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary " 10403 << Bound << " (before multiplying by " << M << ")\n"); 10404 Bound *= M; // The quadratic equation multiplier. 10405 10406 std::optional<APInt> SO; 10407 if (BitWidth > 1) { 10408 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " 10409 "signed overflow\n"); 10410 SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth); 10411 } 10412 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " 10413 "unsigned overflow\n"); 10414 std::optional<APInt> UO = 10415 APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth + 1); 10416 10417 auto LeavesRange = [&] (const APInt &X) { 10418 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X); 10419 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE); 10420 if (Range.contains(V0->getValue())) 10421 return false; 10422 // X should be at least 1, so X-1 is non-negative. 10423 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1); 10424 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE); 10425 if (Range.contains(V1->getValue())) 10426 return true; 10427 return false; 10428 }; 10429 10430 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there 10431 // can be a solution, but the function failed to find it. We cannot treat it 10432 // as "no solution". 10433 if (!SO || !UO) 10434 return {std::nullopt, false}; 10435 10436 // Check the smaller value first to see if it leaves the range. 10437 // At this point, both SO and UO must have values. 10438 std::optional<APInt> Min = MinOptional(SO, UO); 10439 if (LeavesRange(*Min)) 10440 return { Min, true }; 10441 std::optional<APInt> Max = Min == SO ? UO : SO; 10442 if (LeavesRange(*Max)) 10443 return { Max, true }; 10444 10445 // Solutions were found, but were eliminated, hence the "true". 10446 return {std::nullopt, true}; 10447 }; 10448 10449 std::tie(A, B, C, M, BitWidth) = *T; 10450 // Lower bound is inclusive, subtract 1 to represent the exiting value. 10451 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1; 10452 APInt Upper = Range.getUpper().sext(A.getBitWidth()); 10453 auto SL = SolveForBoundary(Lower); 10454 auto SU = SolveForBoundary(Upper); 10455 // If any of the solutions was unknown, no meaninigful conclusions can 10456 // be made. 10457 if (!SL.second || !SU.second) 10458 return std::nullopt; 10459 10460 // Claim: The correct solution is not some value between Min and Max. 10461 // 10462 // Justification: Assuming that Min and Max are different values, one of 10463 // them is when the first signed overflow happens, the other is when the 10464 // first unsigned overflow happens. Crossing the range boundary is only 10465 // possible via an overflow (treating 0 as a special case of it, modeling 10466 // an overflow as crossing k*2^W for some k). 10467 // 10468 // The interesting case here is when Min was eliminated as an invalid 10469 // solution, but Max was not. The argument is that if there was another 10470 // overflow between Min and Max, it would also have been eliminated if 10471 // it was considered. 10472 // 10473 // For a given boundary, it is possible to have two overflows of the same 10474 // type (signed/unsigned) without having the other type in between: this 10475 // can happen when the vertex of the parabola is between the iterations 10476 // corresponding to the overflows. This is only possible when the two 10477 // overflows cross k*2^W for the same k. In such case, if the second one 10478 // left the range (and was the first one to do so), the first overflow 10479 // would have to enter the range, which would mean that either we had left 10480 // the range before or that we started outside of it. Both of these cases 10481 // are contradictions. 10482 // 10483 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct 10484 // solution is not some value between the Max for this boundary and the 10485 // Min of the other boundary. 10486 // 10487 // Justification: Assume that we had such Max_A and Min_B corresponding 10488 // to range boundaries A and B and such that Max_A < Min_B. If there was 10489 // a solution between Max_A and Min_B, it would have to be caused by an 10490 // overflow corresponding to either A or B. It cannot correspond to B, 10491 // since Min_B is the first occurrence of such an overflow. If it 10492 // corresponded to A, it would have to be either a signed or an unsigned 10493 // overflow that is larger than both eliminated overflows for A. But 10494 // between the eliminated overflows and this overflow, the values would 10495 // cover the entire value space, thus crossing the other boundary, which 10496 // is a contradiction. 10497 10498 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); 10499 } 10500 10501 ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, 10502 const Loop *L, 10503 bool ControlsOnlyExit, 10504 bool AllowPredicates) { 10505 10506 // This is only used for loops with a "x != y" exit test. The exit condition 10507 // is now expressed as a single expression, V = x-y. So the exit test is 10508 // effectively V != 0. We know and take advantage of the fact that this 10509 // expression only being used in a comparison by zero context. 10510 10511 SmallVector<const SCEVPredicate *> Predicates; 10512 // If the value is a constant 10513 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 10514 // If the value is already zero, the branch will execute zero times. 10515 if (C->getValue()->isZero()) return C; 10516 return getCouldNotCompute(); // Otherwise it will loop infinitely. 10517 } 10518 10519 const SCEVAddRecExpr *AddRec = 10520 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V)); 10521 10522 if (!AddRec && AllowPredicates) 10523 // Try to make this an AddRec using runtime tests, in the first X 10524 // iterations of this loop, where X is the SCEV expression found by the 10525 // algorithm below. 10526 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); 10527 10528 if (!AddRec || AddRec->getLoop() != L) 10529 return getCouldNotCompute(); 10530 10531 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of 10532 // the quadratic equation to solve it. 10533 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { 10534 // We can only use this value if the chrec ends up with an exact zero 10535 // value at this index. When solving for "X*X != 5", for example, we 10536 // should not accept a root of 2. 10537 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { 10538 const auto *R = cast<SCEVConstant>(getConstant(*S)); 10539 return ExitLimit(R, R, R, false, Predicates); 10540 } 10541 return getCouldNotCompute(); 10542 } 10543 10544 // Otherwise we can only handle this if it is affine. 10545 if (!AddRec->isAffine()) 10546 return getCouldNotCompute(); 10547 10548 // If this is an affine expression, the execution count of this branch is 10549 // the minimum unsigned root of the following equation: 10550 // 10551 // Start + Step*N = 0 (mod 2^BW) 10552 // 10553 // equivalent to: 10554 // 10555 // Step*N = -Start (mod 2^BW) 10556 // 10557 // where BW is the common bit width of Start and Step. 10558 10559 // Get the initial value for the loop. 10560 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); 10561 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); 10562 10563 if (!isLoopInvariant(Step, L)) 10564 return getCouldNotCompute(); 10565 10566 LoopGuards Guards = LoopGuards::collect(L, *this); 10567 // Specialize step for this loop so we get context sensitive facts below. 10568 const SCEV *StepWLG = applyLoopGuards(Step, Guards); 10569 10570 // For positive steps (counting up until unsigned overflow): 10571 // N = -Start/Step (as unsigned) 10572 // For negative steps (counting down to zero): 10573 // N = Start/-Step 10574 // First compute the unsigned distance from zero in the direction of Step. 10575 bool CountDown = isKnownNegative(StepWLG); 10576 if (!CountDown && !isKnownNonNegative(StepWLG)) 10577 return getCouldNotCompute(); 10578 10579 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); 10580 // Handle unitary steps, which cannot wraparound. 10581 // 1*N = -Start; -1*N = Start (mod 2^BW), so: 10582 // N = Distance (as unsigned) 10583 10584 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) { 10585 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards)); 10586 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance)); 10587 10588 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, 10589 // we end up with a loop whose backedge-taken count is n - 1. Detect this 10590 // case, and see if we can improve the bound. 10591 // 10592 // Explicitly handling this here is necessary because getUnsignedRange 10593 // isn't context-sensitive; it doesn't know that we only care about the 10594 // range inside the loop. 10595 const SCEV *Zero = getZero(Distance->getType()); 10596 const SCEV *One = getOne(Distance->getType()); 10597 const SCEV *DistancePlusOne = getAddExpr(Distance, One); 10598 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { 10599 // If Distance + 1 doesn't overflow, we can compute the maximum distance 10600 // as "unsigned_max(Distance + 1) - 1". 10601 ConstantRange CR = getUnsignedRange(DistancePlusOne); 10602 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); 10603 } 10604 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false, 10605 Predicates); 10606 } 10607 10608 // If the condition controls loop exit (the loop exits only if the expression 10609 // is true) and the addition is no-wrap we can use unsigned divide to 10610 // compute the backedge count. In this case, the step may not divide the 10611 // distance, but we don't care because if the condition is "missed" the loop 10612 // will have undefined behavior due to wrapping. 10613 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() && 10614 loopHasNoAbnormalExits(AddRec->getLoop())) { 10615 10616 // If the stride is zero and the start is non-zero, the loop must be 10617 // infinite. In C++, most loops are finite by assumption, in which case the 10618 // step being zero implies UB must execute if the loop is entered. 10619 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) && 10620 !isKnownNonZero(StepWLG)) 10621 return getCouldNotCompute(); 10622 10623 const SCEV *Exact = 10624 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); 10625 const SCEV *ConstantMax = getCouldNotCompute(); 10626 if (Exact != getCouldNotCompute()) { 10627 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards)); 10628 ConstantMax = 10629 getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); 10630 } 10631 const SCEV *SymbolicMax = 10632 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact; 10633 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates); 10634 } 10635 10636 // Solve the general equation. 10637 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); 10638 if (!StepC || StepC->getValue()->isZero()) 10639 return getCouldNotCompute(); 10640 const SCEV *E = SolveLinEquationWithOverflow( 10641 StepC->getAPInt(), getNegativeSCEV(Start), 10642 AllowPredicates ? &Predicates : nullptr, *this); 10643 10644 const SCEV *M = E; 10645 if (E != getCouldNotCompute()) { 10646 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards)); 10647 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); 10648 } 10649 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E; 10650 return ExitLimit(E, M, S, false, Predicates); 10651 } 10652 10653 ScalarEvolution::ExitLimit 10654 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { 10655 // Loops that look like: while (X == 0) are very strange indeed. We don't 10656 // handle them yet except for the trivial case. This could be expanded in the 10657 // future as needed. 10658 10659 // If the value is a constant, check to see if it is known to be non-zero 10660 // already. If so, the backedge will execute zero times. 10661 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 10662 if (!C->getValue()->isZero()) 10663 return getZero(C->getType()); 10664 return getCouldNotCompute(); // Otherwise it will loop infinitely. 10665 } 10666 10667 // We could implement others, but I really doubt anyone writes loops like 10668 // this, and if they did, they would already be constant folded. 10669 return getCouldNotCompute(); 10670 } 10671 10672 std::pair<const BasicBlock *, const BasicBlock *> 10673 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) 10674 const { 10675 // If the block has a unique predecessor, then there is no path from the 10676 // predecessor to the block that does not go through the direct edge 10677 // from the predecessor to the block. 10678 if (const BasicBlock *Pred = BB->getSinglePredecessor()) 10679 return {Pred, BB}; 10680 10681 // A loop's header is defined to be a block that dominates the loop. 10682 // If the header has a unique predecessor outside the loop, it must be 10683 // a block that has exactly one successor that can reach the loop. 10684 if (const Loop *L = LI.getLoopFor(BB)) 10685 return {L->getLoopPredecessor(), L->getHeader()}; 10686 10687 return {nullptr, BB}; 10688 } 10689 10690 /// SCEV structural equivalence is usually sufficient for testing whether two 10691 /// expressions are equal, however for the purposes of looking for a condition 10692 /// guarding a loop, it can be useful to be a little more general, since a 10693 /// front-end may have replicated the controlling expression. 10694 static bool HasSameValue(const SCEV *A, const SCEV *B) { 10695 // Quick check to see if they are the same SCEV. 10696 if (A == B) return true; 10697 10698 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { 10699 // Not all instructions that are "identical" compute the same value. For 10700 // instance, two distinct alloca instructions allocating the same type are 10701 // identical and do not read memory; but compute distinct values. 10702 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A)); 10703 }; 10704 10705 // Otherwise, if they're both SCEVUnknown, it's possible that they hold 10706 // two different instructions with the same value. Check for this case. 10707 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) 10708 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) 10709 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) 10710 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) 10711 if (ComputesEqualValues(AI, BI)) 10712 return true; 10713 10714 // Otherwise assume they may have a different value. 10715 return false; 10716 } 10717 10718 static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { 10719 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S); 10720 if (!Add || Add->getNumOperands() != 2) 10721 return false; 10722 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); 10723 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { 10724 LHS = Add->getOperand(1); 10725 RHS = ME->getOperand(1); 10726 return true; 10727 } 10728 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); 10729 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { 10730 LHS = Add->getOperand(0); 10731 RHS = ME->getOperand(1); 10732 return true; 10733 } 10734 return false; 10735 } 10736 10737 bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, 10738 const SCEV *&RHS, unsigned Depth) { 10739 bool Changed = false; 10740 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or 10741 // '0 != 0'. 10742 auto TrivialCase = [&](bool TriviallyTrue) { 10743 LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); 10744 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; 10745 return true; 10746 }; 10747 // If we hit the max recursion limit bail out. 10748 if (Depth >= 3) 10749 return false; 10750 10751 // Canonicalize a constant to the right side. 10752 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { 10753 // Check for both operands constant. 10754 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 10755 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred)) 10756 return TrivialCase(false); 10757 return TrivialCase(true); 10758 } 10759 // Otherwise swap the operands to put the constant on the right. 10760 std::swap(LHS, RHS); 10761 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 10762 Changed = true; 10763 } 10764 10765 // If we're comparing an addrec with a value which is loop-invariant in the 10766 // addrec's loop, put the addrec on the left. Also make a dominance check, 10767 // as both operands could be addrecs loop-invariant in each other's loop. 10768 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) { 10769 const Loop *L = AR->getLoop(); 10770 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { 10771 std::swap(LHS, RHS); 10772 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 10773 Changed = true; 10774 } 10775 } 10776 10777 // If there's a constant operand, canonicalize comparisons with boundary 10778 // cases, and canonicalize *-or-equal comparisons to regular comparisons. 10779 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { 10780 const APInt &RA = RC->getAPInt(); 10781 10782 bool SimplifiedByConstantRange = false; 10783 10784 if (!ICmpInst::isEquality(Pred)) { 10785 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); 10786 if (ExactCR.isFullSet()) 10787 return TrivialCase(true); 10788 if (ExactCR.isEmptySet()) 10789 return TrivialCase(false); 10790 10791 APInt NewRHS; 10792 CmpInst::Predicate NewPred; 10793 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && 10794 ICmpInst::isEquality(NewPred)) { 10795 // We were able to convert an inequality to an equality. 10796 Pred = NewPred; 10797 RHS = getConstant(NewRHS); 10798 Changed = SimplifiedByConstantRange = true; 10799 } 10800 } 10801 10802 if (!SimplifiedByConstantRange) { 10803 switch (Pred) { 10804 default: 10805 break; 10806 case ICmpInst::ICMP_EQ: 10807 case ICmpInst::ICMP_NE: 10808 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. 10809 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS)) 10810 Changed = true; 10811 break; 10812 10813 // The "Should have been caught earlier!" messages refer to the fact 10814 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above 10815 // should have fired on the corresponding cases, and canonicalized the 10816 // check to trivial case. 10817 10818 case ICmpInst::ICMP_UGE: 10819 assert(!RA.isMinValue() && "Should have been caught earlier!"); 10820 Pred = ICmpInst::ICMP_UGT; 10821 RHS = getConstant(RA - 1); 10822 Changed = true; 10823 break; 10824 case ICmpInst::ICMP_ULE: 10825 assert(!RA.isMaxValue() && "Should have been caught earlier!"); 10826 Pred = ICmpInst::ICMP_ULT; 10827 RHS = getConstant(RA + 1); 10828 Changed = true; 10829 break; 10830 case ICmpInst::ICMP_SGE: 10831 assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); 10832 Pred = ICmpInst::ICMP_SGT; 10833 RHS = getConstant(RA - 1); 10834 Changed = true; 10835 break; 10836 case ICmpInst::ICMP_SLE: 10837 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); 10838 Pred = ICmpInst::ICMP_SLT; 10839 RHS = getConstant(RA + 1); 10840 Changed = true; 10841 break; 10842 } 10843 } 10844 } 10845 10846 // Check for obvious equality. 10847 if (HasSameValue(LHS, RHS)) { 10848 if (ICmpInst::isTrueWhenEqual(Pred)) 10849 return TrivialCase(true); 10850 if (ICmpInst::isFalseWhenEqual(Pred)) 10851 return TrivialCase(false); 10852 } 10853 10854 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by 10855 // adding or subtracting 1 from one of the operands. 10856 switch (Pred) { 10857 case ICmpInst::ICMP_SLE: 10858 if (!getSignedRangeMax(RHS).isMaxSignedValue()) { 10859 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 10860 SCEV::FlagNSW); 10861 Pred = ICmpInst::ICMP_SLT; 10862 Changed = true; 10863 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) { 10864 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, 10865 SCEV::FlagNSW); 10866 Pred = ICmpInst::ICMP_SLT; 10867 Changed = true; 10868 } 10869 break; 10870 case ICmpInst::ICMP_SGE: 10871 if (!getSignedRangeMin(RHS).isMinSignedValue()) { 10872 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, 10873 SCEV::FlagNSW); 10874 Pred = ICmpInst::ICMP_SGT; 10875 Changed = true; 10876 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) { 10877 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 10878 SCEV::FlagNSW); 10879 Pred = ICmpInst::ICMP_SGT; 10880 Changed = true; 10881 } 10882 break; 10883 case ICmpInst::ICMP_ULE: 10884 if (!getUnsignedRangeMax(RHS).isMaxValue()) { 10885 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 10886 SCEV::FlagNUW); 10887 Pred = ICmpInst::ICMP_ULT; 10888 Changed = true; 10889 } else if (!getUnsignedRangeMin(LHS).isMinValue()) { 10890 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); 10891 Pred = ICmpInst::ICMP_ULT; 10892 Changed = true; 10893 } 10894 break; 10895 case ICmpInst::ICMP_UGE: 10896 // If RHS is an op we can fold the -1, try that first. 10897 // Otherwise prefer LHS to preserve the nuw flag. 10898 if ((isa<SCEVConstant>(RHS) || 10899 (isa<SCEVAddExpr, SCEVAddRecExpr>(RHS) && 10900 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) && 10901 !getUnsignedRangeMin(RHS).isMinValue()) { 10902 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); 10903 Pred = ICmpInst::ICMP_UGT; 10904 Changed = true; 10905 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) { 10906 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 10907 SCEV::FlagNUW); 10908 Pred = ICmpInst::ICMP_UGT; 10909 Changed = true; 10910 } else if (!getUnsignedRangeMin(RHS).isMinValue()) { 10911 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); 10912 Pred = ICmpInst::ICMP_UGT; 10913 Changed = true; 10914 } 10915 break; 10916 default: 10917 break; 10918 } 10919 10920 // TODO: More simplifications are possible here. 10921 10922 // Recursively simplify until we either hit a recursion limit or nothing 10923 // changes. 10924 if (Changed) 10925 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1); 10926 10927 return Changed; 10928 } 10929 10930 bool ScalarEvolution::isKnownNegative(const SCEV *S) { 10931 return getSignedRangeMax(S).isNegative(); 10932 } 10933 10934 bool ScalarEvolution::isKnownPositive(const SCEV *S) { 10935 return getSignedRangeMin(S).isStrictlyPositive(); 10936 } 10937 10938 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { 10939 return !getSignedRangeMin(S).isNegative(); 10940 } 10941 10942 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { 10943 return !getSignedRangeMax(S).isStrictlyPositive(); 10944 } 10945 10946 bool ScalarEvolution::isKnownNonZero(const SCEV *S) { 10947 // Query push down for cases where the unsigned range is 10948 // less than sufficient. 10949 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S)) 10950 return isKnownNonZero(SExt->getOperand(0)); 10951 return getUnsignedRangeMin(S) != 0; 10952 } 10953 10954 bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero, 10955 bool OrNegative) { 10956 auto NonRecursive = [this, OrNegative](const SCEV *S) { 10957 if (auto *C = dyn_cast<SCEVConstant>(S)) 10958 return C->getAPInt().isPowerOf2() || 10959 (OrNegative && C->getAPInt().isNegatedPowerOf2()); 10960 10961 // The vscale_range indicates vscale is a power-of-two. 10962 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange); 10963 }; 10964 10965 if (NonRecursive(S)) 10966 return true; 10967 10968 auto *Mul = dyn_cast<SCEVMulExpr>(S); 10969 if (!Mul) 10970 return false; 10971 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S)); 10972 } 10973 10974 bool ScalarEvolution::isKnownMultipleOf( 10975 const SCEV *S, uint64_t M, 10976 SmallVectorImpl<const SCEVPredicate *> &Assumptions) { 10977 if (M == 0) 10978 return false; 10979 if (M == 1) 10980 return true; 10981 10982 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S 10983 // starts with a multiple of M and at every iteration step S only adds 10984 // multiples of M. 10985 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 10986 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) && 10987 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions); 10988 10989 // For a constant, check that "S % M == 0". 10990 if (auto *Cst = dyn_cast<SCEVConstant>(S)) { 10991 APInt C = Cst->getAPInt(); 10992 return C.urem(M) == 0; 10993 } 10994 10995 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc. 10996 10997 // Basic tests have failed. 10998 // Check "S % M == 0" at compile time and record runtime Assumptions. 10999 auto *STy = dyn_cast<IntegerType>(S->getType()); 11000 const SCEV *SmodM = 11001 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false))); 11002 const SCEV *Zero = getZero(STy); 11003 11004 // Check whether "S % M == 0" is known at compile time. 11005 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero)) 11006 return true; 11007 11008 // Check whether "S % M != 0" is known at compile time. 11009 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero)) 11010 return false; 11011 11012 const SCEVPredicate *P = getComparePredicate(ICmpInst::ICMP_EQ, SmodM, Zero); 11013 11014 // Detect redundant predicates. 11015 for (auto *A : Assumptions) 11016 if (A->implies(P, *this)) 11017 return true; 11018 11019 // Only record non-redundant predicates. 11020 Assumptions.push_back(P); 11021 return true; 11022 } 11023 11024 std::pair<const SCEV *, const SCEV *> 11025 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { 11026 // Compute SCEV on entry of loop L. 11027 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); 11028 if (Start == getCouldNotCompute()) 11029 return { Start, Start }; 11030 // Compute post increment SCEV for loop L. 11031 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); 11032 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); 11033 return { Start, PostInc }; 11034 } 11035 11036 bool ScalarEvolution::isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, 11037 const SCEV *RHS) { 11038 // First collect all loops. 11039 SmallPtrSet<const Loop *, 8> LoopsUsed; 11040 getUsedLoops(LHS, LoopsUsed); 11041 getUsedLoops(RHS, LoopsUsed); 11042 11043 if (LoopsUsed.empty()) 11044 return false; 11045 11046 // Domination relationship must be a linear order on collected loops. 11047 #ifndef NDEBUG 11048 for (const auto *L1 : LoopsUsed) 11049 for (const auto *L2 : LoopsUsed) 11050 assert((DT.dominates(L1->getHeader(), L2->getHeader()) || 11051 DT.dominates(L2->getHeader(), L1->getHeader())) && 11052 "Domination relationship is not a linear order"); 11053 #endif 11054 11055 const Loop *MDL = 11056 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) { 11057 return DT.properlyDominates(L1->getHeader(), L2->getHeader()); 11058 }); 11059 11060 // Get init and post increment value for LHS. 11061 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); 11062 // if LHS contains unknown non-invariant SCEV then bail out. 11063 if (SplitLHS.first == getCouldNotCompute()) 11064 return false; 11065 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC"); 11066 // Get init and post increment value for RHS. 11067 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); 11068 // if RHS contains unknown non-invariant SCEV then bail out. 11069 if (SplitRHS.first == getCouldNotCompute()) 11070 return false; 11071 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC"); 11072 // It is possible that init SCEV contains an invariant load but it does 11073 // not dominate MDL and is not available at MDL loop entry, so we should 11074 // check it here. 11075 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) || 11076 !isAvailableAtLoopEntry(SplitRHS.first, MDL)) 11077 return false; 11078 11079 // It seems backedge guard check is faster than entry one so in some cases 11080 // it can speed up whole estimation by short circuit 11081 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, 11082 SplitRHS.second) && 11083 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first); 11084 } 11085 11086 bool ScalarEvolution::isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, 11087 const SCEV *RHS) { 11088 // Canonicalize the inputs first. 11089 (void)SimplifyICmpOperands(Pred, LHS, RHS); 11090 11091 if (isKnownViaInduction(Pred, LHS, RHS)) 11092 return true; 11093 11094 if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) 11095 return true; 11096 11097 // Otherwise see what can be done with some simple reasoning. 11098 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS); 11099 } 11100 11101 std::optional<bool> ScalarEvolution::evaluatePredicate(CmpPredicate Pred, 11102 const SCEV *LHS, 11103 const SCEV *RHS) { 11104 if (isKnownPredicate(Pred, LHS, RHS)) 11105 return true; 11106 if (isKnownPredicate(ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS)) 11107 return false; 11108 return std::nullopt; 11109 } 11110 11111 bool ScalarEvolution::isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, 11112 const SCEV *RHS, 11113 const Instruction *CtxI) { 11114 // TODO: Analyze guards and assumes from Context's block. 11115 return isKnownPredicate(Pred, LHS, RHS) || 11116 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS); 11117 } 11118 11119 std::optional<bool> 11120 ScalarEvolution::evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, 11121 const SCEV *RHS, const Instruction *CtxI) { 11122 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS); 11123 if (KnownWithoutContext) 11124 return KnownWithoutContext; 11125 11126 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS)) 11127 return true; 11128 if (isBasicBlockEntryGuardedByCond( 11129 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS)) 11130 return false; 11131 return std::nullopt; 11132 } 11133 11134 bool ScalarEvolution::isKnownOnEveryIteration(CmpPredicate Pred, 11135 const SCEVAddRecExpr *LHS, 11136 const SCEV *RHS) { 11137 const Loop *L = LHS->getLoop(); 11138 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && 11139 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); 11140 } 11141 11142 std::optional<ScalarEvolution::MonotonicPredicateType> 11143 ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS, 11144 ICmpInst::Predicate Pred) { 11145 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred); 11146 11147 #ifndef NDEBUG 11148 // Verify an invariant: inverting the predicate should turn a monotonically 11149 // increasing change to a monotonically decreasing one, and vice versa. 11150 if (Result) { 11151 auto ResultSwapped = 11152 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred)); 11153 11154 assert(*ResultSwapped != *Result && 11155 "monotonicity should flip as we flip the predicate"); 11156 } 11157 #endif 11158 11159 return Result; 11160 } 11161 11162 std::optional<ScalarEvolution::MonotonicPredicateType> 11163 ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, 11164 ICmpInst::Predicate Pred) { 11165 // A zero step value for LHS means the induction variable is essentially a 11166 // loop invariant value. We don't really depend on the predicate actually 11167 // flipping from false to true (for increasing predicates, and the other way 11168 // around for decreasing predicates), all we care about is that *if* the 11169 // predicate changes then it only changes from false to true. 11170 // 11171 // A zero step value in itself is not very useful, but there may be places 11172 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be 11173 // as general as possible. 11174 11175 // Only handle LE/LT/GE/GT predicates. 11176 if (!ICmpInst::isRelational(Pred)) 11177 return std::nullopt; 11178 11179 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred); 11180 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) && 11181 "Should be greater or less!"); 11182 11183 // Check that AR does not wrap. 11184 if (ICmpInst::isUnsigned(Pred)) { 11185 if (!LHS->hasNoUnsignedWrap()) 11186 return std::nullopt; 11187 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; 11188 } 11189 assert(ICmpInst::isSigned(Pred) && 11190 "Relational predicate is either signed or unsigned!"); 11191 if (!LHS->hasNoSignedWrap()) 11192 return std::nullopt; 11193 11194 const SCEV *Step = LHS->getStepRecurrence(*this); 11195 11196 if (isKnownNonNegative(Step)) 11197 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; 11198 11199 if (isKnownNonPositive(Step)) 11200 return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; 11201 11202 return std::nullopt; 11203 } 11204 11205 std::optional<ScalarEvolution::LoopInvariantPredicate> 11206 ScalarEvolution::getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, 11207 const SCEV *RHS, const Loop *L, 11208 const Instruction *CtxI) { 11209 // If there is a loop-invariant, force it into the RHS, otherwise bail out. 11210 if (!isLoopInvariant(RHS, L)) { 11211 if (!isLoopInvariant(LHS, L)) 11212 return std::nullopt; 11213 11214 std::swap(LHS, RHS); 11215 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 11216 } 11217 11218 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS); 11219 if (!ArLHS || ArLHS->getLoop() != L) 11220 return std::nullopt; 11221 11222 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred); 11223 if (!MonotonicType) 11224 return std::nullopt; 11225 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to 11226 // true as the loop iterates, and the backedge is control dependent on 11227 // "ArLHS `Pred` RHS" == true then we can reason as follows: 11228 // 11229 // * if the predicate was false in the first iteration then the predicate 11230 // is never evaluated again, since the loop exits without taking the 11231 // backedge. 11232 // * if the predicate was true in the first iteration then it will 11233 // continue to be true for all future iterations since it is 11234 // monotonically increasing. 11235 // 11236 // For both the above possibilities, we can replace the loop varying 11237 // predicate with its value on the first iteration of the loop (which is 11238 // loop invariant). 11239 // 11240 // A similar reasoning applies for a monotonically decreasing predicate, by 11241 // replacing true with false and false with true in the above two bullets. 11242 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing; 11243 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred); 11244 11245 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) 11246 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), 11247 RHS); 11248 11249 if (!CtxI) 11250 return std::nullopt; 11251 // Try to prove via context. 11252 // TODO: Support other cases. 11253 switch (Pred) { 11254 default: 11255 break; 11256 case ICmpInst::ICMP_ULE: 11257 case ICmpInst::ICMP_ULT: { 11258 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!"); 11259 // Given preconditions 11260 // (1) ArLHS does not cross the border of positive and negative parts of 11261 // range because of: 11262 // - Positive step; (TODO: lift this limitation) 11263 // - nuw - does not cross zero boundary; 11264 // - nsw - does not cross SINT_MAX boundary; 11265 // (2) ArLHS <s RHS 11266 // (3) RHS >=s 0 11267 // we can replace the loop variant ArLHS <u RHS condition with loop 11268 // invariant Start(ArLHS) <u RHS. 11269 // 11270 // Because of (1) there are two options: 11271 // - ArLHS is always negative. It means that ArLHS <u RHS is always false; 11272 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative. 11273 // It means that ArLHS <s RHS <=> ArLHS <u RHS. 11274 // Because of (2) ArLHS <u RHS is trivially true. 11275 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0. 11276 // We can strengthen this to Start(ArLHS) <u RHS. 11277 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred); 11278 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() && 11279 isKnownPositive(ArLHS->getStepRecurrence(*this)) && 11280 isKnownNonNegative(RHS) && 11281 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI)) 11282 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), 11283 RHS); 11284 } 11285 } 11286 11287 return std::nullopt; 11288 } 11289 11290 std::optional<ScalarEvolution::LoopInvariantPredicate> 11291 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( 11292 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, 11293 const Instruction *CtxI, const SCEV *MaxIter) { 11294 if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( 11295 Pred, LHS, RHS, L, CtxI, MaxIter)) 11296 return LIP; 11297 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter)) 11298 // Number of iterations expressed as UMIN isn't always great for expressing 11299 // the value on the last iteration. If the straightforward approach didn't 11300 // work, try the following trick: if the a predicate is invariant for X, it 11301 // is also invariant for umin(X, ...). So try to find something that works 11302 // among subexpressions of MaxIter expressed as umin. 11303 for (auto *Op : UMin->operands()) 11304 if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( 11305 Pred, LHS, RHS, L, CtxI, Op)) 11306 return LIP; 11307 return std::nullopt; 11308 } 11309 11310 std::optional<ScalarEvolution::LoopInvariantPredicate> 11311 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( 11312 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, 11313 const Instruction *CtxI, const SCEV *MaxIter) { 11314 // Try to prove the following set of facts: 11315 // - The predicate is monotonic in the iteration space. 11316 // - If the check does not fail on the 1st iteration: 11317 // - No overflow will happen during first MaxIter iterations; 11318 // - It will not fail on the MaxIter'th iteration. 11319 // If the check does fail on the 1st iteration, we leave the loop and no 11320 // other checks matter. 11321 11322 // If there is a loop-invariant, force it into the RHS, otherwise bail out. 11323 if (!isLoopInvariant(RHS, L)) { 11324 if (!isLoopInvariant(LHS, L)) 11325 return std::nullopt; 11326 11327 std::swap(LHS, RHS); 11328 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 11329 } 11330 11331 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS); 11332 if (!AR || AR->getLoop() != L) 11333 return std::nullopt; 11334 11335 // The predicate must be relational (i.e. <, <=, >=, >). 11336 if (!ICmpInst::isRelational(Pred)) 11337 return std::nullopt; 11338 11339 // TODO: Support steps other than +/- 1. 11340 const SCEV *Step = AR->getStepRecurrence(*this); 11341 auto *One = getOne(Step->getType()); 11342 auto *MinusOne = getNegativeSCEV(One); 11343 if (Step != One && Step != MinusOne) 11344 return std::nullopt; 11345 11346 // Type mismatch here means that MaxIter is potentially larger than max 11347 // unsigned value in start type, which mean we cannot prove no wrap for the 11348 // indvar. 11349 if (AR->getType() != MaxIter->getType()) 11350 return std::nullopt; 11351 11352 // Value of IV on suggested last iteration. 11353 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this); 11354 // Does it still meet the requirement? 11355 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS)) 11356 return std::nullopt; 11357 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does 11358 // not exceed max unsigned value of this type), this effectively proves 11359 // that there is no wrap during the iteration. To prove that there is no 11360 // signed/unsigned wrap, we need to check that 11361 // Start <= Last for step = 1 or Start >= Last for step = -1. 11362 ICmpInst::Predicate NoOverflowPred = 11363 CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; 11364 if (Step == MinusOne) 11365 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred); 11366 const SCEV *Start = AR->getStart(); 11367 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI)) 11368 return std::nullopt; 11369 11370 // Everything is fine. 11371 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS); 11372 } 11373 11374 bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred, 11375 const SCEV *LHS, 11376 const SCEV *RHS) { 11377 if (HasSameValue(LHS, RHS)) 11378 return ICmpInst::isTrueWhenEqual(Pred); 11379 11380 auto CheckRange = [&](bool IsSigned) { 11381 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS); 11382 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS); 11383 return RangeLHS.icmp(Pred, RangeRHS); 11384 }; 11385 11386 // The check at the top of the function catches the case where the values are 11387 // known to be equal. 11388 if (Pred == CmpInst::ICMP_EQ) 11389 return false; 11390 11391 if (Pred == CmpInst::ICMP_NE) { 11392 if (CheckRange(true) || CheckRange(false)) 11393 return true; 11394 auto *Diff = getMinusSCEV(LHS, RHS); 11395 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff); 11396 } 11397 11398 return CheckRange(CmpInst::isSigned(Pred)); 11399 } 11400 11401 bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred, 11402 const SCEV *LHS, 11403 const SCEV *RHS) { 11404 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where 11405 // C1 and C2 are constant integers. If either X or Y are not add expressions, 11406 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via 11407 // OutC1 and OutC2. 11408 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y, 11409 APInt &OutC1, APInt &OutC2, 11410 SCEV::NoWrapFlags ExpectedFlags) { 11411 const SCEV *XNonConstOp, *XConstOp; 11412 const SCEV *YNonConstOp, *YConstOp; 11413 SCEV::NoWrapFlags XFlagsPresent; 11414 SCEV::NoWrapFlags YFlagsPresent; 11415 11416 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) { 11417 XConstOp = getZero(X->getType()); 11418 XNonConstOp = X; 11419 XFlagsPresent = ExpectedFlags; 11420 } 11421 if (!isa<SCEVConstant>(XConstOp) || 11422 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) 11423 return false; 11424 11425 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) { 11426 YConstOp = getZero(Y->getType()); 11427 YNonConstOp = Y; 11428 YFlagsPresent = ExpectedFlags; 11429 } 11430 11431 if (!isa<SCEVConstant>(YConstOp) || 11432 (YFlagsPresent & ExpectedFlags) != ExpectedFlags) 11433 return false; 11434 11435 if (YNonConstOp != XNonConstOp) 11436 return false; 11437 11438 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt(); 11439 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt(); 11440 11441 return true; 11442 }; 11443 11444 APInt C1; 11445 APInt C2; 11446 11447 switch (Pred) { 11448 default: 11449 break; 11450 11451 case ICmpInst::ICMP_SGE: 11452 std::swap(LHS, RHS); 11453 [[fallthrough]]; 11454 case ICmpInst::ICMP_SLE: 11455 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2. 11456 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2)) 11457 return true; 11458 11459 break; 11460 11461 case ICmpInst::ICMP_SGT: 11462 std::swap(LHS, RHS); 11463 [[fallthrough]]; 11464 case ICmpInst::ICMP_SLT: 11465 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2. 11466 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2)) 11467 return true; 11468 11469 break; 11470 11471 case ICmpInst::ICMP_UGE: 11472 std::swap(LHS, RHS); 11473 [[fallthrough]]; 11474 case ICmpInst::ICMP_ULE: 11475 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2. 11476 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2)) 11477 return true; 11478 11479 break; 11480 11481 case ICmpInst::ICMP_UGT: 11482 std::swap(LHS, RHS); 11483 [[fallthrough]]; 11484 case ICmpInst::ICMP_ULT: 11485 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2. 11486 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2)) 11487 return true; 11488 break; 11489 } 11490 11491 return false; 11492 } 11493 11494 bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred, 11495 const SCEV *LHS, 11496 const SCEV *RHS) { 11497 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) 11498 return false; 11499 11500 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on 11501 // the stack can result in exponential time complexity. 11502 SaveAndRestore Restore(ProvingSplitPredicate, true); 11503 11504 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L 11505 // 11506 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use 11507 // isKnownPredicate. isKnownPredicate is more powerful, but also more 11508 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the 11509 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to 11510 // use isKnownPredicate later if needed. 11511 return isKnownNonNegative(RHS) && 11512 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && 11513 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); 11514 } 11515 11516 bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred, 11517 const SCEV *LHS, const SCEV *RHS) { 11518 // No need to even try if we know the module has no guards. 11519 if (!HasGuards) 11520 return false; 11521 11522 return any_of(*BB, [&](const Instruction &I) { 11523 using namespace llvm::PatternMatch; 11524 11525 Value *Condition; 11526 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( 11527 m_Value(Condition))) && 11528 isImpliedCond(Pred, LHS, RHS, Condition, false); 11529 }); 11530 } 11531 11532 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is 11533 /// protected by a conditional between LHS and RHS. This is used to 11534 /// to eliminate casts. 11535 bool ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, 11536 CmpPredicate Pred, 11537 const SCEV *LHS, 11538 const SCEV *RHS) { 11539 // Interpret a null as meaning no loop, where there is obviously no guard 11540 // (interprocedural conditions notwithstanding). Do not bother about 11541 // unreachable loops. 11542 if (!L || !DT.isReachableFromEntry(L->getHeader())) 11543 return true; 11544 11545 if (VerifyIR) 11546 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && 11547 "This cannot be done on broken IR!"); 11548 11549 11550 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) 11551 return true; 11552 11553 BasicBlock *Latch = L->getLoopLatch(); 11554 if (!Latch) 11555 return false; 11556 11557 BranchInst *LoopContinuePredicate = 11558 dyn_cast<BranchInst>(Latch->getTerminator()); 11559 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && 11560 isImpliedCond(Pred, LHS, RHS, 11561 LoopContinuePredicate->getCondition(), 11562 LoopContinuePredicate->getSuccessor(0) != L->getHeader())) 11563 return true; 11564 11565 // We don't want more than one activation of the following loops on the stack 11566 // -- that can lead to O(n!) time complexity. 11567 if (WalkingBEDominatingConds) 11568 return false; 11569 11570 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true); 11571 11572 // See if we can exploit a trip count to prove the predicate. 11573 const auto &BETakenInfo = getBackedgeTakenInfo(L); 11574 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); 11575 if (LatchBECount != getCouldNotCompute()) { 11576 // We know that Latch branches back to the loop header exactly 11577 // LatchBECount times. This means the backdege condition at Latch is 11578 // equivalent to "{0,+,1} u< LatchBECount". 11579 Type *Ty = LatchBECount->getType(); 11580 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); 11581 const SCEV *LoopCounter = 11582 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); 11583 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, 11584 LatchBECount)) 11585 return true; 11586 } 11587 11588 // Check conditions due to any @llvm.assume intrinsics. 11589 for (auto &AssumeVH : AC.assumptions()) { 11590 if (!AssumeVH) 11591 continue; 11592 auto *CI = cast<CallInst>(AssumeVH); 11593 if (!DT.dominates(CI, Latch->getTerminator())) 11594 continue; 11595 11596 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) 11597 return true; 11598 } 11599 11600 if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) 11601 return true; 11602 11603 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; 11604 DTN != HeaderDTN; DTN = DTN->getIDom()) { 11605 assert(DTN && "should reach the loop header before reaching the root!"); 11606 11607 BasicBlock *BB = DTN->getBlock(); 11608 if (isImpliedViaGuard(BB, Pred, LHS, RHS)) 11609 return true; 11610 11611 BasicBlock *PBB = BB->getSinglePredecessor(); 11612 if (!PBB) 11613 continue; 11614 11615 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator()); 11616 if (!ContinuePredicate || !ContinuePredicate->isConditional()) 11617 continue; 11618 11619 Value *Condition = ContinuePredicate->getCondition(); 11620 11621 // If we have an edge `E` within the loop body that dominates the only 11622 // latch, the condition guarding `E` also guards the backedge. This 11623 // reasoning works only for loops with a single latch. 11624 11625 BasicBlockEdge DominatingEdge(PBB, BB); 11626 if (DominatingEdge.isSingleEdge()) { 11627 // We're constructively (and conservatively) enumerating edges within the 11628 // loop body that dominate the latch. The dominator tree better agree 11629 // with us on this: 11630 assert(DT.dominates(DominatingEdge, Latch) && "should be!"); 11631 11632 if (isImpliedCond(Pred, LHS, RHS, Condition, 11633 BB != ContinuePredicate->getSuccessor(0))) 11634 return true; 11635 } 11636 } 11637 11638 return false; 11639 } 11640 11641 bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, 11642 CmpPredicate Pred, 11643 const SCEV *LHS, 11644 const SCEV *RHS) { 11645 // Do not bother proving facts for unreachable code. 11646 if (!DT.isReachableFromEntry(BB)) 11647 return true; 11648 if (VerifyIR) 11649 assert(!verifyFunction(*BB->getParent(), &dbgs()) && 11650 "This cannot be done on broken IR!"); 11651 11652 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove 11653 // the facts (a >= b && a != b) separately. A typical situation is when the 11654 // non-strict comparison is known from ranges and non-equality is known from 11655 // dominating predicates. If we are proving strict comparison, we always try 11656 // to prove non-equality and non-strict comparison separately. 11657 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred); 11658 const bool ProvingStrictComparison = 11659 Pred != NonStrictPredicate.dropSameSign(); 11660 bool ProvedNonStrictComparison = false; 11661 bool ProvedNonEquality = false; 11662 11663 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool { 11664 if (!ProvedNonStrictComparison) 11665 ProvedNonStrictComparison = Fn(NonStrictPredicate); 11666 if (!ProvedNonEquality) 11667 ProvedNonEquality = Fn(ICmpInst::ICMP_NE); 11668 if (ProvedNonStrictComparison && ProvedNonEquality) 11669 return true; 11670 return false; 11671 }; 11672 11673 if (ProvingStrictComparison) { 11674 auto ProofFn = [&](CmpPredicate P) { 11675 return isKnownViaNonRecursiveReasoning(P, LHS, RHS); 11676 }; 11677 if (SplitAndProve(ProofFn)) 11678 return true; 11679 } 11680 11681 // Try to prove (Pred, LHS, RHS) using isImpliedCond. 11682 auto ProveViaCond = [&](const Value *Condition, bool Inverse) { 11683 const Instruction *CtxI = &BB->front(); 11684 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI)) 11685 return true; 11686 if (ProvingStrictComparison) { 11687 auto ProofFn = [&](CmpPredicate P) { 11688 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI); 11689 }; 11690 if (SplitAndProve(ProofFn)) 11691 return true; 11692 } 11693 return false; 11694 }; 11695 11696 // Starting at the block's predecessor, climb up the predecessor chain, as long 11697 // as there are predecessors that can be found that have unique successors 11698 // leading to the original block. 11699 const Loop *ContainingLoop = LI.getLoopFor(BB); 11700 const BasicBlock *PredBB; 11701 if (ContainingLoop && ContainingLoop->getHeader() == BB) 11702 PredBB = ContainingLoop->getLoopPredecessor(); 11703 else 11704 PredBB = BB->getSinglePredecessor(); 11705 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB); 11706 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { 11707 const BranchInst *BlockEntryPredicate = 11708 dyn_cast<BranchInst>(Pair.first->getTerminator()); 11709 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional()) 11710 continue; 11711 11712 if (ProveViaCond(BlockEntryPredicate->getCondition(), 11713 BlockEntryPredicate->getSuccessor(0) != Pair.second)) 11714 return true; 11715 } 11716 11717 // Check conditions due to any @llvm.assume intrinsics. 11718 for (auto &AssumeVH : AC.assumptions()) { 11719 if (!AssumeVH) 11720 continue; 11721 auto *CI = cast<CallInst>(AssumeVH); 11722 if (!DT.dominates(CI, BB)) 11723 continue; 11724 11725 if (ProveViaCond(CI->getArgOperand(0), false)) 11726 return true; 11727 } 11728 11729 // Check conditions due to any @llvm.experimental.guard intrinsics. 11730 auto *GuardDecl = Intrinsic::getDeclarationIfExists( 11731 F.getParent(), Intrinsic::experimental_guard); 11732 if (GuardDecl) 11733 for (const auto *GU : GuardDecl->users()) 11734 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU)) 11735 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB)) 11736 if (ProveViaCond(Guard->getArgOperand(0), false)) 11737 return true; 11738 return false; 11739 } 11740 11741 bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, 11742 const SCEV *LHS, 11743 const SCEV *RHS) { 11744 // Interpret a null as meaning no loop, where there is obviously no guard 11745 // (interprocedural conditions notwithstanding). 11746 if (!L) 11747 return false; 11748 11749 // Both LHS and RHS must be available at loop entry. 11750 assert(isAvailableAtLoopEntry(LHS, L) && 11751 "LHS is not available at Loop Entry"); 11752 assert(isAvailableAtLoopEntry(RHS, L) && 11753 "RHS is not available at Loop Entry"); 11754 11755 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) 11756 return true; 11757 11758 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS); 11759 } 11760 11761 bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS, 11762 const SCEV *RHS, 11763 const Value *FoundCondValue, bool Inverse, 11764 const Instruction *CtxI) { 11765 // False conditions implies anything. Do not bother analyzing it further. 11766 if (FoundCondValue == 11767 ConstantInt::getBool(FoundCondValue->getContext(), Inverse)) 11768 return true; 11769 11770 if (!PendingLoopPredicates.insert(FoundCondValue).second) 11771 return false; 11772 11773 auto ClearOnExit = 11774 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); 11775 11776 // Recursively handle And and Or conditions. 11777 const Value *Op0, *Op1; 11778 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) { 11779 if (!Inverse) 11780 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) || 11781 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI); 11782 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) { 11783 if (Inverse) 11784 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) || 11785 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI); 11786 } 11787 11788 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); 11789 if (!ICI) return false; 11790 11791 // Now that we found a conditional branch that dominates the loop or controls 11792 // the loop latch. Check to see if it is the comparison we are looking for. 11793 CmpPredicate FoundPred; 11794 if (Inverse) 11795 FoundPred = ICI->getInverseCmpPredicate(); 11796 else 11797 FoundPred = ICI->getCmpPredicate(); 11798 11799 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); 11800 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); 11801 11802 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI); 11803 } 11804 11805 bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS, 11806 const SCEV *RHS, CmpPredicate FoundPred, 11807 const SCEV *FoundLHS, const SCEV *FoundRHS, 11808 const Instruction *CtxI) { 11809 // Balance the types. 11810 if (getTypeSizeInBits(LHS->getType()) < 11811 getTypeSizeInBits(FoundLHS->getType())) { 11812 // For unsigned and equality predicates, try to prove that both found 11813 // operands fit into narrow unsigned range. If so, try to prove facts in 11814 // narrow types. 11815 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() && 11816 !FoundRHS->getType()->isPointerTy()) { 11817 auto *NarrowType = LHS->getType(); 11818 auto *WideType = FoundLHS->getType(); 11819 auto BitWidth = getTypeSizeInBits(NarrowType); 11820 const SCEV *MaxValue = getZeroExtendExpr( 11821 getConstant(APInt::getMaxValue(BitWidth)), WideType); 11822 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS, 11823 MaxValue) && 11824 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS, 11825 MaxValue)) { 11826 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); 11827 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); 11828 // We cannot preserve samesign after truncation. 11829 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(), 11830 TruncFoundLHS, TruncFoundRHS, CtxI)) 11831 return true; 11832 } 11833 } 11834 11835 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy()) 11836 return false; 11837 if (CmpInst::isSigned(Pred)) { 11838 LHS = getSignExtendExpr(LHS, FoundLHS->getType()); 11839 RHS = getSignExtendExpr(RHS, FoundLHS->getType()); 11840 } else { 11841 LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); 11842 RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); 11843 } 11844 } else if (getTypeSizeInBits(LHS->getType()) > 11845 getTypeSizeInBits(FoundLHS->getType())) { 11846 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy()) 11847 return false; 11848 if (CmpInst::isSigned(FoundPred)) { 11849 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); 11850 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); 11851 } else { 11852 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); 11853 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); 11854 } 11855 } 11856 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS, 11857 FoundRHS, CtxI); 11858 } 11859 11860 bool ScalarEvolution::isImpliedCondBalancedTypes( 11861 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred, 11862 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) { 11863 assert(getTypeSizeInBits(LHS->getType()) == 11864 getTypeSizeInBits(FoundLHS->getType()) && 11865 "Types should be balanced!"); 11866 // Canonicalize the query to match the way instcombine will have 11867 // canonicalized the comparison. 11868 if (SimplifyICmpOperands(Pred, LHS, RHS)) 11869 if (LHS == RHS) 11870 return CmpInst::isTrueWhenEqual(Pred); 11871 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) 11872 if (FoundLHS == FoundRHS) 11873 return CmpInst::isFalseWhenEqual(FoundPred); 11874 11875 // Check to see if we can make the LHS or RHS match. 11876 if (LHS == FoundRHS || RHS == FoundLHS) { 11877 if (isa<SCEVConstant>(RHS)) { 11878 std::swap(FoundLHS, FoundRHS); 11879 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred); 11880 } else { 11881 std::swap(LHS, RHS); 11882 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 11883 } 11884 } 11885 11886 // Check whether the found predicate is the same as the desired predicate. 11887 if (auto P = CmpPredicate::getMatching(FoundPred, Pred)) 11888 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI); 11889 11890 // Check whether swapping the found predicate makes it the same as the 11891 // desired predicate. 11892 if (auto P = CmpPredicate::getMatching( 11893 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) { 11894 // We can write the implication 11895 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS 11896 // using one of the following ways: 11897 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS 11898 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS 11899 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS 11900 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS 11901 // Forms 1. and 2. require swapping the operands of one condition. Don't 11902 // do this if it would break canonical constant/addrec ordering. 11903 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS)) 11904 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS, 11905 LHS, FoundLHS, FoundRHS, CtxI); 11906 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS)) 11907 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI); 11908 11909 // There's no clear preference between forms 3. and 4., try both. Avoid 11910 // forming getNotSCEV of pointer values as the resulting subtract is 11911 // not legal. 11912 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() && 11913 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), 11914 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS, 11915 FoundRHS, CtxI)) 11916 return true; 11917 11918 if (!FoundLHS->getType()->isPointerTy() && 11919 !FoundRHS->getType()->isPointerTy() && 11920 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS), 11921 getNotSCEV(FoundRHS), CtxI)) 11922 return true; 11923 11924 return false; 11925 } 11926 11927 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1, 11928 CmpInst::Predicate P2) { 11929 assert(P1 != P2 && "Handled earlier!"); 11930 return CmpInst::isRelational(P2) && 11931 P1 == ICmpInst::getFlippedSignednessPredicate(P2); 11932 }; 11933 if (IsSignFlippedPredicate(Pred, FoundPred)) { 11934 // Unsigned comparison is the same as signed comparison when both the 11935 // operands are non-negative or negative. 11936 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) || 11937 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS))) 11938 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI); 11939 // Create local copies that we can freely swap and canonicalize our 11940 // conditions to "le/lt". 11941 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred; 11942 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS, 11943 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS; 11944 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) { 11945 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred); 11946 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred); 11947 std::swap(CanonicalLHS, CanonicalRHS); 11948 std::swap(CanonicalFoundLHS, CanonicalFoundRHS); 11949 } 11950 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) && 11951 "Must be!"); 11952 assert((ICmpInst::isLT(CanonicalFoundPred) || 11953 ICmpInst::isLE(CanonicalFoundPred)) && 11954 "Must be!"); 11955 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS)) 11956 // Use implication: 11957 // x <u y && y >=s 0 --> x <s y. 11958 // If we can prove the left part, the right part is also proven. 11959 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS, 11960 CanonicalRHS, CanonicalFoundLHS, 11961 CanonicalFoundRHS); 11962 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS)) 11963 // Use implication: 11964 // x <s y && y <s 0 --> x <u y. 11965 // If we can prove the left part, the right part is also proven. 11966 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS, 11967 CanonicalRHS, CanonicalFoundLHS, 11968 CanonicalFoundRHS); 11969 } 11970 11971 // Check if we can make progress by sharpening ranges. 11972 if (FoundPred == ICmpInst::ICMP_NE && 11973 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { 11974 11975 const SCEVConstant *C = nullptr; 11976 const SCEV *V = nullptr; 11977 11978 if (isa<SCEVConstant>(FoundLHS)) { 11979 C = cast<SCEVConstant>(FoundLHS); 11980 V = FoundRHS; 11981 } else { 11982 C = cast<SCEVConstant>(FoundRHS); 11983 V = FoundLHS; 11984 } 11985 11986 // The guarding predicate tells us that C != V. If the known range 11987 // of V is [C, t), we can sharpen the range to [C + 1, t). The 11988 // range we consider has to correspond to same signedness as the 11989 // predicate we're interested in folding. 11990 11991 APInt Min = ICmpInst::isSigned(Pred) ? 11992 getSignedRangeMin(V) : getUnsignedRangeMin(V); 11993 11994 if (Min == C->getAPInt()) { 11995 // Given (V >= Min && V != Min) we conclude V >= (Min + 1). 11996 // This is true even if (Min + 1) wraps around -- in case of 11997 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). 11998 11999 APInt SharperMin = Min + 1; 12000 12001 switch (Pred) { 12002 case ICmpInst::ICMP_SGE: 12003 case ICmpInst::ICMP_UGE: 12004 // We know V `Pred` SharperMin. If this implies LHS `Pred` 12005 // RHS, we're done. 12006 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin), 12007 CtxI)) 12008 return true; 12009 [[fallthrough]]; 12010 12011 case ICmpInst::ICMP_SGT: 12012 case ICmpInst::ICMP_UGT: 12013 // We know from the range information that (V `Pred` Min || 12014 // V == Min). We know from the guarding condition that !(V 12015 // == Min). This gives us 12016 // 12017 // V `Pred` Min || V == Min && !(V == Min) 12018 // => V `Pred` Min 12019 // 12020 // If V `Pred` Min implies LHS `Pred` RHS, we're done. 12021 12022 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI)) 12023 return true; 12024 break; 12025 12026 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively. 12027 case ICmpInst::ICMP_SLE: 12028 case ICmpInst::ICMP_ULE: 12029 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS, 12030 LHS, V, getConstant(SharperMin), CtxI)) 12031 return true; 12032 [[fallthrough]]; 12033 12034 case ICmpInst::ICMP_SLT: 12035 case ICmpInst::ICMP_ULT: 12036 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS, 12037 LHS, V, getConstant(Min), CtxI)) 12038 return true; 12039 break; 12040 12041 default: 12042 // No change 12043 break; 12044 } 12045 } 12046 } 12047 12048 // Check whether the actual condition is beyond sufficient. 12049 if (FoundPred == ICmpInst::ICMP_EQ) 12050 if (ICmpInst::isTrueWhenEqual(Pred)) 12051 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI)) 12052 return true; 12053 if (Pred == ICmpInst::ICMP_NE) 12054 if (!ICmpInst::isTrueWhenEqual(FoundPred)) 12055 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI)) 12056 return true; 12057 12058 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS)) 12059 return true; 12060 12061 // Otherwise assume the worst. 12062 return false; 12063 } 12064 12065 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, 12066 const SCEV *&L, const SCEV *&R, 12067 SCEV::NoWrapFlags &Flags) { 12068 const auto *AE = dyn_cast<SCEVAddExpr>(Expr); 12069 if (!AE || AE->getNumOperands() != 2) 12070 return false; 12071 12072 L = AE->getOperand(0); 12073 R = AE->getOperand(1); 12074 Flags = AE->getNoWrapFlags(); 12075 return true; 12076 } 12077 12078 std::optional<APInt> 12079 ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { 12080 // We avoid subtracting expressions here because this function is usually 12081 // fairly deep in the call stack (i.e. is called many times). 12082 12083 unsigned BW = getTypeSizeInBits(More->getType()); 12084 APInt Diff(BW, 0); 12085 APInt DiffMul(BW, 1); 12086 // Try various simplifications to reduce the difference to a constant. Limit 12087 // the number of allowed simplifications to keep compile-time low. 12088 for (unsigned I = 0; I < 8; ++I) { 12089 if (More == Less) 12090 return Diff; 12091 12092 // Reduce addrecs with identical steps to their start value. 12093 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) { 12094 const auto *LAR = cast<SCEVAddRecExpr>(Less); 12095 const auto *MAR = cast<SCEVAddRecExpr>(More); 12096 12097 if (LAR->getLoop() != MAR->getLoop()) 12098 return std::nullopt; 12099 12100 // We look at affine expressions only; not for correctness but to keep 12101 // getStepRecurrence cheap. 12102 if (!LAR->isAffine() || !MAR->isAffine()) 12103 return std::nullopt; 12104 12105 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) 12106 return std::nullopt; 12107 12108 Less = LAR->getStart(); 12109 More = MAR->getStart(); 12110 continue; 12111 } 12112 12113 // Try to match a common constant multiply. 12114 auto MatchConstMul = 12115 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> { 12116 auto *M = dyn_cast<SCEVMulExpr>(S); 12117 if (!M || M->getNumOperands() != 2 || 12118 !isa<SCEVConstant>(M->getOperand(0))) 12119 return std::nullopt; 12120 return { 12121 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}}; 12122 }; 12123 if (auto MatchedMore = MatchConstMul(More)) { 12124 if (auto MatchedLess = MatchConstMul(Less)) { 12125 if (MatchedMore->second == MatchedLess->second) { 12126 More = MatchedMore->first; 12127 Less = MatchedLess->first; 12128 DiffMul *= MatchedMore->second; 12129 continue; 12130 } 12131 } 12132 } 12133 12134 // Try to cancel out common factors in two add expressions. 12135 SmallDenseMap<const SCEV *, int, 8> Multiplicity; 12136 auto Add = [&](const SCEV *S, int Mul) { 12137 if (auto *C = dyn_cast<SCEVConstant>(S)) { 12138 if (Mul == 1) { 12139 Diff += C->getAPInt() * DiffMul; 12140 } else { 12141 assert(Mul == -1); 12142 Diff -= C->getAPInt() * DiffMul; 12143 } 12144 } else 12145 Multiplicity[S] += Mul; 12146 }; 12147 auto Decompose = [&](const SCEV *S, int Mul) { 12148 if (isa<SCEVAddExpr>(S)) { 12149 for (const SCEV *Op : S->operands()) 12150 Add(Op, Mul); 12151 } else 12152 Add(S, Mul); 12153 }; 12154 Decompose(More, 1); 12155 Decompose(Less, -1); 12156 12157 // Check whether all the non-constants cancel out, or reduce to new 12158 // More/Less values. 12159 const SCEV *NewMore = nullptr, *NewLess = nullptr; 12160 for (const auto &[S, Mul] : Multiplicity) { 12161 if (Mul == 0) 12162 continue; 12163 if (Mul == 1) { 12164 if (NewMore) 12165 return std::nullopt; 12166 NewMore = S; 12167 } else if (Mul == -1) { 12168 if (NewLess) 12169 return std::nullopt; 12170 NewLess = S; 12171 } else 12172 return std::nullopt; 12173 } 12174 12175 // Values stayed the same, no point in trying further. 12176 if (NewMore == More || NewLess == Less) 12177 return std::nullopt; 12178 12179 More = NewMore; 12180 Less = NewLess; 12181 12182 // Reduced to constant. 12183 if (!More && !Less) 12184 return Diff; 12185 12186 // Left with variable on only one side, bail out. 12187 if (!More || !Less) 12188 return std::nullopt; 12189 } 12190 12191 // Did not reduce to constant. 12192 return std::nullopt; 12193 } 12194 12195 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( 12196 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, 12197 const SCEV *FoundRHS, const Instruction *CtxI) { 12198 // Try to recognize the following pattern: 12199 // 12200 // FoundRHS = ... 12201 // ... 12202 // loop: 12203 // FoundLHS = {Start,+,W} 12204 // context_bb: // Basic block from the same loop 12205 // known(Pred, FoundLHS, FoundRHS) 12206 // 12207 // If some predicate is known in the context of a loop, it is also known on 12208 // each iteration of this loop, including the first iteration. Therefore, in 12209 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to 12210 // prove the original pred using this fact. 12211 if (!CtxI) 12212 return false; 12213 const BasicBlock *ContextBB = CtxI->getParent(); 12214 // Make sure AR varies in the context block. 12215 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) { 12216 const Loop *L = AR->getLoop(); 12217 // Make sure that context belongs to the loop and executes on 1st iteration 12218 // (if it ever executes at all). 12219 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) 12220 return false; 12221 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) 12222 return false; 12223 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); 12224 } 12225 12226 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) { 12227 const Loop *L = AR->getLoop(); 12228 // Make sure that context belongs to the loop and executes on 1st iteration 12229 // (if it ever executes at all). 12230 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) 12231 return false; 12232 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop())) 12233 return false; 12234 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart()); 12235 } 12236 12237 return false; 12238 } 12239 12240 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred, 12241 const SCEV *LHS, 12242 const SCEV *RHS, 12243 const SCEV *FoundLHS, 12244 const SCEV *FoundRHS) { 12245 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) 12246 return false; 12247 12248 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS); 12249 if (!AddRecLHS) 12250 return false; 12251 12252 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS); 12253 if (!AddRecFoundLHS) 12254 return false; 12255 12256 // We'd like to let SCEV reason about control dependencies, so we constrain 12257 // both the inequalities to be about add recurrences on the same loop. This 12258 // way we can use isLoopEntryGuardedByCond later. 12259 12260 const Loop *L = AddRecFoundLHS->getLoop(); 12261 if (L != AddRecLHS->getLoop()) 12262 return false; 12263 12264 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) 12265 // 12266 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) 12267 // ... (2) 12268 // 12269 // Informal proof for (2), assuming (1) [*]: 12270 // 12271 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] 12272 // 12273 // Then 12274 // 12275 // FoundLHS s< FoundRHS s< INT_MIN - C 12276 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] 12277 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] 12278 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< 12279 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] 12280 // <=> FoundLHS + C s< FoundRHS + C 12281 // 12282 // [*]: (1) can be proved by ruling out overflow. 12283 // 12284 // [**]: This can be proved by analyzing all the four possibilities: 12285 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and 12286 // (A s>= 0, B s>= 0). 12287 // 12288 // Note: 12289 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" 12290 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS 12291 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS 12292 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is 12293 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + 12294 // C)". 12295 12296 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); 12297 if (!LDiff) 12298 return false; 12299 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); 12300 if (!RDiff || *LDiff != *RDiff) 12301 return false; 12302 12303 if (LDiff->isMinValue()) 12304 return true; 12305 12306 APInt FoundRHSLimit; 12307 12308 if (Pred == CmpInst::ICMP_ULT) { 12309 FoundRHSLimit = -(*RDiff); 12310 } else { 12311 assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); 12312 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; 12313 } 12314 12315 // Try to prove (1) or (2), as needed. 12316 return isAvailableAtLoopEntry(FoundRHS, L) && 12317 isLoopEntryGuardedByCond(L, Pred, FoundRHS, 12318 getConstant(FoundRHSLimit)); 12319 } 12320 12321 bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS, 12322 const SCEV *RHS, const SCEV *FoundLHS, 12323 const SCEV *FoundRHS, unsigned Depth) { 12324 const PHINode *LPhi = nullptr, *RPhi = nullptr; 12325 12326 auto ClearOnExit = make_scope_exit([&]() { 12327 if (LPhi) { 12328 bool Erased = PendingMerges.erase(LPhi); 12329 assert(Erased && "Failed to erase LPhi!"); 12330 (void)Erased; 12331 } 12332 if (RPhi) { 12333 bool Erased = PendingMerges.erase(RPhi); 12334 assert(Erased && "Failed to erase RPhi!"); 12335 (void)Erased; 12336 } 12337 }); 12338 12339 // Find respective Phis and check that they are not being pending. 12340 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) 12341 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) { 12342 if (!PendingMerges.insert(Phi).second) 12343 return false; 12344 LPhi = Phi; 12345 } 12346 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS)) 12347 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) { 12348 // If we detect a loop of Phi nodes being processed by this method, for 12349 // example: 12350 // 12351 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ] 12352 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ] 12353 // 12354 // we don't want to deal with a case that complex, so return conservative 12355 // answer false. 12356 if (!PendingMerges.insert(Phi).second) 12357 return false; 12358 RPhi = Phi; 12359 } 12360 12361 // If none of LHS, RHS is a Phi, nothing to do here. 12362 if (!LPhi && !RPhi) 12363 return false; 12364 12365 // If there is a SCEVUnknown Phi we are interested in, make it left. 12366 if (!LPhi) { 12367 std::swap(LHS, RHS); 12368 std::swap(FoundLHS, FoundRHS); 12369 std::swap(LPhi, RPhi); 12370 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 12371 } 12372 12373 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!"); 12374 const BasicBlock *LBB = LPhi->getParent(); 12375 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); 12376 12377 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { 12378 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || 12379 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) || 12380 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); 12381 }; 12382 12383 if (RPhi && RPhi->getParent() == LBB) { 12384 // Case one: RHS is also a SCEVUnknown Phi from the same basic block. 12385 // If we compare two Phis from the same block, and for each entry block 12386 // the predicate is true for incoming values from this block, then the 12387 // predicate is also true for the Phis. 12388 for (const BasicBlock *IncBB : predecessors(LBB)) { 12389 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); 12390 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); 12391 if (!ProvedEasily(L, R)) 12392 return false; 12393 } 12394 } else if (RAR && RAR->getLoop()->getHeader() == LBB) { 12395 // Case two: RHS is also a Phi from the same basic block, and it is an 12396 // AddRec. It means that there is a loop which has both AddRec and Unknown 12397 // PHIs, for it we can compare incoming values of AddRec from above the loop 12398 // and latch with their respective incoming values of LPhi. 12399 // TODO: Generalize to handle loops with many inputs in a header. 12400 if (LPhi->getNumIncomingValues() != 2) return false; 12401 12402 auto *RLoop = RAR->getLoop(); 12403 auto *Predecessor = RLoop->getLoopPredecessor(); 12404 assert(Predecessor && "Loop with AddRec with no predecessor?"); 12405 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); 12406 if (!ProvedEasily(L1, RAR->getStart())) 12407 return false; 12408 auto *Latch = RLoop->getLoopLatch(); 12409 assert(Latch && "Loop with AddRec with no latch?"); 12410 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); 12411 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) 12412 return false; 12413 } else { 12414 // In all other cases go over inputs of LHS and compare each of them to RHS, 12415 // the predicate is true for (LHS, RHS) if it is true for all such pairs. 12416 // At this point RHS is either a non-Phi, or it is a Phi from some block 12417 // different from LBB. 12418 for (const BasicBlock *IncBB : predecessors(LBB)) { 12419 // Check that RHS is available in this block. 12420 if (!dominates(RHS, IncBB)) 12421 return false; 12422 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); 12423 // Make sure L does not refer to a value from a potentially previous 12424 // iteration of a loop. 12425 if (!properlyDominates(L, LBB)) 12426 return false; 12427 // Addrecs are considered to properly dominate their loop, so are missed 12428 // by the previous check. Discard any values that have computable 12429 // evolution in this loop. 12430 if (auto *Loop = LI.getLoopFor(LBB)) 12431 if (hasComputableLoopEvolution(L, Loop)) 12432 return false; 12433 if (!ProvedEasily(L, RHS)) 12434 return false; 12435 } 12436 } 12437 return true; 12438 } 12439 12440 bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred, 12441 const SCEV *LHS, 12442 const SCEV *RHS, 12443 const SCEV *FoundLHS, 12444 const SCEV *FoundRHS) { 12445 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make 12446 // sure that we are dealing with same LHS. 12447 if (RHS == FoundRHS) { 12448 std::swap(LHS, RHS); 12449 std::swap(FoundLHS, FoundRHS); 12450 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 12451 } 12452 if (LHS != FoundLHS) 12453 return false; 12454 12455 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS); 12456 if (!SUFoundRHS) 12457 return false; 12458 12459 Value *Shiftee, *ShiftValue; 12460 12461 using namespace PatternMatch; 12462 if (match(SUFoundRHS->getValue(), 12463 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) { 12464 auto *ShifteeS = getSCEV(Shiftee); 12465 // Prove one of the following: 12466 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS 12467 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS 12468 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0 12469 // ---> LHS <s RHS 12470 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0 12471 // ---> LHS <=s RHS 12472 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) 12473 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS); 12474 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) 12475 if (isKnownNonNegative(ShifteeS)) 12476 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS); 12477 } 12478 12479 return false; 12480 } 12481 12482 bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS, 12483 const SCEV *RHS, 12484 const SCEV *FoundLHS, 12485 const SCEV *FoundRHS, 12486 const Instruction *CtxI) { 12487 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, 12488 FoundRHS) || 12489 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, 12490 FoundRHS) || 12491 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) || 12492 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS, 12493 CtxI) || 12494 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS); 12495 } 12496 12497 /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values? 12498 template <typename MinMaxExprType> 12499 static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, 12500 const SCEV *Candidate) { 12501 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr); 12502 if (!MinMaxExpr) 12503 return false; 12504 12505 return is_contained(MinMaxExpr->operands(), Candidate); 12506 } 12507 12508 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, 12509 CmpPredicate Pred, const SCEV *LHS, 12510 const SCEV *RHS) { 12511 // If both sides are affine addrecs for the same loop, with equal 12512 // steps, and we know the recurrences don't wrap, then we only 12513 // need to check the predicate on the starting values. 12514 12515 if (!ICmpInst::isRelational(Pred)) 12516 return false; 12517 12518 const SCEV *LStart, *RStart, *Step; 12519 const Loop *L; 12520 if (!match(LHS, 12521 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) || 12522 !match(RHS, m_scev_AffineAddRec(m_SCEV(RStart), m_scev_Specific(Step), 12523 m_SpecificLoop(L)))) 12524 return false; 12525 const SCEVAddRecExpr *LAR = cast<SCEVAddRecExpr>(LHS); 12526 const SCEVAddRecExpr *RAR = cast<SCEVAddRecExpr>(RHS); 12527 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? 12528 SCEV::FlagNSW : SCEV::FlagNUW; 12529 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) 12530 return false; 12531 12532 return SE.isKnownPredicate(Pred, LStart, RStart); 12533 } 12534 12535 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max 12536 /// expression? 12537 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, 12538 const SCEV *LHS, const SCEV *RHS) { 12539 switch (Pred) { 12540 default: 12541 return false; 12542 12543 case ICmpInst::ICMP_SGE: 12544 std::swap(LHS, RHS); 12545 [[fallthrough]]; 12546 case ICmpInst::ICMP_SLE: 12547 return 12548 // min(A, ...) <= A 12549 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) || 12550 // A <= max(A, ...) 12551 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); 12552 12553 case ICmpInst::ICMP_UGE: 12554 std::swap(LHS, RHS); 12555 [[fallthrough]]; 12556 case ICmpInst::ICMP_ULE: 12557 return 12558 // min(A, ...) <= A 12559 // FIXME: what about umin_seq? 12560 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) || 12561 // A <= max(A, ...) 12562 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); 12563 } 12564 12565 llvm_unreachable("covered switch fell through?!"); 12566 } 12567 12568 bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS, 12569 const SCEV *RHS, 12570 const SCEV *FoundLHS, 12571 const SCEV *FoundRHS, 12572 unsigned Depth) { 12573 assert(getTypeSizeInBits(LHS->getType()) == 12574 getTypeSizeInBits(RHS->getType()) && 12575 "LHS and RHS have different sizes?"); 12576 assert(getTypeSizeInBits(FoundLHS->getType()) == 12577 getTypeSizeInBits(FoundRHS->getType()) && 12578 "FoundLHS and FoundRHS have different sizes?"); 12579 // We want to avoid hurting the compile time with analysis of too big trees. 12580 if (Depth > MaxSCEVOperationsImplicationDepth) 12581 return false; 12582 12583 // We only want to work with GT comparison so far. 12584 if (ICmpInst::isLT(Pred)) { 12585 Pred = ICmpInst::getSwappedCmpPredicate(Pred); 12586 std::swap(LHS, RHS); 12587 std::swap(FoundLHS, FoundRHS); 12588 } 12589 12590 CmpInst::Predicate P = Pred.getPreferredSignedPredicate(); 12591 12592 // For unsigned, try to reduce it to corresponding signed comparison. 12593 if (P == ICmpInst::ICMP_UGT) 12594 // We can replace unsigned predicate with its signed counterpart if all 12595 // involved values are non-negative. 12596 // TODO: We could have better support for unsigned. 12597 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { 12598 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing 12599 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us 12600 // use this fact to prove that LHS and RHS are non-negative. 12601 const SCEV *MinusOne = getMinusOne(LHS->getType()); 12602 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, 12603 FoundRHS) && 12604 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, 12605 FoundRHS)) 12606 P = ICmpInst::ICMP_SGT; 12607 } 12608 12609 if (P != ICmpInst::ICMP_SGT) 12610 return false; 12611 12612 auto GetOpFromSExt = [&](const SCEV *S) { 12613 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) 12614 return Ext->getOperand(); 12615 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off 12616 // the constant in some cases. 12617 return S; 12618 }; 12619 12620 // Acquire values from extensions. 12621 auto *OrigLHS = LHS; 12622 auto *OrigFoundLHS = FoundLHS; 12623 LHS = GetOpFromSExt(LHS); 12624 FoundLHS = GetOpFromSExt(FoundLHS); 12625 12626 // Is the SGT predicate can be proved trivially or using the found context. 12627 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { 12628 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || 12629 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, 12630 FoundRHS, Depth + 1); 12631 }; 12632 12633 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { 12634 // We want to avoid creation of any new non-constant SCEV. Since we are 12635 // going to compare the operands to RHS, we should be certain that we don't 12636 // need any size extensions for this. So let's decline all cases when the 12637 // sizes of types of LHS and RHS do not match. 12638 // TODO: Maybe try to get RHS from sext to catch more cases? 12639 if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) 12640 return false; 12641 12642 // Should not overflow. 12643 if (!LHSAddExpr->hasNoSignedWrap()) 12644 return false; 12645 12646 auto *LL = LHSAddExpr->getOperand(0); 12647 auto *LR = LHSAddExpr->getOperand(1); 12648 auto *MinusOne = getMinusOne(RHS->getType()); 12649 12650 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. 12651 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { 12652 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); 12653 }; 12654 // Try to prove the following rule: 12655 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). 12656 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). 12657 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) 12658 return true; 12659 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { 12660 Value *LL, *LR; 12661 // FIXME: Once we have SDiv implemented, we can get rid of this matching. 12662 12663 using namespace llvm::PatternMatch; 12664 12665 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { 12666 // Rules for division. 12667 // We are going to perform some comparisons with Denominator and its 12668 // derivative expressions. In general case, creating a SCEV for it may 12669 // lead to a complex analysis of the entire graph, and in particular it 12670 // can request trip count recalculation for the same loop. This would 12671 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid 12672 // this, we only want to create SCEVs that are constants in this section. 12673 // So we bail if Denominator is not a constant. 12674 if (!isa<ConstantInt>(LR)) 12675 return false; 12676 12677 auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); 12678 12679 // We want to make sure that LHS = FoundLHS / Denominator. If it is so, 12680 // then a SCEV for the numerator already exists and matches with FoundLHS. 12681 auto *Numerator = getExistingSCEV(LL); 12682 if (!Numerator || Numerator->getType() != FoundLHS->getType()) 12683 return false; 12684 12685 // Make sure that the numerator matches with FoundLHS and the denominator 12686 // is positive. 12687 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) 12688 return false; 12689 12690 auto *DTy = Denominator->getType(); 12691 auto *FRHSTy = FoundRHS->getType(); 12692 if (DTy->isPointerTy() != FRHSTy->isPointerTy()) 12693 // One of types is a pointer and another one is not. We cannot extend 12694 // them properly to a wider type, so let us just reject this case. 12695 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help 12696 // to avoid this check. 12697 return false; 12698 12699 // Given that: 12700 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. 12701 auto *WTy = getWiderType(DTy, FRHSTy); 12702 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); 12703 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); 12704 12705 // Try to prove the following rule: 12706 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). 12707 // For example, given that FoundLHS > 2. It means that FoundLHS is at 12708 // least 3. If we divide it by Denominator < 4, we will have at least 1. 12709 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); 12710 if (isKnownNonPositive(RHS) && 12711 IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) 12712 return true; 12713 12714 // Try to prove the following rule: 12715 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). 12716 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. 12717 // If we divide it by Denominator > 2, then: 12718 // 1. If FoundLHS is negative, then the result is 0. 12719 // 2. If FoundLHS is non-negative, then the result is non-negative. 12720 // Anyways, the result is non-negative. 12721 auto *MinusOne = getMinusOne(WTy); 12722 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); 12723 if (isKnownNegative(RHS) && 12724 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) 12725 return true; 12726 } 12727 } 12728 12729 // If our expression contained SCEVUnknown Phis, and we split it down and now 12730 // need to prove something for them, try to prove the predicate for every 12731 // possible incoming values of those Phis. 12732 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1)) 12733 return true; 12734 12735 return false; 12736 } 12737 12738 static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, 12739 const SCEV *RHS) { 12740 // zext x u<= sext x, sext x s<= zext x 12741 const SCEV *Op; 12742 switch (Pred) { 12743 case ICmpInst::ICMP_SGE: 12744 std::swap(LHS, RHS); 12745 [[fallthrough]]; 12746 case ICmpInst::ICMP_SLE: { 12747 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt. 12748 return match(LHS, m_scev_SExt(m_SCEV(Op))) && 12749 match(RHS, m_scev_ZExt(m_scev_Specific(Op))); 12750 } 12751 case ICmpInst::ICMP_UGE: 12752 std::swap(LHS, RHS); 12753 [[fallthrough]]; 12754 case ICmpInst::ICMP_ULE: { 12755 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt. 12756 return match(LHS, m_scev_ZExt(m_SCEV(Op))) && 12757 match(RHS, m_scev_SExt(m_scev_Specific(Op))); 12758 } 12759 default: 12760 return false; 12761 }; 12762 llvm_unreachable("unhandled case"); 12763 } 12764 12765 bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred, 12766 const SCEV *LHS, 12767 const SCEV *RHS) { 12768 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) || 12769 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || 12770 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || 12771 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || 12772 isKnownPredicateViaNoOverflow(Pred, LHS, RHS); 12773 } 12774 12775 bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred, 12776 const SCEV *LHS, 12777 const SCEV *RHS, 12778 const SCEV *FoundLHS, 12779 const SCEV *FoundRHS) { 12780 switch (Pred) { 12781 default: 12782 llvm_unreachable("Unexpected CmpPredicate value!"); 12783 case ICmpInst::ICMP_EQ: 12784 case ICmpInst::ICMP_NE: 12785 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) 12786 return true; 12787 break; 12788 case ICmpInst::ICMP_SLT: 12789 case ICmpInst::ICMP_SLE: 12790 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && 12791 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) 12792 return true; 12793 break; 12794 case ICmpInst::ICMP_SGT: 12795 case ICmpInst::ICMP_SGE: 12796 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && 12797 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) 12798 return true; 12799 break; 12800 case ICmpInst::ICMP_ULT: 12801 case ICmpInst::ICMP_ULE: 12802 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && 12803 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) 12804 return true; 12805 break; 12806 case ICmpInst::ICMP_UGT: 12807 case ICmpInst::ICMP_UGE: 12808 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && 12809 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) 12810 return true; 12811 break; 12812 } 12813 12814 // Maybe it can be proved via operations? 12815 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) 12816 return true; 12817 12818 return false; 12819 } 12820 12821 bool ScalarEvolution::isImpliedCondOperandsViaRanges( 12822 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred, 12823 const SCEV *FoundLHS, const SCEV *FoundRHS) { 12824 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS)) 12825 // The restriction on `FoundRHS` be lifted easily -- it exists only to 12826 // reduce the compile time impact of this optimization. 12827 return false; 12828 12829 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); 12830 if (!Addend) 12831 return false; 12832 12833 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); 12834 12835 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the 12836 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`". 12837 ConstantRange FoundLHSRange = 12838 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS); 12839 12840 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: 12841 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend)); 12842 12843 // We can also compute the range of values for `LHS` that satisfy the 12844 // consequent, "`LHS` `Pred` `RHS`": 12845 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); 12846 // The antecedent implies the consequent if every value of `LHS` that 12847 // satisfies the antecedent also satisfies the consequent. 12848 return LHSRange.icmp(Pred, ConstRHS); 12849 } 12850 12851 bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, 12852 bool IsSigned) { 12853 assert(isKnownPositive(Stride) && "Positive stride expected!"); 12854 12855 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 12856 const SCEV *One = getOne(Stride->getType()); 12857 12858 if (IsSigned) { 12859 APInt MaxRHS = getSignedRangeMax(RHS); 12860 APInt MaxValue = APInt::getSignedMaxValue(BitWidth); 12861 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); 12862 12863 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! 12864 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS); 12865 } 12866 12867 APInt MaxRHS = getUnsignedRangeMax(RHS); 12868 APInt MaxValue = APInt::getMaxValue(BitWidth); 12869 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); 12870 12871 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! 12872 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); 12873 } 12874 12875 bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, 12876 bool IsSigned) { 12877 12878 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 12879 const SCEV *One = getOne(Stride->getType()); 12880 12881 if (IsSigned) { 12882 APInt MinRHS = getSignedRangeMin(RHS); 12883 APInt MinValue = APInt::getSignedMinValue(BitWidth); 12884 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); 12885 12886 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! 12887 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); 12888 } 12889 12890 APInt MinRHS = getUnsignedRangeMin(RHS); 12891 APInt MinValue = APInt::getMinValue(BitWidth); 12892 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); 12893 12894 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! 12895 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); 12896 } 12897 12898 const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) { 12899 // umin(N, 1) + floor((N - umin(N, 1)) / D) 12900 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin 12901 // expression fixes the case of N=0. 12902 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType())); 12903 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne); 12904 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D)); 12905 } 12906 12907 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, 12908 const SCEV *Stride, 12909 const SCEV *End, 12910 unsigned BitWidth, 12911 bool IsSigned) { 12912 // The logic in this function assumes we can represent a positive stride. 12913 // If we can't, the backedge-taken count must be zero. 12914 if (IsSigned && BitWidth == 1) 12915 return getZero(Stride->getType()); 12916 12917 // This code below only been closely audited for negative strides in the 12918 // unsigned comparison case, it may be correct for signed comparison, but 12919 // that needs to be established. 12920 if (IsSigned && isKnownNegative(Stride)) 12921 return getCouldNotCompute(); 12922 12923 // Calculate the maximum backedge count based on the range of values 12924 // permitted by Start, End, and Stride. 12925 APInt MinStart = 12926 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start); 12927 12928 APInt MinStride = 12929 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride); 12930 12931 // We assume either the stride is positive, or the backedge-taken count 12932 // is zero. So force StrideForMaxBECount to be at least one. 12933 APInt One(BitWidth, 1); 12934 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride) 12935 : APIntOps::umax(One, MinStride); 12936 12937 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth) 12938 : APInt::getMaxValue(BitWidth); 12939 APInt Limit = MaxValue - (StrideForMaxBECount - 1); 12940 12941 // Although End can be a MAX expression we estimate MaxEnd considering only 12942 // the case End = RHS of the loop termination condition. This is safe because 12943 // in the other case (End - Start) is zero, leading to a zero maximum backedge 12944 // taken count. 12945 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) 12946 : APIntOps::umin(getUnsignedRangeMax(End), Limit); 12947 12948 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride) 12949 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart) 12950 : APIntOps::umax(MaxEnd, MinStart); 12951 12952 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */, 12953 getConstant(StrideForMaxBECount) /* Step */); 12954 } 12955 12956 ScalarEvolution::ExitLimit 12957 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, 12958 const Loop *L, bool IsSigned, 12959 bool ControlsOnlyExit, bool AllowPredicates) { 12960 SmallVector<const SCEVPredicate *> Predicates; 12961 12962 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 12963 bool PredicatedIV = false; 12964 if (!IV) { 12965 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) { 12966 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand()); 12967 if (AR && AR->getLoop() == L && AR->isAffine()) { 12968 auto canProveNUW = [&]() { 12969 // We can use the comparison to infer no-wrap flags only if it fully 12970 // controls the loop exit. 12971 if (!ControlsOnlyExit) 12972 return false; 12973 12974 if (!isLoopInvariant(RHS, L)) 12975 return false; 12976 12977 if (!isKnownNonZero(AR->getStepRecurrence(*this))) 12978 // We need the sequence defined by AR to strictly increase in the 12979 // unsigned integer domain for the logic below to hold. 12980 return false; 12981 12982 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType()); 12983 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType()); 12984 // If RHS <=u Limit, then there must exist a value V in the sequence 12985 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and 12986 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned 12987 // overflow occurs. This limit also implies that a signed comparison 12988 // (in the wide bitwidth) is equivalent to an unsigned comparison as 12989 // the high bits on both sides must be zero. 12990 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); 12991 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); 12992 Limit = Limit.zext(OuterBitWidth); 12993 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); 12994 }; 12995 auto Flags = AR->getNoWrapFlags(); 12996 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) 12997 Flags = setFlags(Flags, SCEV::FlagNUW); 12998 12999 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags); 13000 if (AR->hasNoUnsignedWrap()) { 13001 // Emulate what getZeroExtendExpr would have done during construction 13002 // if we'd been able to infer the fact just above at that time. 13003 const SCEV *Step = AR->getStepRecurrence(*this); 13004 Type *Ty = ZExt->getType(); 13005 auto *S = getAddRecExpr( 13006 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0), 13007 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); 13008 IV = dyn_cast<SCEVAddRecExpr>(S); 13009 } 13010 } 13011 } 13012 } 13013 13014 13015 if (!IV && AllowPredicates) { 13016 // Try to make this an AddRec using runtime tests, in the first X 13017 // iterations of this loop, where X is the SCEV expression found by the 13018 // algorithm below. 13019 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); 13020 PredicatedIV = true; 13021 } 13022 13023 // Avoid weird loops 13024 if (!IV || IV->getLoop() != L || !IV->isAffine()) 13025 return getCouldNotCompute(); 13026 13027 // A precondition of this method is that the condition being analyzed 13028 // reaches an exiting branch which dominates the latch. Given that, we can 13029 // assume that an increment which violates the nowrap specification and 13030 // produces poison must cause undefined behavior when the resulting poison 13031 // value is branched upon and thus we can conclude that the backedge is 13032 // taken no more often than would be required to produce that poison value. 13033 // Note that a well defined loop can exit on the iteration which violates 13034 // the nowrap specification if there is another exit (either explicit or 13035 // implicit/exceptional) which causes the loop to execute before the 13036 // exiting instruction we're analyzing would trigger UB. 13037 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW; 13038 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); 13039 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 13040 13041 const SCEV *Stride = IV->getStepRecurrence(*this); 13042 13043 bool PositiveStride = isKnownPositive(Stride); 13044 13045 // Avoid negative or zero stride values. 13046 if (!PositiveStride) { 13047 // We can compute the correct backedge taken count for loops with unknown 13048 // strides if we can prove that the loop is not an infinite loop with side 13049 // effects. Here's the loop structure we are trying to handle - 13050 // 13051 // i = start 13052 // do { 13053 // A[i] = i; 13054 // i += s; 13055 // } while (i < end); 13056 // 13057 // The backedge taken count for such loops is evaluated as - 13058 // (max(end, start + stride) - start - 1) /u stride 13059 // 13060 // The additional preconditions that we need to check to prove correctness 13061 // of the above formula is as follows - 13062 // 13063 // a) IV is either nuw or nsw depending upon signedness (indicated by the 13064 // NoWrap flag). 13065 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has 13066 // no side effects within the loop) 13067 // c) loop has a single static exit (with no abnormal exits) 13068 // 13069 // Precondition a) implies that if the stride is negative, this is a single 13070 // trip loop. The backedge taken count formula reduces to zero in this case. 13071 // 13072 // Precondition b) and c) combine to imply that if rhs is invariant in L, 13073 // then a zero stride means the backedge can't be taken without executing 13074 // undefined behavior. 13075 // 13076 // The positive stride case is the same as isKnownPositive(Stride) returning 13077 // true (original behavior of the function). 13078 // 13079 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) || 13080 !loopHasNoAbnormalExits(L)) 13081 return getCouldNotCompute(); 13082 13083 if (!isKnownNonZero(Stride)) { 13084 // If we have a step of zero, and RHS isn't invariant in L, we don't know 13085 // if it might eventually be greater than start and if so, on which 13086 // iteration. We can't even produce a useful upper bound. 13087 if (!isLoopInvariant(RHS, L)) 13088 return getCouldNotCompute(); 13089 13090 // We allow a potentially zero stride, but we need to divide by stride 13091 // below. Since the loop can't be infinite and this check must control 13092 // the sole exit, we can infer the exit must be taken on the first 13093 // iteration (e.g. backedge count = 0) if the stride is zero. Given that, 13094 // we know the numerator in the divides below must be zero, so we can 13095 // pick an arbitrary non-zero value for the denominator (e.g. stride) 13096 // and produce the right result. 13097 // FIXME: Handle the case where Stride is poison? 13098 auto wouldZeroStrideBeUB = [&]() { 13099 // Proof by contradiction. Suppose the stride were zero. If we can 13100 // prove that the backedge *is* taken on the first iteration, then since 13101 // we know this condition controls the sole exit, we must have an 13102 // infinite loop. We can't have a (well defined) infinite loop per 13103 // check just above. 13104 // Note: The (Start - Stride) term is used to get the start' term from 13105 // (start' + stride,+,stride). Remember that we only care about the 13106 // result of this expression when stride == 0 at runtime. 13107 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride); 13108 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS); 13109 }; 13110 if (!wouldZeroStrideBeUB()) { 13111 Stride = getUMaxExpr(Stride, getOne(Stride->getType())); 13112 } 13113 } 13114 } else if (!NoWrap) { 13115 // Avoid proven overflow cases: this will ensure that the backedge taken 13116 // count will not generate any unsigned overflow. 13117 if (canIVOverflowOnLT(RHS, Stride, IsSigned)) 13118 return getCouldNotCompute(); 13119 } 13120 13121 // On all paths just preceeding, we established the following invariant: 13122 // IV can be assumed not to overflow up to and including the exiting 13123 // iteration. We proved this in one of two ways: 13124 // 1) We can show overflow doesn't occur before the exiting iteration 13125 // 1a) canIVOverflowOnLT, and b) step of one 13126 // 2) We can show that if overflow occurs, the loop must execute UB 13127 // before any possible exit. 13128 // Note that we have not yet proved RHS invariant (in general). 13129 13130 const SCEV *Start = IV->getStart(); 13131 13132 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond. 13133 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases. 13134 // Use integer-typed versions for actual computation; we can't subtract 13135 // pointers in general. 13136 const SCEV *OrigStart = Start; 13137 const SCEV *OrigRHS = RHS; 13138 if (Start->getType()->isPointerTy()) { 13139 Start = getLosslessPtrToIntExpr(Start); 13140 if (isa<SCEVCouldNotCompute>(Start)) 13141 return Start; 13142 } 13143 if (RHS->getType()->isPointerTy()) { 13144 RHS = getLosslessPtrToIntExpr(RHS); 13145 if (isa<SCEVCouldNotCompute>(RHS)) 13146 return RHS; 13147 } 13148 13149 const SCEV *End = nullptr, *BECount = nullptr, 13150 *BECountIfBackedgeTaken = nullptr; 13151 if (!isLoopInvariant(RHS, L)) { 13152 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS); 13153 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L && 13154 RHSAddRec->getNoWrapFlags()) { 13155 // The structure of loop we are trying to calculate backedge count of: 13156 // 13157 // left = left_start 13158 // right = right_start 13159 // 13160 // while(left < right){ 13161 // ... do something here ... 13162 // left += s1; // stride of left is s1 (s1 > 0) 13163 // right += s2; // stride of right is s2 (s2 < 0) 13164 // } 13165 // 13166 13167 const SCEV *RHSStart = RHSAddRec->getStart(); 13168 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this); 13169 13170 // If Stride - RHSStride is positive and does not overflow, we can write 13171 // backedge count as -> 13172 // ceil((End - Start) /u (Stride - RHSStride)) 13173 // Where, End = max(RHSStart, Start) 13174 13175 // Check if RHSStride < 0 and Stride - RHSStride will not overflow. 13176 if (isKnownNegative(RHSStride) && 13177 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride, 13178 RHSStride)) { 13179 13180 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride); 13181 if (isKnownPositive(Denominator)) { 13182 End = IsSigned ? getSMaxExpr(RHSStart, Start) 13183 : getUMaxExpr(RHSStart, Start); 13184 13185 // We can do this because End >= Start, as End = max(RHSStart, Start) 13186 const SCEV *Delta = getMinusSCEV(End, Start); 13187 13188 BECount = getUDivCeilSCEV(Delta, Denominator); 13189 BECountIfBackedgeTaken = 13190 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator); 13191 } 13192 } 13193 } 13194 if (BECount == nullptr) { 13195 // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 13196 // given the start, stride and max value for the end bound of the 13197 // loop (RHS), and the fact that IV does not overflow (which is 13198 // checked above). 13199 const SCEV *MaxBECount = computeMaxBECountForLT( 13200 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); 13201 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, 13202 MaxBECount, false /*MaxOrZero*/, Predicates); 13203 } 13204 } else { 13205 // We use the expression (max(End,Start)-Start)/Stride to describe the 13206 // backedge count, as if the backedge is taken at least once 13207 // max(End,Start) is End and so the result is as above, and if not 13208 // max(End,Start) is Start so we get a backedge count of zero. 13209 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); 13210 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!"); 13211 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!"); 13212 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!"); 13213 // Can we prove (max(RHS,Start) > Start - Stride? 13214 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) && 13215 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) { 13216 // In this case, we can use a refined formula for computing backedge 13217 // taken count. The general formula remains: 13218 // "End-Start /uceiling Stride" where "End = max(RHS,Start)" 13219 // We want to use the alternate formula: 13220 // "((End - 1) - (Start - Stride)) /u Stride" 13221 // Let's do a quick case analysis to show these are equivalent under 13222 // our precondition that max(RHS,Start) > Start - Stride. 13223 // * For RHS <= Start, the backedge-taken count must be zero. 13224 // "((End - 1) - (Start - Stride)) /u Stride" reduces to 13225 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to 13226 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values 13227 // of Stride. For 0 stride, we've use umin(1,Stride) above, 13228 // reducing this to the stride of 1 case. 13229 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil 13230 // Stride". 13231 // "((End - 1) - (Start - Stride)) /u Stride" reduces to 13232 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to 13233 // "((RHS - (Start - Stride) - 1) /u Stride". 13234 // Our preconditions trivially imply no overflow in that form. 13235 const SCEV *MinusOne = getMinusOne(Stride->getType()); 13236 const SCEV *Numerator = 13237 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride)); 13238 BECount = getUDivExpr(Numerator, Stride); 13239 } 13240 13241 if (!BECount) { 13242 auto canProveRHSGreaterThanEqualStart = [&]() { 13243 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; 13244 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); 13245 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); 13246 13247 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) || 13248 isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) 13249 return true; 13250 13251 // (RHS > Start - 1) implies RHS >= Start. 13252 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if 13253 // "Start - 1" doesn't overflow. 13254 // * For signed comparison, if Start - 1 does overflow, it's equal 13255 // to INT_MAX, and "RHS >s INT_MAX" is trivially false. 13256 // * For unsigned comparison, if Start - 1 does overflow, it's equal 13257 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false. 13258 // 13259 // FIXME: Should isLoopEntryGuardedByCond do this for us? 13260 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; 13261 auto *StartMinusOne = 13262 getAddExpr(OrigStart, getMinusOne(OrigStart->getType())); 13263 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne); 13264 }; 13265 13266 // If we know that RHS >= Start in the context of loop, then we know 13267 // that max(RHS, Start) = RHS at this point. 13268 if (canProveRHSGreaterThanEqualStart()) { 13269 End = RHS; 13270 } else { 13271 // If RHS < Start, the backedge will be taken zero times. So in 13272 // general, we can write the backedge-taken count as: 13273 // 13274 // RHS >= Start ? ceil(RHS - Start) / Stride : 0 13275 // 13276 // We convert it to the following to make it more convenient for SCEV: 13277 // 13278 // ceil(max(RHS, Start) - Start) / Stride 13279 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); 13280 13281 // See what would happen if we assume the backedge is taken. This is 13282 // used to compute MaxBECount. 13283 BECountIfBackedgeTaken = 13284 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride); 13285 } 13286 13287 // At this point, we know: 13288 // 13289 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End 13290 // 2. The index variable doesn't overflow. 13291 // 13292 // Therefore, we know N exists such that 13293 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)" 13294 // doesn't overflow. 13295 // 13296 // Using this information, try to prove whether the addition in 13297 // "(Start - End) + (Stride - 1)" has unsigned overflow. 13298 const SCEV *One = getOne(Stride->getType()); 13299 bool MayAddOverflow = [&] { 13300 if (isKnownToBeAPowerOfTwo(Stride)) { 13301 // Suppose Stride is a power of two, and Start/End are unsigned 13302 // integers. Let UMAX be the largest representable unsigned 13303 // integer. 13304 // 13305 // By the preconditions of this function, we know 13306 // "(Start + Stride * N) >= End", and this doesn't overflow. 13307 // As a formula: 13308 // 13309 // End <= (Start + Stride * N) <= UMAX 13310 // 13311 // Subtracting Start from all the terms: 13312 // 13313 // End - Start <= Stride * N <= UMAX - Start 13314 // 13315 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore: 13316 // 13317 // End - Start <= Stride * N <= UMAX 13318 // 13319 // Stride * N is a multiple of Stride. Therefore, 13320 // 13321 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride) 13322 // 13323 // Since Stride is a power of two, UMAX + 1 is divisible by 13324 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can 13325 // write: 13326 // 13327 // End - Start <= Stride * N <= UMAX - Stride - 1 13328 // 13329 // Dropping the middle term: 13330 // 13331 // End - Start <= UMAX - Stride - 1 13332 // 13333 // Adding Stride - 1 to both sides: 13334 // 13335 // (End - Start) + (Stride - 1) <= UMAX 13336 // 13337 // In other words, the addition doesn't have unsigned overflow. 13338 // 13339 // A similar proof works if we treat Start/End as signed values. 13340 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" 13341 // to use signed max instead of unsigned max. Note that we're 13342 // trying to prove a lack of unsigned overflow in either case. 13343 return false; 13344 } 13345 if (Start == Stride || Start == getMinusSCEV(Stride, One)) { 13346 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End 13347 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 13348 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 13349 // 1 <s End. 13350 // 13351 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == 13352 // End. 13353 return false; 13354 } 13355 return true; 13356 }(); 13357 13358 const SCEV *Delta = getMinusSCEV(End, Start); 13359 if (!MayAddOverflow) { 13360 // floor((D + (S - 1)) / S) 13361 // We prefer this formulation if it's legal because it's fewer 13362 // operations. 13363 BECount = 13364 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride); 13365 } else { 13366 BECount = getUDivCeilSCEV(Delta, Stride); 13367 } 13368 } 13369 } 13370 13371 const SCEV *ConstantMaxBECount; 13372 bool MaxOrZero = false; 13373 if (isa<SCEVConstant>(BECount)) { 13374 ConstantMaxBECount = BECount; 13375 } else if (BECountIfBackedgeTaken && 13376 isa<SCEVConstant>(BECountIfBackedgeTaken)) { 13377 // If we know exactly how many times the backedge will be taken if it's 13378 // taken at least once, then the backedge count will either be that or 13379 // zero. 13380 ConstantMaxBECount = BECountIfBackedgeTaken; 13381 MaxOrZero = true; 13382 } else { 13383 ConstantMaxBECount = computeMaxBECountForLT( 13384 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); 13385 } 13386 13387 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) && 13388 !isa<SCEVCouldNotCompute>(BECount)) 13389 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); 13390 13391 const SCEV *SymbolicMaxBECount = 13392 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; 13393 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, 13394 Predicates); 13395 } 13396 13397 ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( 13398 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, 13399 bool ControlsOnlyExit, bool AllowPredicates) { 13400 SmallVector<const SCEVPredicate *> Predicates; 13401 // We handle only IV > Invariant 13402 if (!isLoopInvariant(RHS, L)) 13403 return getCouldNotCompute(); 13404 13405 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 13406 if (!IV && AllowPredicates) 13407 // Try to make this an AddRec using runtime tests, in the first X 13408 // iterations of this loop, where X is the SCEV expression found by the 13409 // algorithm below. 13410 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); 13411 13412 // Avoid weird loops 13413 if (!IV || IV->getLoop() != L || !IV->isAffine()) 13414 return getCouldNotCompute(); 13415 13416 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW; 13417 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); 13418 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; 13419 13420 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); 13421 13422 // Avoid negative or zero stride values 13423 if (!isKnownPositive(Stride)) 13424 return getCouldNotCompute(); 13425 13426 // Avoid proven overflow cases: this will ensure that the backedge taken count 13427 // will not generate any unsigned overflow. Relaxed no-overflow conditions 13428 // exploit NoWrapFlags, allowing to optimize in presence of undefined 13429 // behaviors like the case of C language. 13430 if (!Stride->isOne() && !NoWrap) 13431 if (canIVOverflowOnGT(RHS, Stride, IsSigned)) 13432 return getCouldNotCompute(); 13433 13434 const SCEV *Start = IV->getStart(); 13435 const SCEV *End = RHS; 13436 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { 13437 // If we know that Start >= RHS in the context of loop, then we know that 13438 // min(RHS, Start) = RHS at this point. 13439 if (isLoopEntryGuardedByCond( 13440 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS)) 13441 End = RHS; 13442 else 13443 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); 13444 } 13445 13446 if (Start->getType()->isPointerTy()) { 13447 Start = getLosslessPtrToIntExpr(Start); 13448 if (isa<SCEVCouldNotCompute>(Start)) 13449 return Start; 13450 } 13451 if (End->getType()->isPointerTy()) { 13452 End = getLosslessPtrToIntExpr(End); 13453 if (isa<SCEVCouldNotCompute>(End)) 13454 return End; 13455 } 13456 13457 // Compute ((Start - End) + (Stride - 1)) / Stride. 13458 // FIXME: This can overflow. Holding off on fixing this for now; 13459 // howManyGreaterThans will hopefully be gone soon. 13460 const SCEV *One = getOne(Stride->getType()); 13461 const SCEV *BECount = getUDivExpr( 13462 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride); 13463 13464 APInt MaxStart = IsSigned ? getSignedRangeMax(Start) 13465 : getUnsignedRangeMax(Start); 13466 13467 APInt MinStride = IsSigned ? getSignedRangeMin(Stride) 13468 : getUnsignedRangeMin(Stride); 13469 13470 unsigned BitWidth = getTypeSizeInBits(LHS->getType()); 13471 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) 13472 : APInt::getMinValue(BitWidth) + (MinStride - 1); 13473 13474 // Although End can be a MIN expression we estimate MinEnd considering only 13475 // the case End = RHS. This is safe because in the other case (Start - End) 13476 // is zero, leading to a zero maximum backedge taken count. 13477 APInt MinEnd = 13478 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) 13479 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); 13480 13481 const SCEV *ConstantMaxBECount = 13482 isa<SCEVConstant>(BECount) 13483 ? BECount 13484 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd), 13485 getConstant(MinStride)); 13486 13487 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount)) 13488 ConstantMaxBECount = BECount; 13489 const SCEV *SymbolicMaxBECount = 13490 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; 13491 13492 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, 13493 Predicates); 13494 } 13495 13496 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, 13497 ScalarEvolution &SE) const { 13498 if (Range.isFullSet()) // Infinite loop. 13499 return SE.getCouldNotCompute(); 13500 13501 // If the start is a non-zero constant, shift the range to simplify things. 13502 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) 13503 if (!SC->getValue()->isZero()) { 13504 SmallVector<const SCEV *, 4> Operands(operands()); 13505 Operands[0] = SE.getZero(SC->getType()); 13506 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), 13507 getNoWrapFlags(FlagNW)); 13508 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted)) 13509 return ShiftedAddRec->getNumIterationsInRange( 13510 Range.subtract(SC->getAPInt()), SE); 13511 // This is strange and shouldn't happen. 13512 return SE.getCouldNotCompute(); 13513 } 13514 13515 // The only time we can solve this is when we have all constant indices. 13516 // Otherwise, we cannot determine the overflow conditions. 13517 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); })) 13518 return SE.getCouldNotCompute(); 13519 13520 // Okay at this point we know that all elements of the chrec are constants and 13521 // that the start element is zero. 13522 13523 // First check to see if the range contains zero. If not, the first 13524 // iteration exits. 13525 unsigned BitWidth = SE.getTypeSizeInBits(getType()); 13526 if (!Range.contains(APInt(BitWidth, 0))) 13527 return SE.getZero(getType()); 13528 13529 if (isAffine()) { 13530 // If this is an affine expression then we have this situation: 13531 // Solve {0,+,A} in Range === Ax in Range 13532 13533 // We know that zero is in the range. If A is positive then we know that 13534 // the upper value of the range must be the first possible exit value. 13535 // If A is negative then the lower of the range is the last possible loop 13536 // value. Also note that we already checked for a full range. 13537 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); 13538 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower(); 13539 13540 // The exit value should be (End+A)/A. 13541 APInt ExitVal = (End + A).udiv(A); 13542 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); 13543 13544 // Evaluate at the exit value. If we really did fall out of the valid 13545 // range, then we computed our trip count, otherwise wrap around or other 13546 // things must have happened. 13547 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); 13548 if (Range.contains(Val->getValue())) 13549 return SE.getCouldNotCompute(); // Something strange happened 13550 13551 // Ensure that the previous value is in the range. 13552 assert(Range.contains( 13553 EvaluateConstantChrecAtConstant(this, 13554 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && 13555 "Linear scev computation is off in a bad way!"); 13556 return SE.getConstant(ExitValue); 13557 } 13558 13559 if (isQuadratic()) { 13560 if (auto S = SolveQuadraticAddRecRange(this, Range, SE)) 13561 return SE.getConstant(*S); 13562 } 13563 13564 return SE.getCouldNotCompute(); 13565 } 13566 13567 const SCEVAddRecExpr * 13568 SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { 13569 assert(getNumOperands() > 1 && "AddRec with zero step?"); 13570 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)), 13571 // but in this case we cannot guarantee that the value returned will be an 13572 // AddRec because SCEV does not have a fixed point where it stops 13573 // simplification: it is legal to return ({rec1} + {rec2}). For example, it 13574 // may happen if we reach arithmetic depth limit while simplifying. So we 13575 // construct the returned value explicitly. 13576 SmallVector<const SCEV *, 3> Ops; 13577 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and 13578 // (this + Step) is {A+B,+,B+C,+...,+,N}. 13579 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) 13580 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1))); 13581 // We know that the last operand is not a constant zero (otherwise it would 13582 // have been popped out earlier). This guarantees us that if the result has 13583 // the same last operand, then it will also not be popped out, meaning that 13584 // the returned value will be an AddRec. 13585 const SCEV *Last = getOperand(getNumOperands() - 1); 13586 assert(!Last->isZero() && "Recurrency with zero step?"); 13587 Ops.push_back(Last); 13588 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(), 13589 SCEV::FlagAnyWrap)); 13590 } 13591 13592 // Return true when S contains at least an undef value. 13593 bool ScalarEvolution::containsUndefs(const SCEV *S) const { 13594 return SCEVExprContains(S, [](const SCEV *S) { 13595 if (const auto *SU = dyn_cast<SCEVUnknown>(S)) 13596 return isa<UndefValue>(SU->getValue()); 13597 return false; 13598 }); 13599 } 13600 13601 // Return true when S contains a value that is a nullptr. 13602 bool ScalarEvolution::containsErasedValue(const SCEV *S) const { 13603 return SCEVExprContains(S, [](const SCEV *S) { 13604 if (const auto *SU = dyn_cast<SCEVUnknown>(S)) 13605 return SU->getValue() == nullptr; 13606 return false; 13607 }); 13608 } 13609 13610 /// Return the size of an element read or written by Inst. 13611 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { 13612 Type *Ty; 13613 if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) 13614 Ty = Store->getValueOperand()->getType(); 13615 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) 13616 Ty = Load->getType(); 13617 else 13618 return nullptr; 13619 13620 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Inst->getContext())); 13621 return getSizeOfExpr(ETy, Ty); 13622 } 13623 13624 //===----------------------------------------------------------------------===// 13625 // SCEVCallbackVH Class Implementation 13626 //===----------------------------------------------------------------------===// 13627 13628 void ScalarEvolution::SCEVCallbackVH::deleted() { 13629 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 13630 if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) 13631 SE->ConstantEvolutionLoopExitValue.erase(PN); 13632 SE->eraseValueFromMap(getValPtr()); 13633 // this now dangles! 13634 } 13635 13636 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { 13637 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 13638 13639 // Forget all the expressions associated with users of the old value, 13640 // so that future queries will recompute the expressions using the new 13641 // value. 13642 SE->forgetValue(getValPtr()); 13643 // this now dangles! 13644 } 13645 13646 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) 13647 : CallbackVH(V), SE(se) {} 13648 13649 //===----------------------------------------------------------------------===// 13650 // ScalarEvolution Class Implementation 13651 //===----------------------------------------------------------------------===// 13652 13653 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, 13654 AssumptionCache &AC, DominatorTree &DT, 13655 LoopInfo &LI) 13656 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI), 13657 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), 13658 LoopDispositions(64), BlockDispositions(64) { 13659 // To use guards for proving predicates, we need to scan every instruction in 13660 // relevant basic blocks, and not just terminators. Doing this is a waste of 13661 // time if the IR does not actually contain any calls to 13662 // @llvm.experimental.guard, so do a quick check and remember this beforehand. 13663 // 13664 // This pessimizes the case where a pass that preserves ScalarEvolution wants 13665 // to _add_ guards to the module when there weren't any before, and wants 13666 // ScalarEvolution to optimize based on those guards. For now we prefer to be 13667 // efficient in lieu of being smart in that rather obscure case. 13668 13669 auto *GuardDecl = Intrinsic::getDeclarationIfExists( 13670 F.getParent(), Intrinsic::experimental_guard); 13671 HasGuards = GuardDecl && !GuardDecl->use_empty(); 13672 } 13673 13674 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) 13675 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), 13676 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), 13677 ValueExprMap(std::move(Arg.ValueExprMap)), 13678 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), 13679 PendingPhiRanges(std::move(Arg.PendingPhiRanges)), 13680 PendingMerges(std::move(Arg.PendingMerges)), 13681 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)), 13682 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), 13683 PredicatedBackedgeTakenCounts( 13684 std::move(Arg.PredicatedBackedgeTakenCounts)), 13685 BECountUsers(std::move(Arg.BECountUsers)), 13686 ConstantEvolutionLoopExitValue( 13687 std::move(Arg.ConstantEvolutionLoopExitValue)), 13688 ValuesAtScopes(std::move(Arg.ValuesAtScopes)), 13689 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)), 13690 LoopDispositions(std::move(Arg.LoopDispositions)), 13691 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)), 13692 BlockDispositions(std::move(Arg.BlockDispositions)), 13693 SCEVUsers(std::move(Arg.SCEVUsers)), 13694 UnsignedRanges(std::move(Arg.UnsignedRanges)), 13695 SignedRanges(std::move(Arg.SignedRanges)), 13696 UniqueSCEVs(std::move(Arg.UniqueSCEVs)), 13697 UniquePreds(std::move(Arg.UniquePreds)), 13698 SCEVAllocator(std::move(Arg.SCEVAllocator)), 13699 LoopUsers(std::move(Arg.LoopUsers)), 13700 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), 13701 FirstUnknown(Arg.FirstUnknown) { 13702 Arg.FirstUnknown = nullptr; 13703 } 13704 13705 ScalarEvolution::~ScalarEvolution() { 13706 // Iterate through all the SCEVUnknown instances and call their 13707 // destructors, so that they release their references to their values. 13708 for (SCEVUnknown *U = FirstUnknown; U;) { 13709 SCEVUnknown *Tmp = U; 13710 U = U->Next; 13711 Tmp->~SCEVUnknown(); 13712 } 13713 FirstUnknown = nullptr; 13714 13715 ExprValueMap.clear(); 13716 ValueExprMap.clear(); 13717 HasRecMap.clear(); 13718 BackedgeTakenCounts.clear(); 13719 PredicatedBackedgeTakenCounts.clear(); 13720 13721 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); 13722 assert(PendingPhiRanges.empty() && "getRangeRef garbage"); 13723 assert(PendingMerges.empty() && "isImpliedViaMerge garbage"); 13724 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); 13725 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); 13726 } 13727 13728 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { 13729 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); 13730 } 13731 13732 /// When printing a top-level SCEV for trip counts, it's helpful to include 13733 /// a type for constants which are otherwise hard to disambiguate. 13734 static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) { 13735 if (isa<SCEVConstant>(S)) 13736 OS << *S->getType() << " "; 13737 OS << *S; 13738 } 13739 13740 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, 13741 const Loop *L) { 13742 // Print all inner loops first 13743 for (Loop *I : *L) 13744 PrintLoopInfo(OS, SE, I); 13745 13746 OS << "Loop "; 13747 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13748 OS << ": "; 13749 13750 SmallVector<BasicBlock *, 8> ExitingBlocks; 13751 L->getExitingBlocks(ExitingBlocks); 13752 if (ExitingBlocks.size() != 1) 13753 OS << "<multiple exits> "; 13754 13755 auto *BTC = SE->getBackedgeTakenCount(L); 13756 if (!isa<SCEVCouldNotCompute>(BTC)) { 13757 OS << "backedge-taken count is "; 13758 PrintSCEVWithTypeHint(OS, BTC); 13759 } else 13760 OS << "Unpredictable backedge-taken count."; 13761 OS << "\n"; 13762 13763 if (ExitingBlocks.size() > 1) 13764 for (BasicBlock *ExitingBlock : ExitingBlocks) { 13765 OS << " exit count for " << ExitingBlock->getName() << ": "; 13766 const SCEV *EC = SE->getExitCount(L, ExitingBlock); 13767 PrintSCEVWithTypeHint(OS, EC); 13768 if (isa<SCEVCouldNotCompute>(EC)) { 13769 // Retry with predicates. 13770 SmallVector<const SCEVPredicate *> Predicates; 13771 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates); 13772 if (!isa<SCEVCouldNotCompute>(EC)) { 13773 OS << "\n predicated exit count for " << ExitingBlock->getName() 13774 << ": "; 13775 PrintSCEVWithTypeHint(OS, EC); 13776 OS << "\n Predicates:\n"; 13777 for (const auto *P : Predicates) 13778 P->print(OS, 4); 13779 } 13780 } 13781 OS << "\n"; 13782 } 13783 13784 OS << "Loop "; 13785 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13786 OS << ": "; 13787 13788 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L); 13789 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) { 13790 OS << "constant max backedge-taken count is "; 13791 PrintSCEVWithTypeHint(OS, ConstantBTC); 13792 if (SE->isBackedgeTakenCountMaxOrZero(L)) 13793 OS << ", actual taken count either this or zero."; 13794 } else { 13795 OS << "Unpredictable constant max backedge-taken count. "; 13796 } 13797 13798 OS << "\n" 13799 "Loop "; 13800 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13801 OS << ": "; 13802 13803 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L); 13804 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) { 13805 OS << "symbolic max backedge-taken count is "; 13806 PrintSCEVWithTypeHint(OS, SymbolicBTC); 13807 if (SE->isBackedgeTakenCountMaxOrZero(L)) 13808 OS << ", actual taken count either this or zero."; 13809 } else { 13810 OS << "Unpredictable symbolic max backedge-taken count. "; 13811 } 13812 OS << "\n"; 13813 13814 if (ExitingBlocks.size() > 1) 13815 for (BasicBlock *ExitingBlock : ExitingBlocks) { 13816 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": "; 13817 auto *ExitBTC = SE->getExitCount(L, ExitingBlock, 13818 ScalarEvolution::SymbolicMaximum); 13819 PrintSCEVWithTypeHint(OS, ExitBTC); 13820 if (isa<SCEVCouldNotCompute>(ExitBTC)) { 13821 // Retry with predicates. 13822 SmallVector<const SCEVPredicate *> Predicates; 13823 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates, 13824 ScalarEvolution::SymbolicMaximum); 13825 if (!isa<SCEVCouldNotCompute>(ExitBTC)) { 13826 OS << "\n predicated symbolic max exit count for " 13827 << ExitingBlock->getName() << ": "; 13828 PrintSCEVWithTypeHint(OS, ExitBTC); 13829 OS << "\n Predicates:\n"; 13830 for (const auto *P : Predicates) 13831 P->print(OS, 4); 13832 } 13833 } 13834 OS << "\n"; 13835 } 13836 13837 SmallVector<const SCEVPredicate *, 4> Preds; 13838 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); 13839 if (PBT != BTC) { 13840 assert(!Preds.empty() && "Different predicated BTC, but no predicates"); 13841 OS << "Loop "; 13842 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13843 OS << ": "; 13844 if (!isa<SCEVCouldNotCompute>(PBT)) { 13845 OS << "Predicated backedge-taken count is "; 13846 PrintSCEVWithTypeHint(OS, PBT); 13847 } else 13848 OS << "Unpredictable predicated backedge-taken count."; 13849 OS << "\n"; 13850 OS << " Predicates:\n"; 13851 for (const auto *P : Preds) 13852 P->print(OS, 4); 13853 } 13854 Preds.clear(); 13855 13856 auto *PredConstantMax = 13857 SE->getPredicatedConstantMaxBackedgeTakenCount(L, Preds); 13858 if (PredConstantMax != ConstantBTC) { 13859 assert(!Preds.empty() && 13860 "different predicated constant max BTC but no predicates"); 13861 OS << "Loop "; 13862 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13863 OS << ": "; 13864 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) { 13865 OS << "Predicated constant max backedge-taken count is "; 13866 PrintSCEVWithTypeHint(OS, PredConstantMax); 13867 } else 13868 OS << "Unpredictable predicated constant max backedge-taken count."; 13869 OS << "\n"; 13870 OS << " Predicates:\n"; 13871 for (const auto *P : Preds) 13872 P->print(OS, 4); 13873 } 13874 Preds.clear(); 13875 13876 auto *PredSymbolicMax = 13877 SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds); 13878 if (SymbolicBTC != PredSymbolicMax) { 13879 assert(!Preds.empty() && 13880 "Different predicated symbolic max BTC, but no predicates"); 13881 OS << "Loop "; 13882 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13883 OS << ": "; 13884 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) { 13885 OS << "Predicated symbolic max backedge-taken count is "; 13886 PrintSCEVWithTypeHint(OS, PredSymbolicMax); 13887 } else 13888 OS << "Unpredictable predicated symbolic max backedge-taken count."; 13889 OS << "\n"; 13890 OS << " Predicates:\n"; 13891 for (const auto *P : Preds) 13892 P->print(OS, 4); 13893 } 13894 13895 if (SE->hasLoopInvariantBackedgeTakenCount(L)) { 13896 OS << "Loop "; 13897 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13898 OS << ": "; 13899 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; 13900 } 13901 } 13902 13903 namespace llvm { 13904 raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::LoopDisposition LD) { 13905 switch (LD) { 13906 case ScalarEvolution::LoopVariant: 13907 OS << "Variant"; 13908 break; 13909 case ScalarEvolution::LoopInvariant: 13910 OS << "Invariant"; 13911 break; 13912 case ScalarEvolution::LoopComputable: 13913 OS << "Computable"; 13914 break; 13915 } 13916 return OS; 13917 } 13918 13919 raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::BlockDisposition BD) { 13920 switch (BD) { 13921 case ScalarEvolution::DoesNotDominateBlock: 13922 OS << "DoesNotDominate"; 13923 break; 13924 case ScalarEvolution::DominatesBlock: 13925 OS << "Dominates"; 13926 break; 13927 case ScalarEvolution::ProperlyDominatesBlock: 13928 OS << "ProperlyDominates"; 13929 break; 13930 } 13931 return OS; 13932 } 13933 } // namespace llvm 13934 13935 void ScalarEvolution::print(raw_ostream &OS) const { 13936 // ScalarEvolution's implementation of the print method is to print 13937 // out SCEV values of all instructions that are interesting. Doing 13938 // this potentially causes it to create new SCEV objects though, 13939 // which technically conflicts with the const qualifier. This isn't 13940 // observable from outside the class though, so casting away the 13941 // const isn't dangerous. 13942 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 13943 13944 if (ClassifyExpressions) { 13945 OS << "Classifying expressions for: "; 13946 F.printAsOperand(OS, /*PrintType=*/false); 13947 OS << "\n"; 13948 for (Instruction &I : instructions(F)) 13949 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) { 13950 OS << I << '\n'; 13951 OS << " --> "; 13952 const SCEV *SV = SE.getSCEV(&I); 13953 SV->print(OS); 13954 if (!isa<SCEVCouldNotCompute>(SV)) { 13955 OS << " U: "; 13956 SE.getUnsignedRange(SV).print(OS); 13957 OS << " S: "; 13958 SE.getSignedRange(SV).print(OS); 13959 } 13960 13961 const Loop *L = LI.getLoopFor(I.getParent()); 13962 13963 const SCEV *AtUse = SE.getSCEVAtScope(SV, L); 13964 if (AtUse != SV) { 13965 OS << " --> "; 13966 AtUse->print(OS); 13967 if (!isa<SCEVCouldNotCompute>(AtUse)) { 13968 OS << " U: "; 13969 SE.getUnsignedRange(AtUse).print(OS); 13970 OS << " S: "; 13971 SE.getSignedRange(AtUse).print(OS); 13972 } 13973 } 13974 13975 if (L) { 13976 OS << "\t\t" "Exits: "; 13977 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); 13978 if (!SE.isLoopInvariant(ExitValue, L)) { 13979 OS << "<<Unknown>>"; 13980 } else { 13981 OS << *ExitValue; 13982 } 13983 13984 bool First = true; 13985 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { 13986 if (First) { 13987 OS << "\t\t" "LoopDispositions: { "; 13988 First = false; 13989 } else { 13990 OS << ", "; 13991 } 13992 13993 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); 13994 OS << ": " << SE.getLoopDisposition(SV, Iter); 13995 } 13996 13997 for (const auto *InnerL : depth_first(L)) { 13998 if (InnerL == L) 13999 continue; 14000 if (First) { 14001 OS << "\t\t" "LoopDispositions: { "; 14002 First = false; 14003 } else { 14004 OS << ", "; 14005 } 14006 14007 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); 14008 OS << ": " << SE.getLoopDisposition(SV, InnerL); 14009 } 14010 14011 OS << " }"; 14012 } 14013 14014 OS << "\n"; 14015 } 14016 } 14017 14018 OS << "Determining loop execution counts for: "; 14019 F.printAsOperand(OS, /*PrintType=*/false); 14020 OS << "\n"; 14021 for (Loop *I : LI) 14022 PrintLoopInfo(OS, &SE, I); 14023 } 14024 14025 ScalarEvolution::LoopDisposition 14026 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { 14027 auto &Values = LoopDispositions[S]; 14028 for (auto &V : Values) { 14029 if (V.getPointer() == L) 14030 return V.getInt(); 14031 } 14032 Values.emplace_back(L, LoopVariant); 14033 LoopDisposition D = computeLoopDisposition(S, L); 14034 auto &Values2 = LoopDispositions[S]; 14035 for (auto &V : llvm::reverse(Values2)) { 14036 if (V.getPointer() == L) { 14037 V.setInt(D); 14038 break; 14039 } 14040 } 14041 return D; 14042 } 14043 14044 ScalarEvolution::LoopDisposition 14045 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { 14046 switch (S->getSCEVType()) { 14047 case scConstant: 14048 case scVScale: 14049 return LoopInvariant; 14050 case scAddRecExpr: { 14051 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 14052 14053 // If L is the addrec's loop, it's computable. 14054 if (AR->getLoop() == L) 14055 return LoopComputable; 14056 14057 // Add recurrences are never invariant in the function-body (null loop). 14058 if (!L) 14059 return LoopVariant; 14060 14061 // Everything that is not defined at loop entry is variant. 14062 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) 14063 return LoopVariant; 14064 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not" 14065 " dominate the contained loop's header?"); 14066 14067 // This recurrence is invariant w.r.t. L if AR's loop contains L. 14068 if (AR->getLoop()->contains(L)) 14069 return LoopInvariant; 14070 14071 // This recurrence is variant w.r.t. L if any of its operands 14072 // are variant. 14073 for (const auto *Op : AR->operands()) 14074 if (!isLoopInvariant(Op, L)) 14075 return LoopVariant; 14076 14077 // Otherwise it's loop-invariant. 14078 return LoopInvariant; 14079 } 14080 case scTruncate: 14081 case scZeroExtend: 14082 case scSignExtend: 14083 case scPtrToInt: 14084 case scAddExpr: 14085 case scMulExpr: 14086 case scUDivExpr: 14087 case scUMaxExpr: 14088 case scSMaxExpr: 14089 case scUMinExpr: 14090 case scSMinExpr: 14091 case scSequentialUMinExpr: { 14092 bool HasVarying = false; 14093 for (const auto *Op : S->operands()) { 14094 LoopDisposition D = getLoopDisposition(Op, L); 14095 if (D == LoopVariant) 14096 return LoopVariant; 14097 if (D == LoopComputable) 14098 HasVarying = true; 14099 } 14100 return HasVarying ? LoopComputable : LoopInvariant; 14101 } 14102 case scUnknown: 14103 // All non-instruction values are loop invariant. All instructions are loop 14104 // invariant if they are not contained in the specified loop. 14105 // Instructions are never considered invariant in the function body 14106 // (null loop) because they are defined within the "loop". 14107 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) 14108 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; 14109 return LoopInvariant; 14110 case scCouldNotCompute: 14111 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 14112 } 14113 llvm_unreachable("Unknown SCEV kind!"); 14114 } 14115 14116 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { 14117 return getLoopDisposition(S, L) == LoopInvariant; 14118 } 14119 14120 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { 14121 return getLoopDisposition(S, L) == LoopComputable; 14122 } 14123 14124 ScalarEvolution::BlockDisposition 14125 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { 14126 auto &Values = BlockDispositions[S]; 14127 for (auto &V : Values) { 14128 if (V.getPointer() == BB) 14129 return V.getInt(); 14130 } 14131 Values.emplace_back(BB, DoesNotDominateBlock); 14132 BlockDisposition D = computeBlockDisposition(S, BB); 14133 auto &Values2 = BlockDispositions[S]; 14134 for (auto &V : llvm::reverse(Values2)) { 14135 if (V.getPointer() == BB) { 14136 V.setInt(D); 14137 break; 14138 } 14139 } 14140 return D; 14141 } 14142 14143 ScalarEvolution::BlockDisposition 14144 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { 14145 switch (S->getSCEVType()) { 14146 case scConstant: 14147 case scVScale: 14148 return ProperlyDominatesBlock; 14149 case scAddRecExpr: { 14150 // This uses a "dominates" query instead of "properly dominates" query 14151 // to test for proper dominance too, because the instruction which 14152 // produces the addrec's value is a PHI, and a PHI effectively properly 14153 // dominates its entire containing block. 14154 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 14155 if (!DT.dominates(AR->getLoop()->getHeader(), BB)) 14156 return DoesNotDominateBlock; 14157 14158 // Fall through into SCEVNAryExpr handling. 14159 [[fallthrough]]; 14160 } 14161 case scTruncate: 14162 case scZeroExtend: 14163 case scSignExtend: 14164 case scPtrToInt: 14165 case scAddExpr: 14166 case scMulExpr: 14167 case scUDivExpr: 14168 case scUMaxExpr: 14169 case scSMaxExpr: 14170 case scUMinExpr: 14171 case scSMinExpr: 14172 case scSequentialUMinExpr: { 14173 bool Proper = true; 14174 for (const SCEV *NAryOp : S->operands()) { 14175 BlockDisposition D = getBlockDisposition(NAryOp, BB); 14176 if (D == DoesNotDominateBlock) 14177 return DoesNotDominateBlock; 14178 if (D == DominatesBlock) 14179 Proper = false; 14180 } 14181 return Proper ? ProperlyDominatesBlock : DominatesBlock; 14182 } 14183 case scUnknown: 14184 if (Instruction *I = 14185 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) { 14186 if (I->getParent() == BB) 14187 return DominatesBlock; 14188 if (DT.properlyDominates(I->getParent(), BB)) 14189 return ProperlyDominatesBlock; 14190 return DoesNotDominateBlock; 14191 } 14192 return ProperlyDominatesBlock; 14193 case scCouldNotCompute: 14194 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 14195 } 14196 llvm_unreachable("Unknown SCEV kind!"); 14197 } 14198 14199 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { 14200 return getBlockDisposition(S, BB) >= DominatesBlock; 14201 } 14202 14203 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { 14204 return getBlockDisposition(S, BB) == ProperlyDominatesBlock; 14205 } 14206 14207 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { 14208 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); 14209 } 14210 14211 void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, 14212 bool Predicated) { 14213 auto &BECounts = 14214 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; 14215 auto It = BECounts.find(L); 14216 if (It != BECounts.end()) { 14217 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) { 14218 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { 14219 if (!isa<SCEVConstant>(S)) { 14220 auto UserIt = BECountUsers.find(S); 14221 assert(UserIt != BECountUsers.end()); 14222 UserIt->second.erase({L, Predicated}); 14223 } 14224 } 14225 } 14226 BECounts.erase(It); 14227 } 14228 } 14229 14230 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) { 14231 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs); 14232 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end()); 14233 14234 while (!Worklist.empty()) { 14235 const SCEV *Curr = Worklist.pop_back_val(); 14236 auto Users = SCEVUsers.find(Curr); 14237 if (Users != SCEVUsers.end()) 14238 for (const auto *User : Users->second) 14239 if (ToForget.insert(User).second) 14240 Worklist.push_back(User); 14241 } 14242 14243 for (const auto *S : ToForget) 14244 forgetMemoizedResultsImpl(S); 14245 14246 for (auto I = PredicatedSCEVRewrites.begin(); 14247 I != PredicatedSCEVRewrites.end();) { 14248 std::pair<const SCEV *, const Loop *> Entry = I->first; 14249 if (ToForget.count(Entry.first)) 14250 PredicatedSCEVRewrites.erase(I++); 14251 else 14252 ++I; 14253 } 14254 } 14255 14256 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { 14257 LoopDispositions.erase(S); 14258 BlockDispositions.erase(S); 14259 UnsignedRanges.erase(S); 14260 SignedRanges.erase(S); 14261 HasRecMap.erase(S); 14262 ConstantMultipleCache.erase(S); 14263 14264 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) { 14265 UnsignedWrapViaInductionTried.erase(AR); 14266 SignedWrapViaInductionTried.erase(AR); 14267 } 14268 14269 auto ExprIt = ExprValueMap.find(S); 14270 if (ExprIt != ExprValueMap.end()) { 14271 for (Value *V : ExprIt->second) { 14272 auto ValueIt = ValueExprMap.find_as(V); 14273 if (ValueIt != ValueExprMap.end()) 14274 ValueExprMap.erase(ValueIt); 14275 } 14276 ExprValueMap.erase(ExprIt); 14277 } 14278 14279 auto ScopeIt = ValuesAtScopes.find(S); 14280 if (ScopeIt != ValuesAtScopes.end()) { 14281 for (const auto &Pair : ScopeIt->second) 14282 if (!isa_and_nonnull<SCEVConstant>(Pair.second)) 14283 llvm::erase(ValuesAtScopesUsers[Pair.second], 14284 std::make_pair(Pair.first, S)); 14285 ValuesAtScopes.erase(ScopeIt); 14286 } 14287 14288 auto ScopeUserIt = ValuesAtScopesUsers.find(S); 14289 if (ScopeUserIt != ValuesAtScopesUsers.end()) { 14290 for (const auto &Pair : ScopeUserIt->second) 14291 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S)); 14292 ValuesAtScopesUsers.erase(ScopeUserIt); 14293 } 14294 14295 auto BEUsersIt = BECountUsers.find(S); 14296 if (BEUsersIt != BECountUsers.end()) { 14297 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original. 14298 auto Copy = BEUsersIt->second; 14299 for (const auto &Pair : Copy) 14300 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt()); 14301 BECountUsers.erase(BEUsersIt); 14302 } 14303 14304 auto FoldUser = FoldCacheUser.find(S); 14305 if (FoldUser != FoldCacheUser.end()) 14306 for (auto &KV : FoldUser->second) 14307 FoldCache.erase(KV); 14308 FoldCacheUser.erase(S); 14309 } 14310 14311 void 14312 ScalarEvolution::getUsedLoops(const SCEV *S, 14313 SmallPtrSetImpl<const Loop *> &LoopsUsed) { 14314 struct FindUsedLoops { 14315 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed) 14316 : LoopsUsed(LoopsUsed) {} 14317 SmallPtrSetImpl<const Loop *> &LoopsUsed; 14318 bool follow(const SCEV *S) { 14319 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) 14320 LoopsUsed.insert(AR->getLoop()); 14321 return true; 14322 } 14323 14324 bool isDone() const { return false; } 14325 }; 14326 14327 FindUsedLoops F(LoopsUsed); 14328 SCEVTraversal<FindUsedLoops>(F).visitAll(S); 14329 } 14330 14331 void ScalarEvolution::getReachableBlocks( 14332 SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) { 14333 SmallVector<BasicBlock *> Worklist; 14334 Worklist.push_back(&F.getEntryBlock()); 14335 while (!Worklist.empty()) { 14336 BasicBlock *BB = Worklist.pop_back_val(); 14337 if (!Reachable.insert(BB).second) 14338 continue; 14339 14340 Value *Cond; 14341 BasicBlock *TrueBB, *FalseBB; 14342 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB), 14343 m_BasicBlock(FalseBB)))) { 14344 if (auto *C = dyn_cast<ConstantInt>(Cond)) { 14345 Worklist.push_back(C->isOne() ? TrueBB : FalseBB); 14346 continue; 14347 } 14348 14349 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) { 14350 const SCEV *L = getSCEV(Cmp->getOperand(0)); 14351 const SCEV *R = getSCEV(Cmp->getOperand(1)); 14352 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) { 14353 Worklist.push_back(TrueBB); 14354 continue; 14355 } 14356 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L, 14357 R)) { 14358 Worklist.push_back(FalseBB); 14359 continue; 14360 } 14361 } 14362 } 14363 14364 append_range(Worklist, successors(BB)); 14365 } 14366 } 14367 14368 void ScalarEvolution::verify() const { 14369 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 14370 ScalarEvolution SE2(F, TLI, AC, DT, LI); 14371 14372 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end()); 14373 14374 // Map's SCEV expressions from one ScalarEvolution "universe" to another. 14375 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { 14376 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} 14377 14378 const SCEV *visitConstant(const SCEVConstant *Constant) { 14379 return SE.getConstant(Constant->getAPInt()); 14380 } 14381 14382 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 14383 return SE.getUnknown(Expr->getValue()); 14384 } 14385 14386 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 14387 return SE.getCouldNotCompute(); 14388 } 14389 }; 14390 14391 SCEVMapper SCM(SE2); 14392 SmallPtrSet<BasicBlock *, 16> ReachableBlocks; 14393 SE2.getReachableBlocks(ReachableBlocks, F); 14394 14395 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * { 14396 if (containsUndefs(Old) || containsUndefs(New)) { 14397 // SCEV treats "undef" as an unknown but consistent value (i.e. it does 14398 // not propagate undef aggressively). This means we can (and do) fail 14399 // verification in cases where a transform makes a value go from "undef" 14400 // to "undef+1" (say). The transform is fine, since in both cases the 14401 // result is "undef", but SCEV thinks the value increased by 1. 14402 return nullptr; 14403 } 14404 14405 // Unless VerifySCEVStrict is set, we only compare constant deltas. 14406 const SCEV *Delta = SE2.getMinusSCEV(Old, New); 14407 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta)) 14408 return nullptr; 14409 14410 return Delta; 14411 }; 14412 14413 while (!LoopStack.empty()) { 14414 auto *L = LoopStack.pop_back_val(); 14415 llvm::append_range(LoopStack, *L); 14416 14417 // Only verify BECounts in reachable loops. For an unreachable loop, 14418 // any BECount is legal. 14419 if (!ReachableBlocks.contains(L->getHeader())) 14420 continue; 14421 14422 // Only verify cached BECounts. Computing new BECounts may change the 14423 // results of subsequent SCEV uses. 14424 auto It = BackedgeTakenCounts.find(L); 14425 if (It == BackedgeTakenCounts.end()) 14426 continue; 14427 14428 auto *CurBECount = 14429 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this))); 14430 auto *NewBECount = SE2.getBackedgeTakenCount(L); 14431 14432 if (CurBECount == SE2.getCouldNotCompute() || 14433 NewBECount == SE2.getCouldNotCompute()) { 14434 // NB! This situation is legal, but is very suspicious -- whatever pass 14435 // change the loop to make a trip count go from could not compute to 14436 // computable or vice-versa *should have* invalidated SCEV. However, we 14437 // choose not to assert here (for now) since we don't want false 14438 // positives. 14439 continue; 14440 } 14441 14442 if (SE.getTypeSizeInBits(CurBECount->getType()) > 14443 SE.getTypeSizeInBits(NewBECount->getType())) 14444 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); 14445 else if (SE.getTypeSizeInBits(CurBECount->getType()) < 14446 SE.getTypeSizeInBits(NewBECount->getType())) 14447 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); 14448 14449 const SCEV *Delta = GetDelta(CurBECount, NewBECount); 14450 if (Delta && !Delta->isZero()) { 14451 dbgs() << "Trip Count for " << *L << " Changed!\n"; 14452 dbgs() << "Old: " << *CurBECount << "\n"; 14453 dbgs() << "New: " << *NewBECount << "\n"; 14454 dbgs() << "Delta: " << *Delta << "\n"; 14455 std::abort(); 14456 } 14457 } 14458 14459 // Collect all valid loops currently in LoopInfo. 14460 SmallPtrSet<Loop *, 32> ValidLoops; 14461 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end()); 14462 while (!Worklist.empty()) { 14463 Loop *L = Worklist.pop_back_val(); 14464 if (ValidLoops.insert(L).second) 14465 Worklist.append(L->begin(), L->end()); 14466 } 14467 for (const auto &KV : ValueExprMap) { 14468 #ifndef NDEBUG 14469 // Check for SCEV expressions referencing invalid/deleted loops. 14470 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) { 14471 assert(ValidLoops.contains(AR->getLoop()) && 14472 "AddRec references invalid loop"); 14473 } 14474 #endif 14475 14476 // Check that the value is also part of the reverse map. 14477 auto It = ExprValueMap.find(KV.second); 14478 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) { 14479 dbgs() << "Value " << *KV.first 14480 << " is in ValueExprMap but not in ExprValueMap\n"; 14481 std::abort(); 14482 } 14483 14484 if (auto *I = dyn_cast<Instruction>(&*KV.first)) { 14485 if (!ReachableBlocks.contains(I->getParent())) 14486 continue; 14487 const SCEV *OldSCEV = SCM.visit(KV.second); 14488 const SCEV *NewSCEV = SE2.getSCEV(I); 14489 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV); 14490 if (Delta && !Delta->isZero()) { 14491 dbgs() << "SCEV for value " << *I << " changed!\n" 14492 << "Old: " << *OldSCEV << "\n" 14493 << "New: " << *NewSCEV << "\n" 14494 << "Delta: " << *Delta << "\n"; 14495 std::abort(); 14496 } 14497 } 14498 } 14499 14500 for (const auto &KV : ExprValueMap) { 14501 for (Value *V : KV.second) { 14502 const SCEV *S = ValueExprMap.lookup(V); 14503 if (!S) { 14504 dbgs() << "Value " << *V 14505 << " is in ExprValueMap but not in ValueExprMap\n"; 14506 std::abort(); 14507 } 14508 if (S != KV.first) { 14509 dbgs() << "Value " << *V << " mapped to " << *S << " rather than " 14510 << *KV.first << "\n"; 14511 std::abort(); 14512 } 14513 } 14514 } 14515 14516 // Verify integrity of SCEV users. 14517 for (const auto &S : UniqueSCEVs) { 14518 for (const auto *Op : S.operands()) { 14519 // We do not store dependencies of constants. 14520 if (isa<SCEVConstant>(Op)) 14521 continue; 14522 auto It = SCEVUsers.find(Op); 14523 if (It != SCEVUsers.end() && It->second.count(&S)) 14524 continue; 14525 dbgs() << "Use of operand " << *Op << " by user " << S 14526 << " is not being tracked!\n"; 14527 std::abort(); 14528 } 14529 } 14530 14531 // Verify integrity of ValuesAtScopes users. 14532 for (const auto &ValueAndVec : ValuesAtScopes) { 14533 const SCEV *Value = ValueAndVec.first; 14534 for (const auto &LoopAndValueAtScope : ValueAndVec.second) { 14535 const Loop *L = LoopAndValueAtScope.first; 14536 const SCEV *ValueAtScope = LoopAndValueAtScope.second; 14537 if (!isa<SCEVConstant>(ValueAtScope)) { 14538 auto It = ValuesAtScopesUsers.find(ValueAtScope); 14539 if (It != ValuesAtScopesUsers.end() && 14540 is_contained(It->second, std::make_pair(L, Value))) 14541 continue; 14542 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: " 14543 << *ValueAtScope << " missing in ValuesAtScopesUsers\n"; 14544 std::abort(); 14545 } 14546 } 14547 } 14548 14549 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) { 14550 const SCEV *ValueAtScope = ValueAtScopeAndVec.first; 14551 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) { 14552 const Loop *L = LoopAndValue.first; 14553 const SCEV *Value = LoopAndValue.second; 14554 assert(!isa<SCEVConstant>(Value)); 14555 auto It = ValuesAtScopes.find(Value); 14556 if (It != ValuesAtScopes.end() && 14557 is_contained(It->second, std::make_pair(L, ValueAtScope))) 14558 continue; 14559 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: " 14560 << *ValueAtScope << " missing in ValuesAtScopes\n"; 14561 std::abort(); 14562 } 14563 } 14564 14565 // Verify integrity of BECountUsers. 14566 auto VerifyBECountUsers = [&](bool Predicated) { 14567 auto &BECounts = 14568 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; 14569 for (const auto &LoopAndBEInfo : BECounts) { 14570 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) { 14571 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { 14572 if (!isa<SCEVConstant>(S)) { 14573 auto UserIt = BECountUsers.find(S); 14574 if (UserIt != BECountUsers.end() && 14575 UserIt->second.contains({ LoopAndBEInfo.first, Predicated })) 14576 continue; 14577 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first 14578 << " missing from BECountUsers\n"; 14579 std::abort(); 14580 } 14581 } 14582 } 14583 } 14584 }; 14585 VerifyBECountUsers(/* Predicated */ false); 14586 VerifyBECountUsers(/* Predicated */ true); 14587 14588 // Verify intergity of loop disposition cache. 14589 for (auto &[S, Values] : LoopDispositions) { 14590 for (auto [Loop, CachedDisposition] : Values) { 14591 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop); 14592 if (CachedDisposition != RecomputedDisposition) { 14593 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop 14594 << " is incorrect: cached " << CachedDisposition << ", actual " 14595 << RecomputedDisposition << "\n"; 14596 std::abort(); 14597 } 14598 } 14599 } 14600 14601 // Verify integrity of the block disposition cache. 14602 for (auto &[S, Values] : BlockDispositions) { 14603 for (auto [BB, CachedDisposition] : Values) { 14604 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB); 14605 if (CachedDisposition != RecomputedDisposition) { 14606 dbgs() << "Cached disposition of " << *S << " for block %" 14607 << BB->getName() << " is incorrect: cached " << CachedDisposition 14608 << ", actual " << RecomputedDisposition << "\n"; 14609 std::abort(); 14610 } 14611 } 14612 } 14613 14614 // Verify FoldCache/FoldCacheUser caches. 14615 for (auto [FoldID, Expr] : FoldCache) { 14616 auto I = FoldCacheUser.find(Expr); 14617 if (I == FoldCacheUser.end()) { 14618 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr 14619 << "!\n"; 14620 std::abort(); 14621 } 14622 if (!is_contained(I->second, FoldID)) { 14623 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n"; 14624 std::abort(); 14625 } 14626 } 14627 for (auto [Expr, IDs] : FoldCacheUser) { 14628 for (auto &FoldID : IDs) { 14629 const SCEV *S = FoldCache.lookup(FoldID); 14630 if (!S) { 14631 dbgs() << "Missing entry in FoldCache for expression " << *Expr 14632 << "!\n"; 14633 std::abort(); 14634 } 14635 if (S != Expr) { 14636 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S 14637 << " != " << *Expr << "!\n"; 14638 std::abort(); 14639 } 14640 } 14641 } 14642 14643 // Verify that ConstantMultipleCache computations are correct. We check that 14644 // cached multiples and recomputed multiples are multiples of each other to 14645 // verify correctness. It is possible that a recomputed multiple is different 14646 // from the cached multiple due to strengthened no wrap flags or changes in 14647 // KnownBits computations. 14648 for (auto [S, Multiple] : ConstantMultipleCache) { 14649 APInt RecomputedMultiple = SE2.getConstantMultiple(S); 14650 if ((Multiple != 0 && RecomputedMultiple != 0 && 14651 Multiple.urem(RecomputedMultiple) != 0 && 14652 RecomputedMultiple.urem(Multiple) != 0)) { 14653 dbgs() << "Incorrect cached computation in ConstantMultipleCache for " 14654 << *S << " : Computed " << RecomputedMultiple 14655 << " but cache contains " << Multiple << "!\n"; 14656 std::abort(); 14657 } 14658 } 14659 } 14660 14661 bool ScalarEvolution::invalidate( 14662 Function &F, const PreservedAnalyses &PA, 14663 FunctionAnalysisManager::Invalidator &Inv) { 14664 // Invalidate the ScalarEvolution object whenever it isn't preserved or one 14665 // of its dependencies is invalidated. 14666 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>(); 14667 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || 14668 Inv.invalidate<AssumptionAnalysis>(F, PA) || 14669 Inv.invalidate<DominatorTreeAnalysis>(F, PA) || 14670 Inv.invalidate<LoopAnalysis>(F, PA); 14671 } 14672 14673 AnalysisKey ScalarEvolutionAnalysis::Key; 14674 14675 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, 14676 FunctionAnalysisManager &AM) { 14677 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 14678 auto &AC = AM.getResult<AssumptionAnalysis>(F); 14679 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 14680 auto &LI = AM.getResult<LoopAnalysis>(F); 14681 return ScalarEvolution(F, TLI, AC, DT, LI); 14682 } 14683 14684 PreservedAnalyses 14685 ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) { 14686 AM.getResult<ScalarEvolutionAnalysis>(F).verify(); 14687 return PreservedAnalyses::all(); 14688 } 14689 14690 PreservedAnalyses 14691 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { 14692 // For compatibility with opt's -analyze feature under legacy pass manager 14693 // which was not ported to NPM. This keeps tests using 14694 // update_analyze_test_checks.py working. 14695 OS << "Printing analysis 'Scalar Evolution Analysis' for function '" 14696 << F.getName() << "':\n"; 14697 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); 14698 return PreservedAnalyses::all(); 14699 } 14700 14701 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", 14702 "Scalar Evolution Analysis", false, true) 14703 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 14704 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 14705 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 14706 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 14707 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", 14708 "Scalar Evolution Analysis", false, true) 14709 14710 char ScalarEvolutionWrapperPass::ID = 0; 14711 14712 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {} 14713 14714 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { 14715 SE.reset(new ScalarEvolution( 14716 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), 14717 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), 14718 getAnalysis<DominatorTreeWrapperPass>().getDomTree(), 14719 getAnalysis<LoopInfoWrapperPass>().getLoopInfo())); 14720 return false; 14721 } 14722 14723 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } 14724 14725 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { 14726 SE->print(OS); 14727 } 14728 14729 void ScalarEvolutionWrapperPass::verifyAnalysis() const { 14730 if (!VerifySCEV) 14731 return; 14732 14733 SE->verify(); 14734 } 14735 14736 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { 14737 AU.setPreservesAll(); 14738 AU.addRequiredTransitive<AssumptionCacheTracker>(); 14739 AU.addRequiredTransitive<LoopInfoWrapperPass>(); 14740 AU.addRequiredTransitive<DominatorTreeWrapperPass>(); 14741 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); 14742 } 14743 14744 const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, 14745 const SCEV *RHS) { 14746 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS); 14747 } 14748 14749 const SCEVPredicate * 14750 ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred, 14751 const SCEV *LHS, const SCEV *RHS) { 14752 FoldingSetNodeID ID; 14753 assert(LHS->getType() == RHS->getType() && 14754 "Type mismatch between LHS and RHS"); 14755 // Unique this node based on the arguments 14756 ID.AddInteger(SCEVPredicate::P_Compare); 14757 ID.AddInteger(Pred); 14758 ID.AddPointer(LHS); 14759 ID.AddPointer(RHS); 14760 void *IP = nullptr; 14761 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) 14762 return S; 14763 SCEVComparePredicate *Eq = new (SCEVAllocator) 14764 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS); 14765 UniquePreds.InsertNode(Eq, IP); 14766 return Eq; 14767 } 14768 14769 const SCEVPredicate *ScalarEvolution::getWrapPredicate( 14770 const SCEVAddRecExpr *AR, 14771 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { 14772 FoldingSetNodeID ID; 14773 // Unique this node based on the arguments 14774 ID.AddInteger(SCEVPredicate::P_Wrap); 14775 ID.AddPointer(AR); 14776 ID.AddInteger(AddedFlags); 14777 void *IP = nullptr; 14778 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) 14779 return S; 14780 auto *OF = new (SCEVAllocator) 14781 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); 14782 UniquePreds.InsertNode(OF, IP); 14783 return OF; 14784 } 14785 14786 namespace { 14787 14788 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { 14789 public: 14790 14791 /// Rewrites \p S in the context of a loop L and the SCEV predication 14792 /// infrastructure. 14793 /// 14794 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the 14795 /// equivalences present in \p Pred. 14796 /// 14797 /// If \p NewPreds is non-null, rewrite is free to add further predicates to 14798 /// \p NewPreds such that the result will be an AddRecExpr. 14799 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, 14800 SmallVectorImpl<const SCEVPredicate *> *NewPreds, 14801 const SCEVPredicate *Pred) { 14802 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); 14803 return Rewriter.visit(S); 14804 } 14805 14806 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 14807 if (Pred) { 14808 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) { 14809 for (const auto *Pred : U->getPredicates()) 14810 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) 14811 if (IPred->getLHS() == Expr && 14812 IPred->getPredicate() == ICmpInst::ICMP_EQ) 14813 return IPred->getRHS(); 14814 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) { 14815 if (IPred->getLHS() == Expr && 14816 IPred->getPredicate() == ICmpInst::ICMP_EQ) 14817 return IPred->getRHS(); 14818 } 14819 } 14820 return convertToAddRecWithPreds(Expr); 14821 } 14822 14823 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 14824 const SCEV *Operand = visit(Expr->getOperand()); 14825 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); 14826 if (AR && AR->getLoop() == L && AR->isAffine()) { 14827 // This couldn't be folded because the operand didn't have the nuw 14828 // flag. Add the nusw flag as an assumption that we could make. 14829 const SCEV *Step = AR->getStepRecurrence(SE); 14830 Type *Ty = Expr->getType(); 14831 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) 14832 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), 14833 SE.getSignExtendExpr(Step, Ty), L, 14834 AR->getNoWrapFlags()); 14835 } 14836 return SE.getZeroExtendExpr(Operand, Expr->getType()); 14837 } 14838 14839 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 14840 const SCEV *Operand = visit(Expr->getOperand()); 14841 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); 14842 if (AR && AR->getLoop() == L && AR->isAffine()) { 14843 // This couldn't be folded because the operand didn't have the nsw 14844 // flag. Add the nssw flag as an assumption that we could make. 14845 const SCEV *Step = AR->getStepRecurrence(SE); 14846 Type *Ty = Expr->getType(); 14847 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) 14848 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), 14849 SE.getSignExtendExpr(Step, Ty), L, 14850 AR->getNoWrapFlags()); 14851 } 14852 return SE.getSignExtendExpr(Operand, Expr->getType()); 14853 } 14854 14855 private: 14856 explicit SCEVPredicateRewriter( 14857 const Loop *L, ScalarEvolution &SE, 14858 SmallVectorImpl<const SCEVPredicate *> *NewPreds, 14859 const SCEVPredicate *Pred) 14860 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} 14861 14862 bool addOverflowAssumption(const SCEVPredicate *P) { 14863 if (!NewPreds) { 14864 // Check if we've already made this assumption. 14865 return Pred && Pred->implies(P, SE); 14866 } 14867 NewPreds->push_back(P); 14868 return true; 14869 } 14870 14871 bool addOverflowAssumption(const SCEVAddRecExpr *AR, 14872 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { 14873 auto *A = SE.getWrapPredicate(AR, AddedFlags); 14874 return addOverflowAssumption(A); 14875 } 14876 14877 // If \p Expr represents a PHINode, we try to see if it can be represented 14878 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible 14879 // to add this predicate as a runtime overflow check, we return the AddRec. 14880 // If \p Expr does not meet these conditions (is not a PHI node, or we 14881 // couldn't create an AddRec for it, or couldn't add the predicate), we just 14882 // return \p Expr. 14883 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { 14884 if (!isa<PHINode>(Expr->getValue())) 14885 return Expr; 14886 std::optional< 14887 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> 14888 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); 14889 if (!PredicatedRewrite) 14890 return Expr; 14891 for (const auto *P : PredicatedRewrite->second){ 14892 // Wrap predicates from outer loops are not supported. 14893 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) { 14894 if (L != WP->getExpr()->getLoop()) 14895 return Expr; 14896 } 14897 if (!addOverflowAssumption(P)) 14898 return Expr; 14899 } 14900 return PredicatedRewrite->first; 14901 } 14902 14903 SmallVectorImpl<const SCEVPredicate *> *NewPreds; 14904 const SCEVPredicate *Pred; 14905 const Loop *L; 14906 }; 14907 14908 } // end anonymous namespace 14909 14910 const SCEV * 14911 ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, 14912 const SCEVPredicate &Preds) { 14913 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); 14914 } 14915 14916 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( 14917 const SCEV *S, const Loop *L, 14918 SmallVectorImpl<const SCEVPredicate *> &Preds) { 14919 SmallVector<const SCEVPredicate *> TransformPreds; 14920 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); 14921 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 14922 14923 if (!AddRec) 14924 return nullptr; 14925 14926 // Since the transformation was successful, we can now transfer the SCEV 14927 // predicates. 14928 Preds.append(TransformPreds.begin(), TransformPreds.end()); 14929 14930 return AddRec; 14931 } 14932 14933 /// SCEV predicates 14934 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, 14935 SCEVPredicateKind Kind) 14936 : FastID(ID), Kind(Kind) {} 14937 14938 SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, 14939 const ICmpInst::Predicate Pred, 14940 const SCEV *LHS, const SCEV *RHS) 14941 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { 14942 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); 14943 assert(LHS != RHS && "LHS and RHS are the same SCEV"); 14944 } 14945 14946 bool SCEVComparePredicate::implies(const SCEVPredicate *N, 14947 ScalarEvolution &SE) const { 14948 const auto *Op = dyn_cast<SCEVComparePredicate>(N); 14949 14950 if (!Op) 14951 return false; 14952 14953 if (Pred != ICmpInst::ICMP_EQ) 14954 return false; 14955 14956 return Op->LHS == LHS && Op->RHS == RHS; 14957 } 14958 14959 bool SCEVComparePredicate::isAlwaysTrue() const { return false; } 14960 14961 void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const { 14962 if (Pred == ICmpInst::ICMP_EQ) 14963 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; 14964 else 14965 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") " 14966 << *RHS << "\n"; 14967 14968 } 14969 14970 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, 14971 const SCEVAddRecExpr *AR, 14972 IncrementWrapFlags Flags) 14973 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} 14974 14975 const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } 14976 14977 bool SCEVWrapPredicate::implies(const SCEVPredicate *N, 14978 ScalarEvolution &SE) const { 14979 const auto *Op = dyn_cast<SCEVWrapPredicate>(N); 14980 if (!Op || setFlags(Flags, Op->Flags) != Flags) 14981 return false; 14982 14983 if (Op->AR == AR) 14984 return true; 14985 14986 if (Flags != SCEVWrapPredicate::IncrementNSSW && 14987 Flags != SCEVWrapPredicate::IncrementNUSW) 14988 return false; 14989 14990 const SCEV *Start = AR->getStart(); 14991 const SCEV *OpStart = Op->AR->getStart(); 14992 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy()) 14993 return false; 14994 14995 // Reject pointers to different address spaces. 14996 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType()) 14997 return false; 14998 14999 const SCEV *Step = AR->getStepRecurrence(SE); 15000 const SCEV *OpStep = Op->AR->getStepRecurrence(SE); 15001 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep)) 15002 return false; 15003 15004 // If both steps are positive, this implies N, if N's start and step are 15005 // ULE/SLE (for NSUW/NSSW) than this'. 15006 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType()); 15007 Step = SE.getNoopOrZeroExtend(Step, WiderTy); 15008 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy); 15009 15010 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; 15011 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy) 15012 : SE.getNoopOrSignExtend(OpStart, WiderTy); 15013 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy) 15014 : SE.getNoopOrSignExtend(Start, WiderTy); 15015 CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; 15016 return SE.isKnownPredicate(Pred, OpStep, Step) && 15017 SE.isKnownPredicate(Pred, OpStart, Start); 15018 } 15019 15020 bool SCEVWrapPredicate::isAlwaysTrue() const { 15021 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); 15022 IncrementWrapFlags IFlags = Flags; 15023 15024 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) 15025 IFlags = clearFlags(IFlags, IncrementNSSW); 15026 15027 return IFlags == IncrementAnyWrap; 15028 } 15029 15030 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { 15031 OS.indent(Depth) << *getExpr() << " Added Flags: "; 15032 if (SCEVWrapPredicate::IncrementNUSW & getFlags()) 15033 OS << "<nusw>"; 15034 if (SCEVWrapPredicate::IncrementNSSW & getFlags()) 15035 OS << "<nssw>"; 15036 OS << "\n"; 15037 } 15038 15039 SCEVWrapPredicate::IncrementWrapFlags 15040 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, 15041 ScalarEvolution &SE) { 15042 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; 15043 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); 15044 15045 // We can safely transfer the NSW flag as NSSW. 15046 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) 15047 ImpliedFlags = IncrementNSSW; 15048 15049 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { 15050 // If the increment is positive, the SCEV NUW flag will also imply the 15051 // WrapPredicate NUSW flag. 15052 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) 15053 if (Step->getValue()->getValue().isNonNegative()) 15054 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); 15055 } 15056 15057 return ImpliedFlags; 15058 } 15059 15060 /// Union predicates don't get cached so create a dummy set ID for it. 15061 SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds, 15062 ScalarEvolution &SE) 15063 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { 15064 for (const auto *P : Preds) 15065 add(P, SE); 15066 } 15067 15068 bool SCEVUnionPredicate::isAlwaysTrue() const { 15069 return all_of(Preds, 15070 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); 15071 } 15072 15073 bool SCEVUnionPredicate::implies(const SCEVPredicate *N, 15074 ScalarEvolution &SE) const { 15075 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) 15076 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) { 15077 return this->implies(I, SE); 15078 }); 15079 15080 return any_of(Preds, 15081 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }); 15082 } 15083 15084 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { 15085 for (const auto *Pred : Preds) 15086 Pred->print(OS, Depth); 15087 } 15088 15089 void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) { 15090 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { 15091 for (const auto *Pred : Set->Preds) 15092 add(Pred, SE); 15093 return; 15094 } 15095 15096 // Implication checks are quadratic in the number of predicates. Stop doing 15097 // them if there are many predicates, as they should be too expensive to use 15098 // anyway at that point. 15099 bool CheckImplies = Preds.size() < 16; 15100 15101 // Only add predicate if it is not already implied by this union predicate. 15102 if (CheckImplies && implies(N, SE)) 15103 return; 15104 15105 // Build a new vector containing the current predicates, except the ones that 15106 // are implied by the new predicate N. 15107 SmallVector<const SCEVPredicate *> PrunedPreds; 15108 for (auto *P : Preds) { 15109 if (CheckImplies && N->implies(P, SE)) 15110 continue; 15111 PrunedPreds.push_back(P); 15112 } 15113 Preds = std::move(PrunedPreds); 15114 Preds.push_back(N); 15115 } 15116 15117 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, 15118 Loop &L) 15119 : SE(SE), L(L) { 15120 SmallVector<const SCEVPredicate*, 4> Empty; 15121 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE); 15122 } 15123 15124 void ScalarEvolution::registerUser(const SCEV *User, 15125 ArrayRef<const SCEV *> Ops) { 15126 for (const auto *Op : Ops) 15127 // We do not expect that forgetting cached data for SCEVConstants will ever 15128 // open any prospects for sharpening or introduce any correctness issues, 15129 // so we don't bother storing their dependencies. 15130 if (!isa<SCEVConstant>(Op)) 15131 SCEVUsers[Op].insert(User); 15132 } 15133 15134 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { 15135 const SCEV *Expr = SE.getSCEV(V); 15136 RewriteEntry &Entry = RewriteMap[Expr]; 15137 15138 // If we already have an entry and the version matches, return it. 15139 if (Entry.second && Generation == Entry.first) 15140 return Entry.second; 15141 15142 // We found an entry but it's stale. Rewrite the stale entry 15143 // according to the current predicate. 15144 if (Entry.second) 15145 Expr = Entry.second; 15146 15147 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); 15148 Entry = {Generation, NewSCEV}; 15149 15150 return NewSCEV; 15151 } 15152 15153 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { 15154 if (!BackedgeCount) { 15155 SmallVector<const SCEVPredicate *, 4> Preds; 15156 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds); 15157 for (const auto *P : Preds) 15158 addPredicate(*P); 15159 } 15160 return BackedgeCount; 15161 } 15162 15163 const SCEV *PredicatedScalarEvolution::getSymbolicMaxBackedgeTakenCount() { 15164 if (!SymbolicMaxBackedgeCount) { 15165 SmallVector<const SCEVPredicate *, 4> Preds; 15166 SymbolicMaxBackedgeCount = 15167 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds); 15168 for (const auto *P : Preds) 15169 addPredicate(*P); 15170 } 15171 return SymbolicMaxBackedgeCount; 15172 } 15173 15174 unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() { 15175 if (!SmallConstantMaxTripCount) { 15176 SmallVector<const SCEVPredicate *, 4> Preds; 15177 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds); 15178 for (const auto *P : Preds) 15179 addPredicate(*P); 15180 } 15181 return *SmallConstantMaxTripCount; 15182 } 15183 15184 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { 15185 if (Preds->implies(&Pred, SE)) 15186 return; 15187 15188 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates()); 15189 NewPreds.push_back(&Pred); 15190 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE); 15191 updateGeneration(); 15192 } 15193 15194 const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const { 15195 return *Preds; 15196 } 15197 15198 void PredicatedScalarEvolution::updateGeneration() { 15199 // If the generation number wrapped recompute everything. 15200 if (++Generation == 0) { 15201 for (auto &II : RewriteMap) { 15202 const SCEV *Rewritten = II.second.second; 15203 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)}; 15204 } 15205 } 15206 } 15207 15208 void PredicatedScalarEvolution::setNoOverflow( 15209 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { 15210 const SCEV *Expr = getSCEV(V); 15211 const auto *AR = cast<SCEVAddRecExpr>(Expr); 15212 15213 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); 15214 15215 // Clear the statically implied flags. 15216 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); 15217 addPredicate(*SE.getWrapPredicate(AR, Flags)); 15218 15219 auto II = FlagsMap.insert({V, Flags}); 15220 if (!II.second) 15221 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); 15222 } 15223 15224 bool PredicatedScalarEvolution::hasNoOverflow( 15225 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { 15226 const SCEV *Expr = getSCEV(V); 15227 const auto *AR = cast<SCEVAddRecExpr>(Expr); 15228 15229 Flags = SCEVWrapPredicate::clearFlags( 15230 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); 15231 15232 auto II = FlagsMap.find(V); 15233 15234 if (II != FlagsMap.end()) 15235 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); 15236 15237 return Flags == SCEVWrapPredicate::IncrementAnyWrap; 15238 } 15239 15240 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { 15241 const SCEV *Expr = this->getSCEV(V); 15242 SmallVector<const SCEVPredicate *, 4> NewPreds; 15243 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); 15244 15245 if (!New) 15246 return nullptr; 15247 15248 for (const auto *P : NewPreds) 15249 addPredicate(*P); 15250 15251 RewriteMap[SE.getSCEV(V)] = {Generation, New}; 15252 return New; 15253 } 15254 15255 PredicatedScalarEvolution::PredicatedScalarEvolution( 15256 const PredicatedScalarEvolution &Init) 15257 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), 15258 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(), 15259 SE)), 15260 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { 15261 for (auto I : Init.FlagsMap) 15262 FlagsMap.insert(I); 15263 } 15264 15265 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { 15266 // For each block. 15267 for (auto *BB : L.getBlocks()) 15268 for (auto &I : *BB) { 15269 if (!SE.isSCEVable(I.getType())) 15270 continue; 15271 15272 auto *Expr = SE.getSCEV(&I); 15273 auto II = RewriteMap.find(Expr); 15274 15275 if (II == RewriteMap.end()) 15276 continue; 15277 15278 // Don't print things that are not interesting. 15279 if (II->second.second == Expr) 15280 continue; 15281 15282 OS.indent(Depth) << "[PSE]" << I << ":\n"; 15283 OS.indent(Depth + 2) << *Expr << "\n"; 15284 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; 15285 } 15286 } 15287 15288 // Match the mathematical pattern A - (A / B) * B, where A and B can be 15289 // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used 15290 // for URem with constant power-of-2 second operands. 15291 // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is 15292 // 4, A / B becomes X / 8). 15293 bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, 15294 const SCEV *&RHS) { 15295 if (Expr->getType()->isPointerTy()) 15296 return false; 15297 15298 // Try to match 'zext (trunc A to iB) to iY', which is used 15299 // for URem with constant power-of-2 second operands. Make sure the size of 15300 // the operand A matches the size of the whole expressions. 15301 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr)) 15302 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) { 15303 LHS = Trunc->getOperand(); 15304 // Bail out if the type of the LHS is larger than the type of the 15305 // expression for now. 15306 if (getTypeSizeInBits(LHS->getType()) > 15307 getTypeSizeInBits(Expr->getType())) 15308 return false; 15309 if (LHS->getType() != Expr->getType()) 15310 LHS = getZeroExtendExpr(LHS, Expr->getType()); 15311 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) 15312 << getTypeSizeInBits(Trunc->getType())); 15313 return true; 15314 } 15315 const auto *Add = dyn_cast<SCEVAddExpr>(Expr); 15316 if (Add == nullptr || Add->getNumOperands() != 2) 15317 return false; 15318 15319 const SCEV *A = Add->getOperand(1); 15320 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); 15321 15322 if (Mul == nullptr) 15323 return false; 15324 15325 const auto MatchURemWithDivisor = [&](const SCEV *B) { 15326 // (SomeExpr + (-(SomeExpr / B) * B)). 15327 if (Expr == getURemExpr(A, B)) { 15328 LHS = A; 15329 RHS = B; 15330 return true; 15331 } 15332 return false; 15333 }; 15334 15335 // (SomeExpr + (-1 * (SomeExpr / B) * B)). 15336 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0))) 15337 return MatchURemWithDivisor(Mul->getOperand(1)) || 15338 MatchURemWithDivisor(Mul->getOperand(2)); 15339 15340 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). 15341 if (Mul->getNumOperands() == 2) 15342 return MatchURemWithDivisor(Mul->getOperand(1)) || 15343 MatchURemWithDivisor(Mul->getOperand(0)) || 15344 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) || 15345 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); 15346 return false; 15347 } 15348 15349 ScalarEvolution::LoopGuards 15350 ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { 15351 BasicBlock *Header = L->getHeader(); 15352 BasicBlock *Pred = L->getLoopPredecessor(); 15353 LoopGuards Guards(SE); 15354 if (!Pred) 15355 return Guards; 15356 SmallPtrSet<const BasicBlock *, 8> VisitedBlocks; 15357 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks); 15358 return Guards; 15359 } 15360 15361 void ScalarEvolution::LoopGuards::collectFromPHI( 15362 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, 15363 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, 15364 SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards, 15365 unsigned Depth) { 15366 if (!SE.isSCEVable(Phi.getType())) 15367 return; 15368 15369 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>; 15370 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern { 15371 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx); 15372 if (!VisitedBlocks.insert(InBlock).second) 15373 return {nullptr, scCouldNotCompute}; 15374 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE)); 15375 if (Inserted) 15376 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks, 15377 Depth + 1); 15378 auto &RewriteMap = G->second.RewriteMap; 15379 if (RewriteMap.empty()) 15380 return {nullptr, scCouldNotCompute}; 15381 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx))); 15382 if (S == RewriteMap.end()) 15383 return {nullptr, scCouldNotCompute}; 15384 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second); 15385 if (!SM) 15386 return {nullptr, scCouldNotCompute}; 15387 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0))) 15388 return {C0, SM->getSCEVType()}; 15389 return {nullptr, scCouldNotCompute}; 15390 }; 15391 auto MergeMinMaxConst = [](MinMaxPattern P1, 15392 MinMaxPattern P2) -> MinMaxPattern { 15393 auto [C1, T1] = P1; 15394 auto [C2, T2] = P2; 15395 if (!C1 || !C2 || T1 != T2) 15396 return {nullptr, scCouldNotCompute}; 15397 switch (T1) { 15398 case scUMaxExpr: 15399 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1}; 15400 case scSMaxExpr: 15401 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1}; 15402 case scUMinExpr: 15403 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1}; 15404 case scSMinExpr: 15405 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1}; 15406 default: 15407 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs."); 15408 } 15409 }; 15410 auto P = GetMinMaxConst(0); 15411 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) { 15412 if (!P.first) 15413 break; 15414 P = MergeMinMaxConst(P, GetMinMaxConst(In)); 15415 } 15416 if (P.first) { 15417 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi)); 15418 SmallVector<const SCEV *, 2> Ops({P.first, LHS}); 15419 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops); 15420 Guards.RewriteMap.insert({LHS, RHS}); 15421 } 15422 } 15423 15424 void ScalarEvolution::LoopGuards::collectFromBlock( 15425 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, 15426 const BasicBlock *Block, const BasicBlock *Pred, 15427 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) { 15428 SmallVector<const SCEV *> ExprsToRewrite; 15429 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, 15430 const SCEV *RHS, 15431 DenseMap<const SCEV *, const SCEV *> 15432 &RewriteMap) { 15433 // WARNING: It is generally unsound to apply any wrap flags to the proposed 15434 // replacement SCEV which isn't directly implied by the structure of that 15435 // SCEV. In particular, using contextual facts to imply flags is *NOT* 15436 // legal. See the scoping rules for flags in the header to understand why. 15437 15438 // If LHS is a constant, apply information to the other expression. 15439 if (isa<SCEVConstant>(LHS)) { 15440 std::swap(LHS, RHS); 15441 Predicate = CmpInst::getSwappedPredicate(Predicate); 15442 } 15443 15444 // Check for a condition of the form (-C1 + X < C2). InstCombine will 15445 // create this form when combining two checks of the form (X u< C2 + C1) and 15446 // (X >=u C1). 15447 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap, 15448 &ExprsToRewrite]() { 15449 const SCEVConstant *C1; 15450 const SCEVUnknown *LHSUnknown; 15451 auto *C2 = dyn_cast<SCEVConstant>(RHS); 15452 if (!match(LHS, 15453 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) || 15454 !C2) 15455 return false; 15456 15457 auto ExactRegion = 15458 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt()) 15459 .sub(C1->getAPInt()); 15460 15461 // Bail out, unless we have a non-wrapping, monotonic range. 15462 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet()) 15463 return false; 15464 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown); 15465 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second; 15466 I->second = SE.getUMaxExpr( 15467 SE.getConstant(ExactRegion.getUnsignedMin()), 15468 SE.getUMinExpr(RewrittenLHS, 15469 SE.getConstant(ExactRegion.getUnsignedMax()))); 15470 ExprsToRewrite.push_back(LHSUnknown); 15471 return true; 15472 }; 15473 if (MatchRangeCheckIdiom()) 15474 return; 15475 15476 // Return true if \p Expr is a MinMax SCEV expression with a non-negative 15477 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS 15478 // the non-constant operand and in \p LHS the constant operand. 15479 auto IsMinMaxSCEVWithNonNegativeConstant = 15480 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, 15481 const SCEV *&RHS) { 15482 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) { 15483 if (MinMax->getNumOperands() != 2) 15484 return false; 15485 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) { 15486 if (C->getAPInt().isNegative()) 15487 return false; 15488 SCTy = MinMax->getSCEVType(); 15489 LHS = MinMax->getOperand(0); 15490 RHS = MinMax->getOperand(1); 15491 return true; 15492 } 15493 } 15494 return false; 15495 }; 15496 15497 // Checks whether Expr is a non-negative constant, and Divisor is a positive 15498 // constant, and returns their APInt in ExprVal and in DivisorVal. 15499 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, 15500 APInt &ExprVal, APInt &DivisorVal) { 15501 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr); 15502 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor); 15503 if (!ConstExpr || !ConstDivisor) 15504 return false; 15505 ExprVal = ConstExpr->getAPInt(); 15506 DivisorVal = ConstDivisor->getAPInt(); 15507 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); 15508 }; 15509 15510 // Return a new SCEV that modifies \p Expr to the closest number divides by 15511 // \p Divisor and greater or equal than Expr. 15512 // For now, only handle constant Expr and Divisor. 15513 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, 15514 const SCEV *Divisor) { 15515 APInt ExprVal; 15516 APInt DivisorVal; 15517 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) 15518 return Expr; 15519 APInt Rem = ExprVal.urem(DivisorVal); 15520 if (!Rem.isZero()) 15521 // return the SCEV: Expr + Divisor - Expr % Divisor 15522 return SE.getConstant(ExprVal + DivisorVal - Rem); 15523 return Expr; 15524 }; 15525 15526 // Return a new SCEV that modifies \p Expr to the closest number divides by 15527 // \p Divisor and less or equal than Expr. 15528 // For now, only handle constant Expr and Divisor. 15529 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, 15530 const SCEV *Divisor) { 15531 APInt ExprVal; 15532 APInt DivisorVal; 15533 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) 15534 return Expr; 15535 APInt Rem = ExprVal.urem(DivisorVal); 15536 // return the SCEV: Expr - Expr % Divisor 15537 return SE.getConstant(ExprVal - Rem); 15538 }; 15539 15540 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, 15541 // recursively. This is done by aligning up/down the constant value to the 15542 // Divisor. 15543 std::function<const SCEV *(const SCEV *, const SCEV *)> 15544 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, 15545 const SCEV *Divisor) { 15546 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; 15547 SCEVTypes SCTy; 15548 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, 15549 MinMaxRHS)) 15550 return MinMaxExpr; 15551 auto IsMin = 15552 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr); 15553 assert(SE.isKnownNonNegative(MinMaxLHS) && 15554 "Expected non-negative operand!"); 15555 auto *DivisibleExpr = 15556 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) 15557 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); 15558 SmallVector<const SCEV *> Ops = { 15559 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; 15560 return SE.getMinMaxExpr(SCTy, Ops); 15561 }; 15562 15563 // If we have LHS == 0, check if LHS is computing a property of some unknown 15564 // SCEV %v which we can rewrite %v to express explicitly. 15565 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { 15566 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to 15567 // explicitly express that. 15568 const SCEV *URemLHS = nullptr; 15569 const SCEV *URemRHS = nullptr; 15570 if (SE.matchURem(LHS, URemLHS, URemRHS)) { 15571 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) { 15572 auto I = RewriteMap.find(LHSUnknown); 15573 const SCEV *RewrittenLHS = 15574 I != RewriteMap.end() ? I->second : LHSUnknown; 15575 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); 15576 const auto *Multiple = 15577 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); 15578 RewriteMap[LHSUnknown] = Multiple; 15579 ExprsToRewrite.push_back(LHSUnknown); 15580 return; 15581 } 15582 } 15583 } 15584 15585 // Do not apply information for constants or if RHS contains an AddRec. 15586 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS)) 15587 return; 15588 15589 // If RHS is SCEVUnknown, make sure the information is applied to it. 15590 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) { 15591 std::swap(LHS, RHS); 15592 Predicate = CmpInst::getSwappedPredicate(Predicate); 15593 } 15594 15595 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From 15596 // and \p FromRewritten are the same (i.e. there has been no rewrite 15597 // registered for \p From), then puts this value in the list of rewritten 15598 // expressions. 15599 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten, 15600 const SCEV *To) { 15601 if (From == FromRewritten) 15602 ExprsToRewrite.push_back(From); 15603 RewriteMap[From] = To; 15604 }; 15605 15606 // Checks whether \p S has already been rewritten. In that case returns the 15607 // existing rewrite because we want to chain further rewrites onto the 15608 // already rewritten value. Otherwise returns \p S. 15609 auto GetMaybeRewritten = [&](const SCEV *S) { 15610 return RewriteMap.lookup_or(S, S); 15611 }; 15612 15613 // Check for the SCEV expression (A /u B) * B while B is a constant, inside 15614 // \p Expr. The check is done recuresively on \p Expr, which is assumed to 15615 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A 15616 // /u B) * B was found, and return the divisor B in \p DividesBy. For 15617 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since 15618 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p 15619 // DividesBy. 15620 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo = 15621 [&](const SCEV *Expr, const SCEV *&DividesBy) { 15622 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) { 15623 if (Mul->getNumOperands() != 2) 15624 return false; 15625 auto *MulLHS = Mul->getOperand(0); 15626 auto *MulRHS = Mul->getOperand(1); 15627 if (isa<SCEVConstant>(MulLHS)) 15628 std::swap(MulLHS, MulRHS); 15629 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) 15630 if (Div->getOperand(1) == MulRHS) { 15631 DividesBy = MulRHS; 15632 return true; 15633 } 15634 } 15635 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) 15636 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || 15637 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); 15638 return false; 15639 }; 15640 15641 // Return true if Expr known to divide by \p DividesBy. 15642 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy = 15643 [&](const SCEV *Expr, const SCEV *DividesBy) { 15644 if (SE.getURemExpr(Expr, DividesBy)->isZero()) 15645 return true; 15646 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) 15647 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && 15648 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); 15649 return false; 15650 }; 15651 15652 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); 15653 const SCEV *DividesBy = nullptr; 15654 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) 15655 // Check that the whole expression is divided by DividesBy 15656 DividesBy = 15657 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; 15658 15659 // Collect rewrites for LHS and its transitive operands based on the 15660 // condition. 15661 // For min/max expressions, also apply the guard to its operands: 15662 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)', 15663 // 'min(a, b) > c' -> '(a > c) and (b > c)', 15664 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)', 15665 // 'max(a, b) < c' -> '(a < c) and (b < c)'. 15666 15667 // We cannot express strict predicates in SCEV, so instead we replace them 15668 // with non-strict ones against plus or minus one of RHS depending on the 15669 // predicate. 15670 const SCEV *One = SE.getOne(RHS->getType()); 15671 switch (Predicate) { 15672 case CmpInst::ICMP_ULT: 15673 if (RHS->getType()->isPointerTy()) 15674 return; 15675 RHS = SE.getUMaxExpr(RHS, One); 15676 [[fallthrough]]; 15677 case CmpInst::ICMP_SLT: { 15678 RHS = SE.getMinusSCEV(RHS, One); 15679 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; 15680 break; 15681 } 15682 case CmpInst::ICMP_UGT: 15683 case CmpInst::ICMP_SGT: 15684 RHS = SE.getAddExpr(RHS, One); 15685 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; 15686 break; 15687 case CmpInst::ICMP_ULE: 15688 case CmpInst::ICMP_SLE: 15689 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; 15690 break; 15691 case CmpInst::ICMP_UGE: 15692 case CmpInst::ICMP_SGE: 15693 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; 15694 break; 15695 default: 15696 break; 15697 } 15698 15699 SmallVector<const SCEV *, 16> Worklist(1, LHS); 15700 SmallPtrSet<const SCEV *, 16> Visited; 15701 15702 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) { 15703 append_range(Worklist, S->operands()); 15704 }; 15705 15706 while (!Worklist.empty()) { 15707 const SCEV *From = Worklist.pop_back_val(); 15708 if (isa<SCEVConstant>(From)) 15709 continue; 15710 if (!Visited.insert(From).second) 15711 continue; 15712 const SCEV *FromRewritten = GetMaybeRewritten(From); 15713 const SCEV *To = nullptr; 15714 15715 switch (Predicate) { 15716 case CmpInst::ICMP_ULT: 15717 case CmpInst::ICMP_ULE: 15718 To = SE.getUMinExpr(FromRewritten, RHS); 15719 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten)) 15720 EnqueueOperands(UMax); 15721 break; 15722 case CmpInst::ICMP_SLT: 15723 case CmpInst::ICMP_SLE: 15724 To = SE.getSMinExpr(FromRewritten, RHS); 15725 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten)) 15726 EnqueueOperands(SMax); 15727 break; 15728 case CmpInst::ICMP_UGT: 15729 case CmpInst::ICMP_UGE: 15730 To = SE.getUMaxExpr(FromRewritten, RHS); 15731 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten)) 15732 EnqueueOperands(UMin); 15733 break; 15734 case CmpInst::ICMP_SGT: 15735 case CmpInst::ICMP_SGE: 15736 To = SE.getSMaxExpr(FromRewritten, RHS); 15737 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) 15738 EnqueueOperands(SMin); 15739 break; 15740 case CmpInst::ICMP_EQ: 15741 if (isa<SCEVConstant>(RHS)) 15742 To = RHS; 15743 break; 15744 case CmpInst::ICMP_NE: 15745 if (match(RHS, m_scev_Zero())) { 15746 const SCEV *OneAlignedUp = 15747 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; 15748 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); 15749 } 15750 break; 15751 default: 15752 break; 15753 } 15754 15755 if (To) 15756 AddRewrite(From, FromRewritten, To); 15757 } 15758 }; 15759 15760 SmallVector<PointerIntPair<Value *, 1, bool>> Terms; 15761 // First, collect information from assumptions dominating the loop. 15762 for (auto &AssumeVH : SE.AC.assumptions()) { 15763 if (!AssumeVH) 15764 continue; 15765 auto *AssumeI = cast<CallInst>(AssumeVH); 15766 if (!SE.DT.dominates(AssumeI, Block)) 15767 continue; 15768 Terms.emplace_back(AssumeI->getOperand(0), true); 15769 } 15770 15771 // Second, collect information from llvm.experimental.guards dominating the loop. 15772 auto *GuardDecl = Intrinsic::getDeclarationIfExists( 15773 SE.F.getParent(), Intrinsic::experimental_guard); 15774 if (GuardDecl) 15775 for (const auto *GU : GuardDecl->users()) 15776 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU)) 15777 if (Guard->getFunction() == Block->getParent() && 15778 SE.DT.dominates(Guard, Block)) 15779 Terms.emplace_back(Guard->getArgOperand(0), true); 15780 15781 // Third, collect conditions from dominating branches. Starting at the loop 15782 // predecessor, climb up the predecessor chain, as long as there are 15783 // predecessors that can be found that have unique successors leading to the 15784 // original header. 15785 // TODO: share this logic with isLoopEntryGuardedByCond. 15786 unsigned NumCollectedConditions = 0; 15787 VisitedBlocks.insert(Block); 15788 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block); 15789 for (; Pair.first; 15790 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) { 15791 VisitedBlocks.insert(Pair.second); 15792 const BranchInst *LoopEntryPredicate = 15793 dyn_cast<BranchInst>(Pair.first->getTerminator()); 15794 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional()) 15795 continue; 15796 15797 Terms.emplace_back(LoopEntryPredicate->getCondition(), 15798 LoopEntryPredicate->getSuccessor(0) == Pair.second); 15799 NumCollectedConditions++; 15800 15801 // If we are recursively collecting guards stop after 2 15802 // conditions to limit compile-time impact for now. 15803 if (Depth > 0 && NumCollectedConditions == 2) 15804 break; 15805 } 15806 // Finally, if we stopped climbing the predecessor chain because 15807 // there wasn't a unique one to continue, try to collect conditions 15808 // for PHINodes by recursively following all of their incoming 15809 // blocks and try to merge the found conditions to build a new one 15810 // for the Phi. 15811 if (Pair.second->hasNPredecessorsOrMore(2) && 15812 Depth < MaxLoopGuardCollectionDepth) { 15813 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards; 15814 for (auto &Phi : Pair.second->phis()) 15815 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth); 15816 } 15817 15818 // Now apply the information from the collected conditions to 15819 // Guards.RewriteMap. Conditions are processed in reverse order, so the 15820 // earliest conditions is processed first. This ensures the SCEVs with the 15821 // shortest dependency chains are constructed first. 15822 for (auto [Term, EnterIfTrue] : reverse(Terms)) { 15823 SmallVector<Value *, 8> Worklist; 15824 SmallPtrSet<Value *, 8> Visited; 15825 Worklist.push_back(Term); 15826 while (!Worklist.empty()) { 15827 Value *Cond = Worklist.pop_back_val(); 15828 if (!Visited.insert(Cond).second) 15829 continue; 15830 15831 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) { 15832 auto Predicate = 15833 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); 15834 const auto *LHS = SE.getSCEV(Cmp->getOperand(0)); 15835 const auto *RHS = SE.getSCEV(Cmp->getOperand(1)); 15836 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap); 15837 continue; 15838 } 15839 15840 Value *L, *R; 15841 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R))) 15842 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) { 15843 Worklist.push_back(L); 15844 Worklist.push_back(R); 15845 } 15846 } 15847 } 15848 15849 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of 15850 // the replacement expressions are contained in the ranges of the replaced 15851 // expressions. 15852 Guards.PreserveNUW = true; 15853 Guards.PreserveNSW = true; 15854 for (const SCEV *Expr : ExprsToRewrite) { 15855 const SCEV *RewriteTo = Guards.RewriteMap[Expr]; 15856 Guards.PreserveNUW &= 15857 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo)); 15858 Guards.PreserveNSW &= 15859 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo)); 15860 } 15861 15862 // Now that all rewrite information is collect, rewrite the collected 15863 // expressions with the information in the map. This applies information to 15864 // sub-expressions. 15865 if (ExprsToRewrite.size() > 1) { 15866 for (const SCEV *Expr : ExprsToRewrite) { 15867 const SCEV *RewriteTo = Guards.RewriteMap[Expr]; 15868 Guards.RewriteMap.erase(Expr); 15869 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)}); 15870 } 15871 } 15872 } 15873 15874 const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { 15875 /// A rewriter to replace SCEV expressions in Map with the corresponding entry 15876 /// in the map. It skips AddRecExpr because we cannot guarantee that the 15877 /// replacement is loop invariant in the loop of the AddRec. 15878 class SCEVLoopGuardRewriter 15879 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { 15880 const DenseMap<const SCEV *, const SCEV *> ⤅ 15881 15882 SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; 15883 15884 public: 15885 SCEVLoopGuardRewriter(ScalarEvolution &SE, 15886 const ScalarEvolution::LoopGuards &Guards) 15887 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { 15888 if (Guards.PreserveNUW) 15889 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); 15890 if (Guards.PreserveNSW) 15891 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW); 15892 } 15893 15894 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; } 15895 15896 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 15897 return Map.lookup_or(Expr, Expr); 15898 } 15899 15900 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 15901 if (const SCEV *S = Map.lookup(Expr)) 15902 return S; 15903 15904 // If we didn't find the extact ZExt expr in the map, check if there's 15905 // an entry for a smaller ZExt we can use instead. 15906 Type *Ty = Expr->getType(); 15907 const SCEV *Op = Expr->getOperand(0); 15908 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2; 15909 while (Bitwidth % 8 == 0 && Bitwidth >= 8 && 15910 Bitwidth > Op->getType()->getScalarSizeInBits()) { 15911 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth); 15912 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); 15913 if (const SCEV *S = Map.lookup(NarrowExt)) 15914 return SE.getZeroExtendExpr(S, Ty); 15915 Bitwidth = Bitwidth / 2; 15916 } 15917 15918 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr( 15919 Expr); 15920 } 15921 15922 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 15923 if (const SCEV *S = Map.lookup(Expr)) 15924 return S; 15925 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr( 15926 Expr); 15927 } 15928 15929 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 15930 if (const SCEV *S = Map.lookup(Expr)) 15931 return S; 15932 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr); 15933 } 15934 15935 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 15936 if (const SCEV *S = Map.lookup(Expr)) 15937 return S; 15938 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr); 15939 } 15940 15941 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 15942 SmallVector<const SCEV *, 2> Operands; 15943 bool Changed = false; 15944 for (const auto *Op : Expr->operands()) { 15945 Operands.push_back( 15946 SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op)); 15947 Changed |= Op != Operands.back(); 15948 } 15949 // We are only replacing operands with equivalent values, so transfer the 15950 // flags from the original expression. 15951 return !Changed ? Expr 15952 : SE.getAddExpr(Operands, 15953 ScalarEvolution::maskFlags( 15954 Expr->getNoWrapFlags(), FlagMask)); 15955 } 15956 15957 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 15958 SmallVector<const SCEV *, 2> Operands; 15959 bool Changed = false; 15960 for (const auto *Op : Expr->operands()) { 15961 Operands.push_back( 15962 SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op)); 15963 Changed |= Op != Operands.back(); 15964 } 15965 // We are only replacing operands with equivalent values, so transfer the 15966 // flags from the original expression. 15967 return !Changed ? Expr 15968 : SE.getMulExpr(Operands, 15969 ScalarEvolution::maskFlags( 15970 Expr->getNoWrapFlags(), FlagMask)); 15971 } 15972 }; 15973 15974 if (RewriteMap.empty()) 15975 return Expr; 15976 15977 SCEVLoopGuardRewriter Rewriter(SE, *this); 15978 return Rewriter.visit(Expr); 15979 } 15980 15981 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { 15982 return applyLoopGuards(Expr, LoopGuards::collect(L, *this)); 15983 } 15984 15985 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, 15986 const LoopGuards &Guards) { 15987 return Guards.rewrite(Expr); 15988 } 15989