xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/ValueHandle.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27 
28 namespace llvm {
29 
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37 
38 enum SCEVTypes : unsigned short {
39   // These should be ordered in terms of increasing complexity to make the
40   // folders simpler.
41   scConstant,
42   scVScale,
43   scTruncate,
44   scZeroExtend,
45   scSignExtend,
46   scAddExpr,
47   scMulExpr,
48   scUDivExpr,
49   scAddRecExpr,
50   scUMaxExpr,
51   scSMaxExpr,
52   scUMinExpr,
53   scSMinExpr,
54   scSequentialUMinExpr,
55   scPtrToInt,
56   scUnknown,
57   scCouldNotCompute
58 };
59 
60 /// This class represents a constant integer value.
61 class SCEVConstant : public SCEV {
62   friend class ScalarEvolution;
63 
64   ConstantInt *V;
65 
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)66   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
67       : SCEV(ID, scConstant, 1), V(v) {}
68 
69 public:
getValue()70   ConstantInt *getValue() const { return V; }
getAPInt()71   const APInt &getAPInt() const { return getValue()->getValue(); }
72 
getType()73   Type *getType() const { return V->getType(); }
74 
75   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)76   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
77 };
78 
79 /// This class represents the value of vscale, as used when defining the length
80 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
81 class SCEVVScale : public SCEV {
82   friend class ScalarEvolution;
83 
SCEVVScale(const FoldingSetNodeIDRef ID,Type * ty)84   SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
85       : SCEV(ID, scVScale, 0), Ty(ty) {}
86 
87   Type *Ty;
88 
89 public:
getType()90   Type *getType() const { return Ty; }
91 
92   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)93   static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
94 };
95 
computeExpressionSize(ArrayRef<const SCEV * > Args)96 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
97   APInt Size(16, 1);
98   for (const auto *Arg : Args)
99     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
100   return (unsigned short)Size.getZExtValue();
101 }
102 
103 /// This is the base class for unary cast operator classes.
104 class SCEVCastExpr : public SCEV {
105 protected:
106   const SCEV *Op;
107   Type *Ty;
108 
109   LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
110                         const SCEV *op, Type *ty);
111 
112 public:
getOperand()113   const SCEV *getOperand() const { return Op; }
getOperand(unsigned i)114   const SCEV *getOperand(unsigned i) const {
115     assert(i == 0 && "Operand index out of range!");
116     return Op;
117   }
operands()118   ArrayRef<const SCEV *> operands() const { return Op; }
getNumOperands()119   size_t getNumOperands() const { return 1; }
getType()120   Type *getType() const { return Ty; }
121 
122   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)123   static bool classof(const SCEV *S) {
124     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
125            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
126   }
127 };
128 
129 /// This class represents a cast from a pointer to a pointer-sized integer
130 /// value.
131 class SCEVPtrToIntExpr : public SCEVCastExpr {
132   friend class ScalarEvolution;
133 
134   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
135 
136 public:
137   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)138   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
139 };
140 
141 /// This is the base class for unary integral cast operator classes.
142 class SCEVIntegralCastExpr : public SCEVCastExpr {
143 protected:
144   LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
145                                 const SCEV *op, Type *ty);
146 
147 public:
148   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)149   static bool classof(const SCEV *S) {
150     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
151            S->getSCEVType() == scSignExtend;
152   }
153 };
154 
155 /// This class represents a truncation of an integer value to a
156 /// smaller integer value.
157 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
158   friend class ScalarEvolution;
159 
160   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
161 
162 public:
163   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)164   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
165 };
166 
167 /// This class represents a zero extension of a small integer value
168 /// to a larger integer value.
169 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
170   friend class ScalarEvolution;
171 
172   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
173 
174 public:
175   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)176   static bool classof(const SCEV *S) {
177     return S->getSCEVType() == scZeroExtend;
178   }
179 };
180 
181 /// This class represents a sign extension of a small integer value
182 /// to a larger integer value.
183 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
184   friend class ScalarEvolution;
185 
186   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
187 
188 public:
189   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)190   static bool classof(const SCEV *S) {
191     return S->getSCEVType() == scSignExtend;
192   }
193 };
194 
195 /// This node is a base class providing common functionality for
196 /// n'ary operators.
197 class SCEVNAryExpr : public SCEV {
198 protected:
199   // Since SCEVs are immutable, ScalarEvolution allocates operand
200   // arrays with its SCEVAllocator, so this class just needs a simple
201   // pointer rather than a more elaborate vector-like data structure.
202   // This also avoids the need for a non-trivial destructor.
203   const SCEV *const *Operands;
204   size_t NumOperands;
205 
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)206   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
207                const SCEV *const *O, size_t N)
208       : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
209         NumOperands(N) {}
210 
211 public:
getNumOperands()212   size_t getNumOperands() const { return NumOperands; }
213 
getOperand(unsigned i)214   const SCEV *getOperand(unsigned i) const {
215     assert(i < NumOperands && "Operand index out of range!");
216     return Operands[i];
217   }
218 
operands()219   ArrayRef<const SCEV *> operands() const {
220     return ArrayRef(Operands, NumOperands);
221   }
222 
223   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
224     return (NoWrapFlags)(SubclassData & Mask);
225   }
226 
hasNoUnsignedWrap()227   bool hasNoUnsignedWrap() const {
228     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
229   }
230 
hasNoSignedWrap()231   bool hasNoSignedWrap() const {
232     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
233   }
234 
hasNoSelfWrap()235   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
236 
237   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)238   static bool classof(const SCEV *S) {
239     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
240            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
241            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
242            S->getSCEVType() == scSequentialUMinExpr ||
243            S->getSCEVType() == scAddRecExpr;
244   }
245 };
246 
247 /// This node is the base class for n'ary commutative operators.
248 class SCEVCommutativeExpr : public SCEVNAryExpr {
249 protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)250   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
251                       const SCEV *const *O, size_t N)
252       : SCEVNAryExpr(ID, T, O, N) {}
253 
254 public:
255   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)256   static bool classof(const SCEV *S) {
257     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
258            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
259            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
260   }
261 
262   /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)263   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
264 };
265 
266 /// This node represents an addition of some number of SCEVs.
267 class SCEVAddExpr : public SCEVCommutativeExpr {
268   friend class ScalarEvolution;
269 
270   Type *Ty;
271 
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)272   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
273       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
274     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
275       return Op->getType()->isPointerTy();
276     });
277     if (FirstPointerTypedOp != operands().end())
278       Ty = (*FirstPointerTypedOp)->getType();
279     else
280       Ty = getOperand(0)->getType();
281   }
282 
283 public:
getType()284   Type *getType() const { return Ty; }
285 
286   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)287   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
288 };
289 
290 /// This node represents multiplication of some number of SCEVs.
291 class SCEVMulExpr : public SCEVCommutativeExpr {
292   friend class ScalarEvolution;
293 
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)294   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
295       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
296 
297 public:
getType()298   Type *getType() const { return getOperand(0)->getType(); }
299 
300   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)301   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
302 };
303 
304 /// This class represents a binary unsigned division operation.
305 class SCEVUDivExpr : public SCEV {
306   friend class ScalarEvolution;
307 
308   std::array<const SCEV *, 2> Operands;
309 
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)310   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
311       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
312     Operands[0] = lhs;
313     Operands[1] = rhs;
314   }
315 
316 public:
getLHS()317   const SCEV *getLHS() const { return Operands[0]; }
getRHS()318   const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()319   size_t getNumOperands() const { return 2; }
getOperand(unsigned i)320   const SCEV *getOperand(unsigned i) const {
321     assert((i == 0 || i == 1) && "Operand index out of range!");
322     return i == 0 ? getLHS() : getRHS();
323   }
324 
operands()325   ArrayRef<const SCEV *> operands() const { return Operands; }
326 
getType()327   Type *getType() const {
328     // In most cases the types of LHS and RHS will be the same, but in some
329     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
330     // depend on the type for correctness, but handling types carefully can
331     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
332     // a pointer type than the RHS, so use the RHS' type here.
333     return getRHS()->getType();
334   }
335 
336   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)337   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
338 };
339 
340 /// This node represents a polynomial recurrence on the trip count
341 /// of the specified loop.  This is the primary focus of the
342 /// ScalarEvolution framework; all the other SCEV subclasses are
343 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
344 /// expressions to be created and analyzed.
345 ///
346 /// All operands of an AddRec are required to be loop invariant.
347 ///
348 class SCEVAddRecExpr : public SCEVNAryExpr {
349   friend class ScalarEvolution;
350 
351   const Loop *L;
352 
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)353   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
354                  const Loop *l)
355       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
356 
357 public:
getType()358   Type *getType() const { return getStart()->getType(); }
getStart()359   const SCEV *getStart() const { return Operands[0]; }
getLoop()360   const Loop *getLoop() const { return L; }
361 
362   /// Constructs and returns the recurrence indicating how much this
363   /// expression steps by.  If this is a polynomial of degree N, it
364   /// returns a chrec of degree N-1.  We cannot determine whether
365   /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)366   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
367     if (isAffine())
368       return getOperand(1);
369     return SE.getAddRecExpr(
370         SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
371         FlagAnyWrap);
372   }
373 
374   /// Return true if this represents an expression A + B*x where A
375   /// and B are loop invariant values.
isAffine()376   bool isAffine() const {
377     // We know that the start value is invariant.  This expression is thus
378     // affine iff the step is also invariant.
379     return getNumOperands() == 2;
380   }
381 
382   /// Return true if this represents an expression A + B*x + C*x^2
383   /// where A, B and C are loop invariant values.  This corresponds
384   /// to an addrec of the form {L,+,M,+,N}
isQuadratic()385   bool isQuadratic() const { return getNumOperands() == 3; }
386 
387   /// Set flags for a recurrence without clearing any previously set flags.
388   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
389   /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)390   void setNoWrapFlags(NoWrapFlags Flags) {
391     if (Flags & (FlagNUW | FlagNSW))
392       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
393     SubclassData |= Flags;
394   }
395 
396   /// Return the value of this chain of recurrences at the specified
397   /// iteration number.
398   LLVM_ABI const SCEV *evaluateAtIteration(const SCEV *It,
399                                            ScalarEvolution &SE) const;
400 
401   /// Return the value of this chain of recurrences at the specified iteration
402   /// number. Takes an explicit list of operands to represent an AddRec.
403   LLVM_ABI static const SCEV *
404   evaluateAtIteration(ArrayRef<const SCEV *> Operands, const SCEV *It,
405                       ScalarEvolution &SE);
406 
407   /// Return the number of iterations of this loop that produce
408   /// values in the specified constant range.  Another way of
409   /// looking at this is that it returns the first iteration number
410   /// where the value is not in the condition, thus computing the
411   /// exit count.  If the iteration count can't be computed, an
412   /// instance of SCEVCouldNotCompute is returned.
413   LLVM_ABI const SCEV *getNumIterationsInRange(const ConstantRange &Range,
414                                                ScalarEvolution &SE) const;
415 
416   /// Return an expression representing the value of this expression
417   /// one iteration of the loop ahead.
418   LLVM_ABI const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
419 
420   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)421   static bool classof(const SCEV *S) {
422     return S->getSCEVType() == scAddRecExpr;
423   }
424 };
425 
426 /// This node is the base class min/max selections.
427 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
428   friend class ScalarEvolution;
429 
isMinMaxType(enum SCEVTypes T)430   static bool isMinMaxType(enum SCEVTypes T) {
431     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
432            T == scUMinExpr;
433   }
434 
435 protected:
436   /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)437   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
438                  const SCEV *const *O, size_t N)
439       : SCEVCommutativeExpr(ID, T, O, N) {
440     assert(isMinMaxType(T));
441     // Min and max never overflow
442     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
443   }
444 
445 public:
getType()446   Type *getType() const { return getOperand(0)->getType(); }
447 
classof(const SCEV * S)448   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
449 
negate(enum SCEVTypes T)450   static enum SCEVTypes negate(enum SCEVTypes T) {
451     switch (T) {
452     case scSMaxExpr:
453       return scSMinExpr;
454     case scSMinExpr:
455       return scSMaxExpr;
456     case scUMaxExpr:
457       return scUMinExpr;
458     case scUMinExpr:
459       return scUMaxExpr;
460     default:
461       llvm_unreachable("Not a min or max SCEV type!");
462     }
463   }
464 };
465 
466 /// This class represents a signed maximum selection.
467 class SCEVSMaxExpr : public SCEVMinMaxExpr {
468   friend class ScalarEvolution;
469 
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)470   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
471       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
472 
473 public:
474   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)475   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
476 };
477 
478 /// This class represents an unsigned maximum selection.
479 class SCEVUMaxExpr : public SCEVMinMaxExpr {
480   friend class ScalarEvolution;
481 
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)482   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
483       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
484 
485 public:
486   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)487   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
488 };
489 
490 /// This class represents a signed minimum selection.
491 class SCEVSMinExpr : public SCEVMinMaxExpr {
492   friend class ScalarEvolution;
493 
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)494   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
495       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
496 
497 public:
498   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)499   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
500 };
501 
502 /// This class represents an unsigned minimum selection.
503 class SCEVUMinExpr : public SCEVMinMaxExpr {
504   friend class ScalarEvolution;
505 
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)506   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
507       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
508 
509 public:
510   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)511   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
512 };
513 
514 /// This node is the base class for sequential/in-order min/max selections.
515 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
516 /// are early-returning upon reaching saturation point.
517 /// I.e. given `0 umin_seq poison`, the result will be `0`, while the result of
518 /// `0 umin poison` is `poison`. When returning early, later expressions are not
519 /// executed, so `0 umin_seq (%x u/ 0)` does not result in undefined behavior.
520 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
521   friend class ScalarEvolution;
522 
isSequentialMinMaxType(enum SCEVTypes T)523   static bool isSequentialMinMaxType(enum SCEVTypes T) {
524     return T == scSequentialUMinExpr;
525   }
526 
527   /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)528   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
529 
530 protected:
531   /// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)532   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
533                            const SCEV *const *O, size_t N)
534       : SCEVNAryExpr(ID, T, O, N) {
535     assert(isSequentialMinMaxType(T));
536     // Min and max never overflow
537     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
538   }
539 
540 public:
getType()541   Type *getType() const { return getOperand(0)->getType(); }
542 
getEquivalentNonSequentialSCEVType(SCEVTypes Ty)543   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
544     assert(isSequentialMinMaxType(Ty));
545     switch (Ty) {
546     case scSequentialUMinExpr:
547       return scUMinExpr;
548     default:
549       llvm_unreachable("Not a sequential min/max type.");
550     }
551   }
552 
getEquivalentNonSequentialSCEVType()553   SCEVTypes getEquivalentNonSequentialSCEVType() const {
554     return getEquivalentNonSequentialSCEVType(getSCEVType());
555   }
556 
classof(const SCEV * S)557   static bool classof(const SCEV *S) {
558     return isSequentialMinMaxType(S->getSCEVType());
559   }
560 };
561 
562 /// This class represents a sequential/in-order unsigned minimum selection.
563 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
564   friend class ScalarEvolution;
565 
SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)566   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
567                          size_t N)
568       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
569 
570 public:
571   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)572   static bool classof(const SCEV *S) {
573     return S->getSCEVType() == scSequentialUMinExpr;
574   }
575 };
576 
577 /// This means that we are dealing with an entirely unknown SCEV
578 /// value, and only represent it as its LLVM Value.  This is the
579 /// "bottom" value for the analysis.
580 class LLVM_ABI SCEVUnknown final : public SCEV, private CallbackVH {
581   friend class ScalarEvolution;
582 
583   /// The parent ScalarEvolution value. This is used to update the
584   /// parent's maps when the value associated with a SCEVUnknown is
585   /// deleted or RAUW'd.
586   ScalarEvolution *SE;
587 
588   /// The next pointer in the linked list of all SCEVUnknown
589   /// instances owned by a ScalarEvolution.
590   SCEVUnknown *Next;
591 
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)592   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
593               SCEVUnknown *next)
594       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
595 
596   // Implement CallbackVH.
597   void deleted() override;
598   void allUsesReplacedWith(Value *New) override;
599 
600 public:
getValue()601   Value *getValue() const { return getValPtr(); }
602 
getType()603   Type *getType() const { return getValPtr()->getType(); }
604 
605   /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)606   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
607 };
608 
609 /// This class defines a simple visitor class that may be used for
610 /// various SCEV analysis purposes.
611 template <typename SC, typename RetVal = void> struct SCEVVisitor {
visitSCEVVisitor612   RetVal visit(const SCEV *S) {
613     switch (S->getSCEVType()) {
614     case scConstant:
615       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
616     case scVScale:
617       return ((SC *)this)->visitVScale((const SCEVVScale *)S);
618     case scPtrToInt:
619       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
620     case scTruncate:
621       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
622     case scZeroExtend:
623       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
624     case scSignExtend:
625       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
626     case scAddExpr:
627       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
628     case scMulExpr:
629       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
630     case scUDivExpr:
631       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
632     case scAddRecExpr:
633       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
634     case scSMaxExpr:
635       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
636     case scUMaxExpr:
637       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
638     case scSMinExpr:
639       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
640     case scUMinExpr:
641       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
642     case scSequentialUMinExpr:
643       return ((SC *)this)
644           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
645     case scUnknown:
646       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
647     case scCouldNotCompute:
648       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
649     }
650     llvm_unreachable("Unknown SCEV kind!");
651   }
652 
visitCouldNotComputeSCEVVisitor653   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
654     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
655   }
656 };
657 
658 /// Visit all nodes in the expression tree using worklist traversal.
659 ///
660 /// Visitor implements:
661 ///   // return true to follow this node.
662 ///   bool follow(const SCEV *S);
663 ///   // return true to terminate the search.
664 ///   bool isDone();
665 template <typename SV> class SCEVTraversal {
666   SV &Visitor;
667   SmallVector<const SCEV *, 8> Worklist;
668   SmallPtrSet<const SCEV *, 8> Visited;
669 
push(const SCEV * S)670   void push(const SCEV *S) {
671     if (Visited.insert(S).second && Visitor.follow(S))
672       Worklist.push_back(S);
673   }
674 
675 public:
SCEVTraversal(SV & V)676   SCEVTraversal(SV &V) : Visitor(V) {}
677 
visitAll(const SCEV * Root)678   void visitAll(const SCEV *Root) {
679     push(Root);
680     while (!Worklist.empty() && !Visitor.isDone()) {
681       const SCEV *S = Worklist.pop_back_val();
682 
683       switch (S->getSCEVType()) {
684       case scConstant:
685       case scVScale:
686       case scUnknown:
687         continue;
688       case scPtrToInt:
689       case scTruncate:
690       case scZeroExtend:
691       case scSignExtend:
692       case scAddExpr:
693       case scMulExpr:
694       case scUDivExpr:
695       case scSMaxExpr:
696       case scUMaxExpr:
697       case scSMinExpr:
698       case scUMinExpr:
699       case scSequentialUMinExpr:
700       case scAddRecExpr:
701         for (const auto *Op : S->operands()) {
702           push(Op);
703           if (Visitor.isDone())
704             break;
705         }
706         continue;
707       case scCouldNotCompute:
708         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
709       }
710       llvm_unreachable("Unknown SCEV kind!");
711     }
712   }
713 };
714 
715 /// Use SCEVTraversal to visit all nodes in the given expression tree.
visitAll(const SCEV * Root,SV & Visitor)716 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
717   SCEVTraversal<SV> T(Visitor);
718   T.visitAll(Root);
719 }
720 
721 /// Return true if any node in \p Root satisfies the predicate \p Pred.
722 template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)723 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
724   struct FindClosure {
725     bool Found = false;
726     PredTy Pred;
727 
728     FindClosure(PredTy Pred) : Pred(Pred) {}
729 
730     bool follow(const SCEV *S) {
731       if (!Pred(S))
732         return true;
733 
734       Found = true;
735       return false;
736     }
737 
738     bool isDone() const { return Found; }
739   };
740 
741   FindClosure FC(Pred);
742   visitAll(Root, FC);
743   return FC.Found;
744 }
745 
746 /// This visitor recursively visits a SCEV expression and re-writes it.
747 /// The result from each visit is cached, so it will return the same
748 /// SCEV for the same input.
749 template <typename SC>
750 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
751 protected:
752   ScalarEvolution &SE;
753   // Memoize the result of each visit so that we only compute once for
754   // the same input SCEV. This is to avoid redundant computations when
755   // a SCEV is referenced by multiple SCEVs. Without memoization, this
756   // visit algorithm would have exponential time complexity in the worst
757   // case, causing the compiler to hang on certain tests.
758   SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
759 
760 public:
SCEVRewriteVisitor(ScalarEvolution & SE)761   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
762 
visit(const SCEV * S)763   const SCEV *visit(const SCEV *S) {
764     auto It = RewriteResults.find(S);
765     if (It != RewriteResults.end())
766       return It->second;
767     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
768     auto Result = RewriteResults.try_emplace(S, Visited);
769     assert(Result.second && "Should insert a new entry");
770     return Result.first->second;
771   }
772 
visitConstant(const SCEVConstant * Constant)773   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
774 
visitVScale(const SCEVVScale * VScale)775   const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
776 
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)777   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
778     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
779     return Operand == Expr->getOperand()
780                ? Expr
781                : SE.getPtrToIntExpr(Operand, Expr->getType());
782   }
783 
visitTruncateExpr(const SCEVTruncateExpr * Expr)784   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
785     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
786     return Operand == Expr->getOperand()
787                ? Expr
788                : SE.getTruncateExpr(Operand, Expr->getType());
789   }
790 
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)791   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
792     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
793     return Operand == Expr->getOperand()
794                ? Expr
795                : SE.getZeroExtendExpr(Operand, Expr->getType());
796   }
797 
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)798   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
799     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
800     return Operand == Expr->getOperand()
801                ? Expr
802                : SE.getSignExtendExpr(Operand, Expr->getType());
803   }
804 
visitAddExpr(const SCEVAddExpr * Expr)805   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
806     SmallVector<const SCEV *, 2> Operands;
807     bool Changed = false;
808     for (const auto *Op : Expr->operands()) {
809       Operands.push_back(((SC *)this)->visit(Op));
810       Changed |= Op != Operands.back();
811     }
812     return !Changed ? Expr : SE.getAddExpr(Operands);
813   }
814 
visitMulExpr(const SCEVMulExpr * Expr)815   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
816     SmallVector<const SCEV *, 2> Operands;
817     bool Changed = false;
818     for (const auto *Op : Expr->operands()) {
819       Operands.push_back(((SC *)this)->visit(Op));
820       Changed |= Op != Operands.back();
821     }
822     return !Changed ? Expr : SE.getMulExpr(Operands);
823   }
824 
visitUDivExpr(const SCEVUDivExpr * Expr)825   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
826     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
827     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
828     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
829     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
830   }
831 
visitAddRecExpr(const SCEVAddRecExpr * Expr)832   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
833     SmallVector<const SCEV *, 2> Operands;
834     bool Changed = false;
835     for (const auto *Op : Expr->operands()) {
836       Operands.push_back(((SC *)this)->visit(Op));
837       Changed |= Op != Operands.back();
838     }
839     return !Changed ? Expr
840                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
841                                        Expr->getNoWrapFlags());
842   }
843 
visitSMaxExpr(const SCEVSMaxExpr * Expr)844   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
845     SmallVector<const SCEV *, 2> Operands;
846     bool Changed = false;
847     for (const auto *Op : Expr->operands()) {
848       Operands.push_back(((SC *)this)->visit(Op));
849       Changed |= Op != Operands.back();
850     }
851     return !Changed ? Expr : SE.getSMaxExpr(Operands);
852   }
853 
visitUMaxExpr(const SCEVUMaxExpr * Expr)854   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
855     SmallVector<const SCEV *, 2> Operands;
856     bool Changed = false;
857     for (const auto *Op : Expr->operands()) {
858       Operands.push_back(((SC *)this)->visit(Op));
859       Changed |= Op != Operands.back();
860     }
861     return !Changed ? Expr : SE.getUMaxExpr(Operands);
862   }
863 
visitSMinExpr(const SCEVSMinExpr * Expr)864   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
865     SmallVector<const SCEV *, 2> Operands;
866     bool Changed = false;
867     for (const auto *Op : Expr->operands()) {
868       Operands.push_back(((SC *)this)->visit(Op));
869       Changed |= Op != Operands.back();
870     }
871     return !Changed ? Expr : SE.getSMinExpr(Operands);
872   }
873 
visitUMinExpr(const SCEVUMinExpr * Expr)874   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
875     SmallVector<const SCEV *, 2> Operands;
876     bool Changed = false;
877     for (const auto *Op : Expr->operands()) {
878       Operands.push_back(((SC *)this)->visit(Op));
879       Changed |= Op != Operands.back();
880     }
881     return !Changed ? Expr : SE.getUMinExpr(Operands);
882   }
883 
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)884   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
885     SmallVector<const SCEV *, 2> Operands;
886     bool Changed = false;
887     for (const auto *Op : Expr->operands()) {
888       Operands.push_back(((SC *)this)->visit(Op));
889       Changed |= Op != Operands.back();
890     }
891     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
892   }
893 
visitUnknown(const SCEVUnknown * Expr)894   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
895 
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)896   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
897     return Expr;
898   }
899 };
900 
901 using ValueToValueMap = DenseMap<const Value *, Value *>;
902 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
903 
904 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
905 /// the SCEVUnknown components following the Map (Value -> SCEV).
906 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
907 public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)908   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
909                              ValueToSCEVMapTy &Map) {
910     SCEVParameterRewriter Rewriter(SE, Map);
911     return Rewriter.visit(Scev);
912   }
913 
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)914   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
915       : SCEVRewriteVisitor(SE), Map(M) {}
916 
visitUnknown(const SCEVUnknown * Expr)917   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
918     auto I = Map.find(Expr->getValue());
919     if (I == Map.end())
920       return Expr;
921     return I->second;
922   }
923 
924 private:
925   ValueToSCEVMapTy &Map;
926 };
927 
928 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
929 
930 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
931 /// the Map (Loop -> SCEV) to all AddRecExprs.
932 class SCEVLoopAddRecRewriter
933     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
934 public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)935   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
936       : SCEVRewriteVisitor(SE), Map(M) {}
937 
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)938   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
939                              ScalarEvolution &SE) {
940     SCEVLoopAddRecRewriter Rewriter(SE, Map);
941     return Rewriter.visit(Scev);
942   }
943 
visitAddRecExpr(const SCEVAddRecExpr * Expr)944   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
945     SmallVector<const SCEV *, 2> Operands;
946     for (const SCEV *Op : Expr->operands())
947       Operands.push_back(visit(Op));
948 
949     const Loop *L = Expr->getLoop();
950     auto It = Map.find(L);
951     if (It == Map.end())
952       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
953 
954     return SCEVAddRecExpr::evaluateAtIteration(Operands, It->second, SE);
955   }
956 
957 private:
958   LoopToScevMapT &Map;
959 };
960 
961 } // end namespace llvm
962 
963 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
964