1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/Analysis/ScalarEvolution.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/ValueHandle.h" 22 #include "llvm/Support/Casting.h" 23 #include "llvm/Support/ErrorHandling.h" 24 #include <cassert> 25 #include <cstddef> 26 27 namespace llvm { 28 29 class APInt; 30 class Constant; 31 class ConstantInt; 32 class ConstantRange; 33 class Loop; 34 class Type; 35 class Value; 36 37 enum SCEVTypes : unsigned short { 38 // These should be ordered in terms of increasing complexity to make the 39 // folders simpler. 40 scConstant, 41 scVScale, 42 scTruncate, 43 scZeroExtend, 44 scSignExtend, 45 scAddExpr, 46 scMulExpr, 47 scUDivExpr, 48 scAddRecExpr, 49 scUMaxExpr, 50 scSMaxExpr, 51 scUMinExpr, 52 scSMinExpr, 53 scSequentialUMinExpr, 54 scPtrToInt, 55 scUnknown, 56 scCouldNotCompute 57 }; 58 59 /// This class represents a constant integer value. 60 class SCEVConstant : public SCEV { 61 friend class ScalarEvolution; 62 63 ConstantInt *V; 64 65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 66 : SCEV(ID, scConstant, 1), V(v) {} 67 68 public: 69 ConstantInt *getValue() const { return V; } 70 const APInt &getAPInt() const { return getValue()->getValue(); } 71 72 Type *getType() const { return V->getType(); } 73 74 /// Methods for support type inquiry through isa, cast, and dyn_cast: 75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 76 }; 77 78 /// This class represents the value of vscale, as used when defining the length 79 /// of a scalable vector or returned by the llvm.vscale() intrinsic. 80 class SCEVVScale : public SCEV { 81 friend class ScalarEvolution; 82 83 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) 84 : SCEV(ID, scVScale, 0), Ty(ty) {} 85 86 Type *Ty; 87 88 public: 89 Type *getType() const { return Ty; } 90 91 /// Methods for support type inquiry through isa, cast, and dyn_cast: 92 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } 93 }; 94 95 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 96 APInt Size(16, 1); 97 for (const auto *Arg : Args) 98 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 99 return (unsigned short)Size.getZExtValue(); 100 } 101 102 /// This is the base class for unary cast operator classes. 103 class SCEVCastExpr : public SCEV { 104 protected: 105 const SCEV *Op; 106 Type *Ty; 107 108 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 109 Type *ty); 110 111 public: 112 const SCEV *getOperand() const { return Op; } 113 const SCEV *getOperand(unsigned i) const { 114 assert(i == 0 && "Operand index out of range!"); 115 return Op; 116 } 117 ArrayRef<const SCEV *> operands() const { return Op; } 118 size_t getNumOperands() const { return 1; } 119 Type *getType() const { return Ty; } 120 121 /// Methods for support type inquiry through isa, cast, and dyn_cast: 122 static bool classof(const SCEV *S) { 123 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 124 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 125 } 126 }; 127 128 /// This class represents a cast from a pointer to a pointer-sized integer 129 /// value. 130 class SCEVPtrToIntExpr : public SCEVCastExpr { 131 friend class ScalarEvolution; 132 133 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 134 135 public: 136 /// Methods for support type inquiry through isa, cast, and dyn_cast: 137 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 138 }; 139 140 /// This is the base class for unary integral cast operator classes. 141 class SCEVIntegralCastExpr : public SCEVCastExpr { 142 protected: 143 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 144 const SCEV *op, Type *ty); 145 146 public: 147 /// Methods for support type inquiry through isa, cast, and dyn_cast: 148 static bool classof(const SCEV *S) { 149 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 150 S->getSCEVType() == scSignExtend; 151 } 152 }; 153 154 /// This class represents a truncation of an integer value to a 155 /// smaller integer value. 156 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 157 friend class ScalarEvolution; 158 159 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 160 161 public: 162 /// Methods for support type inquiry through isa, cast, and dyn_cast: 163 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 164 }; 165 166 /// This class represents a zero extension of a small integer value 167 /// to a larger integer value. 168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 169 friend class ScalarEvolution; 170 171 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 172 173 public: 174 /// Methods for support type inquiry through isa, cast, and dyn_cast: 175 static bool classof(const SCEV *S) { 176 return S->getSCEVType() == scZeroExtend; 177 } 178 }; 179 180 /// This class represents a sign extension of a small integer value 181 /// to a larger integer value. 182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 183 friend class ScalarEvolution; 184 185 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 186 187 public: 188 /// Methods for support type inquiry through isa, cast, and dyn_cast: 189 static bool classof(const SCEV *S) { 190 return S->getSCEVType() == scSignExtend; 191 } 192 }; 193 194 /// This node is a base class providing common functionality for 195 /// n'ary operators. 196 class SCEVNAryExpr : public SCEV { 197 protected: 198 // Since SCEVs are immutable, ScalarEvolution allocates operand 199 // arrays with its SCEVAllocator, so this class just needs a simple 200 // pointer rather than a more elaborate vector-like data structure. 201 // This also avoids the need for a non-trivial destructor. 202 const SCEV *const *Operands; 203 size_t NumOperands; 204 205 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 206 const SCEV *const *O, size_t N) 207 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), 208 NumOperands(N) {} 209 210 public: 211 size_t getNumOperands() const { return NumOperands; } 212 213 const SCEV *getOperand(unsigned i) const { 214 assert(i < NumOperands && "Operand index out of range!"); 215 return Operands[i]; 216 } 217 218 ArrayRef<const SCEV *> operands() const { 219 return ArrayRef(Operands, NumOperands); 220 } 221 222 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 223 return (NoWrapFlags)(SubclassData & Mask); 224 } 225 226 bool hasNoUnsignedWrap() const { 227 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 228 } 229 230 bool hasNoSignedWrap() const { 231 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 232 } 233 234 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 235 236 /// Methods for support type inquiry through isa, cast, and dyn_cast: 237 static bool classof(const SCEV *S) { 238 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 239 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 240 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 241 S->getSCEVType() == scSequentialUMinExpr || 242 S->getSCEVType() == scAddRecExpr; 243 } 244 }; 245 246 /// This node is the base class for n'ary commutative operators. 247 class SCEVCommutativeExpr : public SCEVNAryExpr { 248 protected: 249 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 250 const SCEV *const *O, size_t N) 251 : SCEVNAryExpr(ID, T, O, N) {} 252 253 public: 254 /// Methods for support type inquiry through isa, cast, and dyn_cast: 255 static bool classof(const SCEV *S) { 256 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 257 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 258 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 259 } 260 261 /// Set flags for a non-recurrence without clearing previously set flags. 262 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 263 }; 264 265 /// This node represents an addition of some number of SCEVs. 266 class SCEVAddExpr : public SCEVCommutativeExpr { 267 friend class ScalarEvolution; 268 269 Type *Ty; 270 271 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 272 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 273 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 274 return Op->getType()->isPointerTy(); 275 }); 276 if (FirstPointerTypedOp != operands().end()) 277 Ty = (*FirstPointerTypedOp)->getType(); 278 else 279 Ty = getOperand(0)->getType(); 280 } 281 282 public: 283 Type *getType() const { return Ty; } 284 285 /// Methods for support type inquiry through isa, cast, and dyn_cast: 286 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 287 }; 288 289 /// This node represents multiplication of some number of SCEVs. 290 class SCEVMulExpr : public SCEVCommutativeExpr { 291 friend class ScalarEvolution; 292 293 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 294 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 295 296 public: 297 Type *getType() const { return getOperand(0)->getType(); } 298 299 /// Methods for support type inquiry through isa, cast, and dyn_cast: 300 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 301 }; 302 303 /// This class represents a binary unsigned division operation. 304 class SCEVUDivExpr : public SCEV { 305 friend class ScalarEvolution; 306 307 std::array<const SCEV *, 2> Operands; 308 309 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 310 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 311 Operands[0] = lhs; 312 Operands[1] = rhs; 313 } 314 315 public: 316 const SCEV *getLHS() const { return Operands[0]; } 317 const SCEV *getRHS() const { return Operands[1]; } 318 size_t getNumOperands() const { return 2; } 319 const SCEV *getOperand(unsigned i) const { 320 assert((i == 0 || i == 1) && "Operand index out of range!"); 321 return i == 0 ? getLHS() : getRHS(); 322 } 323 324 ArrayRef<const SCEV *> operands() const { return Operands; } 325 326 Type *getType() const { 327 // In most cases the types of LHS and RHS will be the same, but in some 328 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 329 // depend on the type for correctness, but handling types carefully can 330 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 331 // a pointer type than the RHS, so use the RHS' type here. 332 return getRHS()->getType(); 333 } 334 335 /// Methods for support type inquiry through isa, cast, and dyn_cast: 336 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 337 }; 338 339 /// This node represents a polynomial recurrence on the trip count 340 /// of the specified loop. This is the primary focus of the 341 /// ScalarEvolution framework; all the other SCEV subclasses are 342 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 343 /// expressions to be created and analyzed. 344 /// 345 /// All operands of an AddRec are required to be loop invariant. 346 /// 347 class SCEVAddRecExpr : public SCEVNAryExpr { 348 friend class ScalarEvolution; 349 350 const Loop *L; 351 352 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 353 const Loop *l) 354 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 355 356 public: 357 Type *getType() const { return getStart()->getType(); } 358 const SCEV *getStart() const { return Operands[0]; } 359 const Loop *getLoop() const { return L; } 360 361 /// Constructs and returns the recurrence indicating how much this 362 /// expression steps by. If this is a polynomial of degree N, it 363 /// returns a chrec of degree N-1. We cannot determine whether 364 /// the step recurrence has self-wraparound. 365 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 366 if (isAffine()) 367 return getOperand(1); 368 return SE.getAddRecExpr( 369 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(), 370 FlagAnyWrap); 371 } 372 373 /// Return true if this represents an expression A + B*x where A 374 /// and B are loop invariant values. 375 bool isAffine() const { 376 // We know that the start value is invariant. This expression is thus 377 // affine iff the step is also invariant. 378 return getNumOperands() == 2; 379 } 380 381 /// Return true if this represents an expression A + B*x + C*x^2 382 /// where A, B and C are loop invariant values. This corresponds 383 /// to an addrec of the form {L,+,M,+,N} 384 bool isQuadratic() const { return getNumOperands() == 3; } 385 386 /// Set flags for a recurrence without clearing any previously set flags. 387 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 388 /// to make it easier to propagate flags. 389 void setNoWrapFlags(NoWrapFlags Flags) { 390 if (Flags & (FlagNUW | FlagNSW)) 391 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 392 SubclassData |= Flags; 393 } 394 395 /// Return the value of this chain of recurrences at the specified 396 /// iteration number. 397 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 398 399 /// Return the value of this chain of recurrences at the specified iteration 400 /// number. Takes an explicit list of operands to represent an AddRec. 401 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 402 const SCEV *It, ScalarEvolution &SE); 403 404 /// Return the number of iterations of this loop that produce 405 /// values in the specified constant range. Another way of 406 /// looking at this is that it returns the first iteration number 407 /// where the value is not in the condition, thus computing the 408 /// exit count. If the iteration count can't be computed, an 409 /// instance of SCEVCouldNotCompute is returned. 410 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 411 ScalarEvolution &SE) const; 412 413 /// Return an expression representing the value of this expression 414 /// one iteration of the loop ahead. 415 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 416 417 /// Methods for support type inquiry through isa, cast, and dyn_cast: 418 static bool classof(const SCEV *S) { 419 return S->getSCEVType() == scAddRecExpr; 420 } 421 }; 422 423 /// This node is the base class min/max selections. 424 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 425 friend class ScalarEvolution; 426 427 static bool isMinMaxType(enum SCEVTypes T) { 428 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 429 T == scUMinExpr; 430 } 431 432 protected: 433 /// Note: Constructing subclasses via this constructor is allowed 434 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 435 const SCEV *const *O, size_t N) 436 : SCEVCommutativeExpr(ID, T, O, N) { 437 assert(isMinMaxType(T)); 438 // Min and max never overflow 439 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 440 } 441 442 public: 443 Type *getType() const { return getOperand(0)->getType(); } 444 445 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 446 447 static enum SCEVTypes negate(enum SCEVTypes T) { 448 switch (T) { 449 case scSMaxExpr: 450 return scSMinExpr; 451 case scSMinExpr: 452 return scSMaxExpr; 453 case scUMaxExpr: 454 return scUMinExpr; 455 case scUMinExpr: 456 return scUMaxExpr; 457 default: 458 llvm_unreachable("Not a min or max SCEV type!"); 459 } 460 } 461 }; 462 463 /// This class represents a signed maximum selection. 464 class SCEVSMaxExpr : public SCEVMinMaxExpr { 465 friend class ScalarEvolution; 466 467 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 468 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 469 470 public: 471 /// Methods for support type inquiry through isa, cast, and dyn_cast: 472 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 473 }; 474 475 /// This class represents an unsigned maximum selection. 476 class SCEVUMaxExpr : public SCEVMinMaxExpr { 477 friend class ScalarEvolution; 478 479 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 480 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 481 482 public: 483 /// Methods for support type inquiry through isa, cast, and dyn_cast: 484 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 485 }; 486 487 /// This class represents a signed minimum selection. 488 class SCEVSMinExpr : public SCEVMinMaxExpr { 489 friend class ScalarEvolution; 490 491 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 492 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 493 494 public: 495 /// Methods for support type inquiry through isa, cast, and dyn_cast: 496 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 497 }; 498 499 /// This class represents an unsigned minimum selection. 500 class SCEVUMinExpr : public SCEVMinMaxExpr { 501 friend class ScalarEvolution; 502 503 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 504 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 505 506 public: 507 /// Methods for support type inquiry through isa, cast, and dyn_cast: 508 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 509 }; 510 511 /// This node is the base class for sequential/in-order min/max selections. 512 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 513 /// are early-returning upon reaching saturation point. 514 /// I.e. given `0 umin_seq poison`, the result will be `0`, 515 /// while the result of `0 umin poison` is `poison`. 516 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 517 friend class ScalarEvolution; 518 519 static bool isSequentialMinMaxType(enum SCEVTypes T) { 520 return T == scSequentialUMinExpr; 521 } 522 523 /// Set flags for a non-recurrence without clearing previously set flags. 524 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 525 526 protected: 527 /// Note: Constructing subclasses via this constructor is allowed 528 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 529 const SCEV *const *O, size_t N) 530 : SCEVNAryExpr(ID, T, O, N) { 531 assert(isSequentialMinMaxType(T)); 532 // Min and max never overflow 533 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 534 } 535 536 public: 537 Type *getType() const { return getOperand(0)->getType(); } 538 539 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 540 assert(isSequentialMinMaxType(Ty)); 541 switch (Ty) { 542 case scSequentialUMinExpr: 543 return scUMinExpr; 544 default: 545 llvm_unreachable("Not a sequential min/max type."); 546 } 547 } 548 549 SCEVTypes getEquivalentNonSequentialSCEVType() const { 550 return getEquivalentNonSequentialSCEVType(getSCEVType()); 551 } 552 553 static bool classof(const SCEV *S) { 554 return isSequentialMinMaxType(S->getSCEVType()); 555 } 556 }; 557 558 /// This class represents a sequential/in-order unsigned minimum selection. 559 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 560 friend class ScalarEvolution; 561 562 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 563 size_t N) 564 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 565 566 public: 567 /// Methods for support type inquiry through isa, cast, and dyn_cast: 568 static bool classof(const SCEV *S) { 569 return S->getSCEVType() == scSequentialUMinExpr; 570 } 571 }; 572 573 /// This means that we are dealing with an entirely unknown SCEV 574 /// value, and only represent it as its LLVM Value. This is the 575 /// "bottom" value for the analysis. 576 class SCEVUnknown final : public SCEV, private CallbackVH { 577 friend class ScalarEvolution; 578 579 /// The parent ScalarEvolution value. This is used to update the 580 /// parent's maps when the value associated with a SCEVUnknown is 581 /// deleted or RAUW'd. 582 ScalarEvolution *SE; 583 584 /// The next pointer in the linked list of all SCEVUnknown 585 /// instances owned by a ScalarEvolution. 586 SCEVUnknown *Next; 587 588 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 589 SCEVUnknown *next) 590 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 591 592 // Implement CallbackVH. 593 void deleted() override; 594 void allUsesReplacedWith(Value *New) override; 595 596 public: 597 Value *getValue() const { return getValPtr(); } 598 599 Type *getType() const { return getValPtr()->getType(); } 600 601 /// Methods for support type inquiry through isa, cast, and dyn_cast: 602 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 603 }; 604 605 /// This class defines a simple visitor class that may be used for 606 /// various SCEV analysis purposes. 607 template <typename SC, typename RetVal = void> struct SCEVVisitor { 608 RetVal visit(const SCEV *S) { 609 switch (S->getSCEVType()) { 610 case scConstant: 611 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 612 case scVScale: 613 return ((SC *)this)->visitVScale((const SCEVVScale *)S); 614 case scPtrToInt: 615 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 616 case scTruncate: 617 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 618 case scZeroExtend: 619 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 620 case scSignExtend: 621 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 622 case scAddExpr: 623 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 624 case scMulExpr: 625 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 626 case scUDivExpr: 627 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 628 case scAddRecExpr: 629 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 630 case scSMaxExpr: 631 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 632 case scUMaxExpr: 633 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 634 case scSMinExpr: 635 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 636 case scUMinExpr: 637 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 638 case scSequentialUMinExpr: 639 return ((SC *)this) 640 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 641 case scUnknown: 642 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 643 case scCouldNotCompute: 644 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 645 } 646 llvm_unreachable("Unknown SCEV kind!"); 647 } 648 649 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 650 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 651 } 652 }; 653 654 /// Visit all nodes in the expression tree using worklist traversal. 655 /// 656 /// Visitor implements: 657 /// // return true to follow this node. 658 /// bool follow(const SCEV *S); 659 /// // return true to terminate the search. 660 /// bool isDone(); 661 template <typename SV> class SCEVTraversal { 662 SV &Visitor; 663 SmallVector<const SCEV *, 8> Worklist; 664 SmallPtrSet<const SCEV *, 8> Visited; 665 666 void push(const SCEV *S) { 667 if (Visited.insert(S).second && Visitor.follow(S)) 668 Worklist.push_back(S); 669 } 670 671 public: 672 SCEVTraversal(SV &V) : Visitor(V) {} 673 674 void visitAll(const SCEV *Root) { 675 push(Root); 676 while (!Worklist.empty() && !Visitor.isDone()) { 677 const SCEV *S = Worklist.pop_back_val(); 678 679 switch (S->getSCEVType()) { 680 case scConstant: 681 case scVScale: 682 case scUnknown: 683 continue; 684 case scPtrToInt: 685 case scTruncate: 686 case scZeroExtend: 687 case scSignExtend: 688 case scAddExpr: 689 case scMulExpr: 690 case scUDivExpr: 691 case scSMaxExpr: 692 case scUMaxExpr: 693 case scSMinExpr: 694 case scUMinExpr: 695 case scSequentialUMinExpr: 696 case scAddRecExpr: 697 for (const auto *Op : S->operands()) { 698 push(Op); 699 if (Visitor.isDone()) 700 break; 701 } 702 continue; 703 case scCouldNotCompute: 704 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 705 } 706 llvm_unreachable("Unknown SCEV kind!"); 707 } 708 } 709 }; 710 711 /// Use SCEVTraversal to visit all nodes in the given expression tree. 712 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 713 SCEVTraversal<SV> T(Visitor); 714 T.visitAll(Root); 715 } 716 717 /// Return true if any node in \p Root satisfies the predicate \p Pred. 718 template <typename PredTy> 719 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 720 struct FindClosure { 721 bool Found = false; 722 PredTy Pred; 723 724 FindClosure(PredTy Pred) : Pred(Pred) {} 725 726 bool follow(const SCEV *S) { 727 if (!Pred(S)) 728 return true; 729 730 Found = true; 731 return false; 732 } 733 734 bool isDone() const { return Found; } 735 }; 736 737 FindClosure FC(Pred); 738 visitAll(Root, FC); 739 return FC.Found; 740 } 741 742 /// This visitor recursively visits a SCEV expression and re-writes it. 743 /// The result from each visit is cached, so it will return the same 744 /// SCEV for the same input. 745 template <typename SC> 746 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 747 protected: 748 ScalarEvolution &SE; 749 // Memoize the result of each visit so that we only compute once for 750 // the same input SCEV. This is to avoid redundant computations when 751 // a SCEV is referenced by multiple SCEVs. Without memoization, this 752 // visit algorithm would have exponential time complexity in the worst 753 // case, causing the compiler to hang on certain tests. 754 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults; 755 756 public: 757 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 758 759 const SCEV *visit(const SCEV *S) { 760 auto It = RewriteResults.find(S); 761 if (It != RewriteResults.end()) 762 return It->second; 763 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 764 auto Result = RewriteResults.try_emplace(S, Visited); 765 assert(Result.second && "Should insert a new entry"); 766 return Result.first->second; 767 } 768 769 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 770 771 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } 772 773 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 774 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 775 return Operand == Expr->getOperand() 776 ? Expr 777 : SE.getPtrToIntExpr(Operand, Expr->getType()); 778 } 779 780 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 781 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 782 return Operand == Expr->getOperand() 783 ? Expr 784 : SE.getTruncateExpr(Operand, Expr->getType()); 785 } 786 787 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 788 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 789 return Operand == Expr->getOperand() 790 ? Expr 791 : SE.getZeroExtendExpr(Operand, Expr->getType()); 792 } 793 794 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 795 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 796 return Operand == Expr->getOperand() 797 ? Expr 798 : SE.getSignExtendExpr(Operand, Expr->getType()); 799 } 800 801 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 802 SmallVector<const SCEV *, 2> Operands; 803 bool Changed = false; 804 for (const auto *Op : Expr->operands()) { 805 Operands.push_back(((SC *)this)->visit(Op)); 806 Changed |= Op != Operands.back(); 807 } 808 return !Changed ? Expr : SE.getAddExpr(Operands); 809 } 810 811 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 812 SmallVector<const SCEV *, 2> Operands; 813 bool Changed = false; 814 for (const auto *Op : Expr->operands()) { 815 Operands.push_back(((SC *)this)->visit(Op)); 816 Changed |= Op != Operands.back(); 817 } 818 return !Changed ? Expr : SE.getMulExpr(Operands); 819 } 820 821 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 822 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 823 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 824 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 825 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 826 } 827 828 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 829 SmallVector<const SCEV *, 2> Operands; 830 bool Changed = false; 831 for (const auto *Op : Expr->operands()) { 832 Operands.push_back(((SC *)this)->visit(Op)); 833 Changed |= Op != Operands.back(); 834 } 835 return !Changed ? Expr 836 : SE.getAddRecExpr(Operands, Expr->getLoop(), 837 Expr->getNoWrapFlags()); 838 } 839 840 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 841 SmallVector<const SCEV *, 2> Operands; 842 bool Changed = false; 843 for (const auto *Op : Expr->operands()) { 844 Operands.push_back(((SC *)this)->visit(Op)); 845 Changed |= Op != Operands.back(); 846 } 847 return !Changed ? Expr : SE.getSMaxExpr(Operands); 848 } 849 850 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 851 SmallVector<const SCEV *, 2> Operands; 852 bool Changed = false; 853 for (const auto *Op : Expr->operands()) { 854 Operands.push_back(((SC *)this)->visit(Op)); 855 Changed |= Op != Operands.back(); 856 } 857 return !Changed ? Expr : SE.getUMaxExpr(Operands); 858 } 859 860 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 861 SmallVector<const SCEV *, 2> Operands; 862 bool Changed = false; 863 for (const auto *Op : Expr->operands()) { 864 Operands.push_back(((SC *)this)->visit(Op)); 865 Changed |= Op != Operands.back(); 866 } 867 return !Changed ? Expr : SE.getSMinExpr(Operands); 868 } 869 870 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 871 SmallVector<const SCEV *, 2> Operands; 872 bool Changed = false; 873 for (const auto *Op : Expr->operands()) { 874 Operands.push_back(((SC *)this)->visit(Op)); 875 Changed |= Op != Operands.back(); 876 } 877 return !Changed ? Expr : SE.getUMinExpr(Operands); 878 } 879 880 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 881 SmallVector<const SCEV *, 2> Operands; 882 bool Changed = false; 883 for (const auto *Op : Expr->operands()) { 884 Operands.push_back(((SC *)this)->visit(Op)); 885 Changed |= Op != Operands.back(); 886 } 887 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 888 } 889 890 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 891 892 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 893 return Expr; 894 } 895 }; 896 897 using ValueToValueMap = DenseMap<const Value *, Value *>; 898 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 899 900 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 901 /// the SCEVUnknown components following the Map (Value -> SCEV). 902 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 903 public: 904 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 905 ValueToSCEVMapTy &Map) { 906 SCEVParameterRewriter Rewriter(SE, Map); 907 return Rewriter.visit(Scev); 908 } 909 910 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 911 : SCEVRewriteVisitor(SE), Map(M) {} 912 913 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 914 auto I = Map.find(Expr->getValue()); 915 if (I == Map.end()) 916 return Expr; 917 return I->second; 918 } 919 920 private: 921 ValueToSCEVMapTy ⤅ 922 }; 923 924 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 925 926 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 927 /// the Map (Loop -> SCEV) to all AddRecExprs. 928 class SCEVLoopAddRecRewriter 929 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 930 public: 931 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 932 : SCEVRewriteVisitor(SE), Map(M) {} 933 934 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 935 ScalarEvolution &SE) { 936 SCEVLoopAddRecRewriter Rewriter(SE, Map); 937 return Rewriter.visit(Scev); 938 } 939 940 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 941 SmallVector<const SCEV *, 2> Operands; 942 for (const SCEV *Op : Expr->operands()) 943 Operands.push_back(visit(Op)); 944 945 const Loop *L = Expr->getLoop(); 946 if (0 == Map.count(L)) 947 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 948 949 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 950 } 951 952 private: 953 LoopToScevMapT ⤅ 954 }; 955 956 } // end namespace llvm 957 958 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 959