xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/ValueHandle.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <cassert>
25 #include <cstddef>
26 
27 namespace llvm {
28 
29 class APInt;
30 class Constant;
31 class ConstantInt;
32 class ConstantRange;
33 class Loop;
34 class Type;
35 class Value;
36 
37 enum SCEVTypes : unsigned short {
38   // These should be ordered in terms of increasing complexity to make the
39   // folders simpler.
40   scConstant,
41   scVScale,
42   scTruncate,
43   scZeroExtend,
44   scSignExtend,
45   scAddExpr,
46   scMulExpr,
47   scUDivExpr,
48   scAddRecExpr,
49   scUMaxExpr,
50   scSMaxExpr,
51   scUMinExpr,
52   scSMinExpr,
53   scSequentialUMinExpr,
54   scPtrToInt,
55   scUnknown,
56   scCouldNotCompute
57 };
58 
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61   friend class ScalarEvolution;
62 
63   ConstantInt *V;
64 
65   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66       : SCEV(ID, scConstant, 1), V(v) {}
67 
68 public:
69   ConstantInt *getValue() const { return V; }
70   const APInt &getAPInt() const { return getValue()->getValue(); }
71 
72   Type *getType() const { return V->getType(); }
73 
74   /// Methods for support type inquiry through isa, cast, and dyn_cast:
75   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77 
78 /// This class represents the value of vscale, as used when defining the length
79 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
80 class SCEVVScale : public SCEV {
81   friend class ScalarEvolution;
82 
83   SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
84       : SCEV(ID, scVScale, 0), Ty(ty) {}
85 
86   Type *Ty;
87 
88 public:
89   Type *getType() const { return Ty; }
90 
91   /// Methods for support type inquiry through isa, cast, and dyn_cast:
92   static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
93 };
94 
95 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
96   APInt Size(16, 1);
97   for (const auto *Arg : Args)
98     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
99   return (unsigned short)Size.getZExtValue();
100 }
101 
102 /// This is the base class for unary cast operator classes.
103 class SCEVCastExpr : public SCEV {
104 protected:
105   const SCEV *Op;
106   Type *Ty;
107 
108   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
109                Type *ty);
110 
111 public:
112   const SCEV *getOperand() const { return Op; }
113   const SCEV *getOperand(unsigned i) const {
114     assert(i == 0 && "Operand index out of range!");
115     return Op;
116   }
117   ArrayRef<const SCEV *> operands() const { return Op; }
118   size_t getNumOperands() const { return 1; }
119   Type *getType() const { return Ty; }
120 
121   /// Methods for support type inquiry through isa, cast, and dyn_cast:
122   static bool classof(const SCEV *S) {
123     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
124            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
125   }
126 };
127 
128 /// This class represents a cast from a pointer to a pointer-sized integer
129 /// value.
130 class SCEVPtrToIntExpr : public SCEVCastExpr {
131   friend class ScalarEvolution;
132 
133   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
134 
135 public:
136   /// Methods for support type inquiry through isa, cast, and dyn_cast:
137   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
138 };
139 
140 /// This is the base class for unary integral cast operator classes.
141 class SCEVIntegralCastExpr : public SCEVCastExpr {
142 protected:
143   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
144                        const SCEV *op, Type *ty);
145 
146 public:
147   /// Methods for support type inquiry through isa, cast, and dyn_cast:
148   static bool classof(const SCEV *S) {
149     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
150            S->getSCEVType() == scSignExtend;
151   }
152 };
153 
154 /// This class represents a truncation of an integer value to a
155 /// smaller integer value.
156 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
157   friend class ScalarEvolution;
158 
159   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160 
161 public:
162   /// Methods for support type inquiry through isa, cast, and dyn_cast:
163   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
164 };
165 
166 /// This class represents a zero extension of a small integer value
167 /// to a larger integer value.
168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
169   friend class ScalarEvolution;
170 
171   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
172 
173 public:
174   /// Methods for support type inquiry through isa, cast, and dyn_cast:
175   static bool classof(const SCEV *S) {
176     return S->getSCEVType() == scZeroExtend;
177   }
178 };
179 
180 /// This class represents a sign extension of a small integer value
181 /// to a larger integer value.
182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
183   friend class ScalarEvolution;
184 
185   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
186 
187 public:
188   /// Methods for support type inquiry through isa, cast, and dyn_cast:
189   static bool classof(const SCEV *S) {
190     return S->getSCEVType() == scSignExtend;
191   }
192 };
193 
194 /// This node is a base class providing common functionality for
195 /// n'ary operators.
196 class SCEVNAryExpr : public SCEV {
197 protected:
198   // Since SCEVs are immutable, ScalarEvolution allocates operand
199   // arrays with its SCEVAllocator, so this class just needs a simple
200   // pointer rather than a more elaborate vector-like data structure.
201   // This also avoids the need for a non-trivial destructor.
202   const SCEV *const *Operands;
203   size_t NumOperands;
204 
205   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
206                const SCEV *const *O, size_t N)
207       : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
208         NumOperands(N) {}
209 
210 public:
211   size_t getNumOperands() const { return NumOperands; }
212 
213   const SCEV *getOperand(unsigned i) const {
214     assert(i < NumOperands && "Operand index out of range!");
215     return Operands[i];
216   }
217 
218   ArrayRef<const SCEV *> operands() const {
219     return ArrayRef(Operands, NumOperands);
220   }
221 
222   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
223     return (NoWrapFlags)(SubclassData & Mask);
224   }
225 
226   bool hasNoUnsignedWrap() const {
227     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
228   }
229 
230   bool hasNoSignedWrap() const {
231     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
232   }
233 
234   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
235 
236   /// Methods for support type inquiry through isa, cast, and dyn_cast:
237   static bool classof(const SCEV *S) {
238     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
239            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
240            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
241            S->getSCEVType() == scSequentialUMinExpr ||
242            S->getSCEVType() == scAddRecExpr;
243   }
244 };
245 
246 /// This node is the base class for n'ary commutative operators.
247 class SCEVCommutativeExpr : public SCEVNAryExpr {
248 protected:
249   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
250                       const SCEV *const *O, size_t N)
251       : SCEVNAryExpr(ID, T, O, N) {}
252 
253 public:
254   /// Methods for support type inquiry through isa, cast, and dyn_cast:
255   static bool classof(const SCEV *S) {
256     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
257            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
258            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
259   }
260 
261   /// Set flags for a non-recurrence without clearing previously set flags.
262   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
263 };
264 
265 /// This node represents an addition of some number of SCEVs.
266 class SCEVAddExpr : public SCEVCommutativeExpr {
267   friend class ScalarEvolution;
268 
269   Type *Ty;
270 
271   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
272       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
273     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
274       return Op->getType()->isPointerTy();
275     });
276     if (FirstPointerTypedOp != operands().end())
277       Ty = (*FirstPointerTypedOp)->getType();
278     else
279       Ty = getOperand(0)->getType();
280   }
281 
282 public:
283   Type *getType() const { return Ty; }
284 
285   /// Methods for support type inquiry through isa, cast, and dyn_cast:
286   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
287 };
288 
289 /// This node represents multiplication of some number of SCEVs.
290 class SCEVMulExpr : public SCEVCommutativeExpr {
291   friend class ScalarEvolution;
292 
293   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
294       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
295 
296 public:
297   Type *getType() const { return getOperand(0)->getType(); }
298 
299   /// Methods for support type inquiry through isa, cast, and dyn_cast:
300   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
301 };
302 
303 /// This class represents a binary unsigned division operation.
304 class SCEVUDivExpr : public SCEV {
305   friend class ScalarEvolution;
306 
307   std::array<const SCEV *, 2> Operands;
308 
309   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
310       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
311     Operands[0] = lhs;
312     Operands[1] = rhs;
313   }
314 
315 public:
316   const SCEV *getLHS() const { return Operands[0]; }
317   const SCEV *getRHS() const { return Operands[1]; }
318   size_t getNumOperands() const { return 2; }
319   const SCEV *getOperand(unsigned i) const {
320     assert((i == 0 || i == 1) && "Operand index out of range!");
321     return i == 0 ? getLHS() : getRHS();
322   }
323 
324   ArrayRef<const SCEV *> operands() const { return Operands; }
325 
326   Type *getType() const {
327     // In most cases the types of LHS and RHS will be the same, but in some
328     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
329     // depend on the type for correctness, but handling types carefully can
330     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
331     // a pointer type than the RHS, so use the RHS' type here.
332     return getRHS()->getType();
333   }
334 
335   /// Methods for support type inquiry through isa, cast, and dyn_cast:
336   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
337 };
338 
339 /// This node represents a polynomial recurrence on the trip count
340 /// of the specified loop.  This is the primary focus of the
341 /// ScalarEvolution framework; all the other SCEV subclasses are
342 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
343 /// expressions to be created and analyzed.
344 ///
345 /// All operands of an AddRec are required to be loop invariant.
346 ///
347 class SCEVAddRecExpr : public SCEVNAryExpr {
348   friend class ScalarEvolution;
349 
350   const Loop *L;
351 
352   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
353                  const Loop *l)
354       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
355 
356 public:
357   Type *getType() const { return getStart()->getType(); }
358   const SCEV *getStart() const { return Operands[0]; }
359   const Loop *getLoop() const { return L; }
360 
361   /// Constructs and returns the recurrence indicating how much this
362   /// expression steps by.  If this is a polynomial of degree N, it
363   /// returns a chrec of degree N-1.  We cannot determine whether
364   /// the step recurrence has self-wraparound.
365   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
366     if (isAffine())
367       return getOperand(1);
368     return SE.getAddRecExpr(
369         SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
370         FlagAnyWrap);
371   }
372 
373   /// Return true if this represents an expression A + B*x where A
374   /// and B are loop invariant values.
375   bool isAffine() const {
376     // We know that the start value is invariant.  This expression is thus
377     // affine iff the step is also invariant.
378     return getNumOperands() == 2;
379   }
380 
381   /// Return true if this represents an expression A + B*x + C*x^2
382   /// where A, B and C are loop invariant values.  This corresponds
383   /// to an addrec of the form {L,+,M,+,N}
384   bool isQuadratic() const { return getNumOperands() == 3; }
385 
386   /// Set flags for a recurrence without clearing any previously set flags.
387   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
388   /// to make it easier to propagate flags.
389   void setNoWrapFlags(NoWrapFlags Flags) {
390     if (Flags & (FlagNUW | FlagNSW))
391       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
392     SubclassData |= Flags;
393   }
394 
395   /// Return the value of this chain of recurrences at the specified
396   /// iteration number.
397   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
398 
399   /// Return the value of this chain of recurrences at the specified iteration
400   /// number. Takes an explicit list of operands to represent an AddRec.
401   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
402                                          const SCEV *It, ScalarEvolution &SE);
403 
404   /// Return the number of iterations of this loop that produce
405   /// values in the specified constant range.  Another way of
406   /// looking at this is that it returns the first iteration number
407   /// where the value is not in the condition, thus computing the
408   /// exit count.  If the iteration count can't be computed, an
409   /// instance of SCEVCouldNotCompute is returned.
410   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411                                       ScalarEvolution &SE) const;
412 
413   /// Return an expression representing the value of this expression
414   /// one iteration of the loop ahead.
415   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416 
417   /// Methods for support type inquiry through isa, cast, and dyn_cast:
418   static bool classof(const SCEV *S) {
419     return S->getSCEVType() == scAddRecExpr;
420   }
421 };
422 
423 /// This node is the base class min/max selections.
424 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425   friend class ScalarEvolution;
426 
427   static bool isMinMaxType(enum SCEVTypes T) {
428     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429            T == scUMinExpr;
430   }
431 
432 protected:
433   /// Note: Constructing subclasses via this constructor is allowed
434   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435                  const SCEV *const *O, size_t N)
436       : SCEVCommutativeExpr(ID, T, O, N) {
437     assert(isMinMaxType(T));
438     // Min and max never overflow
439     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440   }
441 
442 public:
443   Type *getType() const { return getOperand(0)->getType(); }
444 
445   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
446 
447   static enum SCEVTypes negate(enum SCEVTypes T) {
448     switch (T) {
449     case scSMaxExpr:
450       return scSMinExpr;
451     case scSMinExpr:
452       return scSMaxExpr;
453     case scUMaxExpr:
454       return scUMinExpr;
455     case scUMinExpr:
456       return scUMaxExpr;
457     default:
458       llvm_unreachable("Not a min or max SCEV type!");
459     }
460   }
461 };
462 
463 /// This class represents a signed maximum selection.
464 class SCEVSMaxExpr : public SCEVMinMaxExpr {
465   friend class ScalarEvolution;
466 
467   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469 
470 public:
471   /// Methods for support type inquiry through isa, cast, and dyn_cast:
472   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
473 };
474 
475 /// This class represents an unsigned maximum selection.
476 class SCEVUMaxExpr : public SCEVMinMaxExpr {
477   friend class ScalarEvolution;
478 
479   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
480       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
481 
482 public:
483   /// Methods for support type inquiry through isa, cast, and dyn_cast:
484   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
485 };
486 
487 /// This class represents a signed minimum selection.
488 class SCEVSMinExpr : public SCEVMinMaxExpr {
489   friend class ScalarEvolution;
490 
491   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
492       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
493 
494 public:
495   /// Methods for support type inquiry through isa, cast, and dyn_cast:
496   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
497 };
498 
499 /// This class represents an unsigned minimum selection.
500 class SCEVUMinExpr : public SCEVMinMaxExpr {
501   friend class ScalarEvolution;
502 
503   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
505 
506 public:
507   /// Methods for support type inquiry through isa, cast, and dyn_cast:
508   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
509 };
510 
511 /// This node is the base class for sequential/in-order min/max selections.
512 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
513 /// are early-returning upon reaching saturation point.
514 /// I.e. given `0 umin_seq poison`, the result will be `0`,
515 /// while the result of `0 umin poison` is `poison`.
516 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
517   friend class ScalarEvolution;
518 
519   static bool isSequentialMinMaxType(enum SCEVTypes T) {
520     return T == scSequentialUMinExpr;
521   }
522 
523   /// Set flags for a non-recurrence without clearing previously set flags.
524   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
525 
526 protected:
527   /// Note: Constructing subclasses via this constructor is allowed
528   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
529                            const SCEV *const *O, size_t N)
530       : SCEVNAryExpr(ID, T, O, N) {
531     assert(isSequentialMinMaxType(T));
532     // Min and max never overflow
533     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
534   }
535 
536 public:
537   Type *getType() const { return getOperand(0)->getType(); }
538 
539   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
540     assert(isSequentialMinMaxType(Ty));
541     switch (Ty) {
542     case scSequentialUMinExpr:
543       return scUMinExpr;
544     default:
545       llvm_unreachable("Not a sequential min/max type.");
546     }
547   }
548 
549   SCEVTypes getEquivalentNonSequentialSCEVType() const {
550     return getEquivalentNonSequentialSCEVType(getSCEVType());
551   }
552 
553   static bool classof(const SCEV *S) {
554     return isSequentialMinMaxType(S->getSCEVType());
555   }
556 };
557 
558 /// This class represents a sequential/in-order unsigned minimum selection.
559 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
560   friend class ScalarEvolution;
561 
562   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
563                          size_t N)
564       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
565 
566 public:
567   /// Methods for support type inquiry through isa, cast, and dyn_cast:
568   static bool classof(const SCEV *S) {
569     return S->getSCEVType() == scSequentialUMinExpr;
570   }
571 };
572 
573 /// This means that we are dealing with an entirely unknown SCEV
574 /// value, and only represent it as its LLVM Value.  This is the
575 /// "bottom" value for the analysis.
576 class SCEVUnknown final : public SCEV, private CallbackVH {
577   friend class ScalarEvolution;
578 
579   /// The parent ScalarEvolution value. This is used to update the
580   /// parent's maps when the value associated with a SCEVUnknown is
581   /// deleted or RAUW'd.
582   ScalarEvolution *SE;
583 
584   /// The next pointer in the linked list of all SCEVUnknown
585   /// instances owned by a ScalarEvolution.
586   SCEVUnknown *Next;
587 
588   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
589               SCEVUnknown *next)
590       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
591 
592   // Implement CallbackVH.
593   void deleted() override;
594   void allUsesReplacedWith(Value *New) override;
595 
596 public:
597   Value *getValue() const { return getValPtr(); }
598 
599   Type *getType() const { return getValPtr()->getType(); }
600 
601   /// Methods for support type inquiry through isa, cast, and dyn_cast:
602   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
603 };
604 
605 /// This class defines a simple visitor class that may be used for
606 /// various SCEV analysis purposes.
607 template <typename SC, typename RetVal = void> struct SCEVVisitor {
608   RetVal visit(const SCEV *S) {
609     switch (S->getSCEVType()) {
610     case scConstant:
611       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
612     case scVScale:
613       return ((SC *)this)->visitVScale((const SCEVVScale *)S);
614     case scPtrToInt:
615       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
616     case scTruncate:
617       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
618     case scZeroExtend:
619       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
620     case scSignExtend:
621       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
622     case scAddExpr:
623       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
624     case scMulExpr:
625       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
626     case scUDivExpr:
627       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
628     case scAddRecExpr:
629       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
630     case scSMaxExpr:
631       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
632     case scUMaxExpr:
633       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
634     case scSMinExpr:
635       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
636     case scUMinExpr:
637       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
638     case scSequentialUMinExpr:
639       return ((SC *)this)
640           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
641     case scUnknown:
642       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
643     case scCouldNotCompute:
644       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
645     }
646     llvm_unreachable("Unknown SCEV kind!");
647   }
648 
649   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
650     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
651   }
652 };
653 
654 /// Visit all nodes in the expression tree using worklist traversal.
655 ///
656 /// Visitor implements:
657 ///   // return true to follow this node.
658 ///   bool follow(const SCEV *S);
659 ///   // return true to terminate the search.
660 ///   bool isDone();
661 template <typename SV> class SCEVTraversal {
662   SV &Visitor;
663   SmallVector<const SCEV *, 8> Worklist;
664   SmallPtrSet<const SCEV *, 8> Visited;
665 
666   void push(const SCEV *S) {
667     if (Visited.insert(S).second && Visitor.follow(S))
668       Worklist.push_back(S);
669   }
670 
671 public:
672   SCEVTraversal(SV &V) : Visitor(V) {}
673 
674   void visitAll(const SCEV *Root) {
675     push(Root);
676     while (!Worklist.empty() && !Visitor.isDone()) {
677       const SCEV *S = Worklist.pop_back_val();
678 
679       switch (S->getSCEVType()) {
680       case scConstant:
681       case scVScale:
682       case scUnknown:
683         continue;
684       case scPtrToInt:
685       case scTruncate:
686       case scZeroExtend:
687       case scSignExtend:
688       case scAddExpr:
689       case scMulExpr:
690       case scUDivExpr:
691       case scSMaxExpr:
692       case scUMaxExpr:
693       case scSMinExpr:
694       case scUMinExpr:
695       case scSequentialUMinExpr:
696       case scAddRecExpr:
697         for (const auto *Op : S->operands()) {
698           push(Op);
699           if (Visitor.isDone())
700             break;
701         }
702         continue;
703       case scCouldNotCompute:
704         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
705       }
706       llvm_unreachable("Unknown SCEV kind!");
707     }
708   }
709 };
710 
711 /// Use SCEVTraversal to visit all nodes in the given expression tree.
712 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
713   SCEVTraversal<SV> T(Visitor);
714   T.visitAll(Root);
715 }
716 
717 /// Return true if any node in \p Root satisfies the predicate \p Pred.
718 template <typename PredTy>
719 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
720   struct FindClosure {
721     bool Found = false;
722     PredTy Pred;
723 
724     FindClosure(PredTy Pred) : Pred(Pred) {}
725 
726     bool follow(const SCEV *S) {
727       if (!Pred(S))
728         return true;
729 
730       Found = true;
731       return false;
732     }
733 
734     bool isDone() const { return Found; }
735   };
736 
737   FindClosure FC(Pred);
738   visitAll(Root, FC);
739   return FC.Found;
740 }
741 
742 /// This visitor recursively visits a SCEV expression and re-writes it.
743 /// The result from each visit is cached, so it will return the same
744 /// SCEV for the same input.
745 template <typename SC>
746 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
747 protected:
748   ScalarEvolution &SE;
749   // Memoize the result of each visit so that we only compute once for
750   // the same input SCEV. This is to avoid redundant computations when
751   // a SCEV is referenced by multiple SCEVs. Without memoization, this
752   // visit algorithm would have exponential time complexity in the worst
753   // case, causing the compiler to hang on certain tests.
754   SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
755 
756 public:
757   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
758 
759   const SCEV *visit(const SCEV *S) {
760     auto It = RewriteResults.find(S);
761     if (It != RewriteResults.end())
762       return It->second;
763     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
764     auto Result = RewriteResults.try_emplace(S, Visited);
765     assert(Result.second && "Should insert a new entry");
766     return Result.first->second;
767   }
768 
769   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
770 
771   const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
772 
773   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
774     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
775     return Operand == Expr->getOperand()
776                ? Expr
777                : SE.getPtrToIntExpr(Operand, Expr->getType());
778   }
779 
780   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
781     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
782     return Operand == Expr->getOperand()
783                ? Expr
784                : SE.getTruncateExpr(Operand, Expr->getType());
785   }
786 
787   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
788     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
789     return Operand == Expr->getOperand()
790                ? Expr
791                : SE.getZeroExtendExpr(Operand, Expr->getType());
792   }
793 
794   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
795     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
796     return Operand == Expr->getOperand()
797                ? Expr
798                : SE.getSignExtendExpr(Operand, Expr->getType());
799   }
800 
801   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
802     SmallVector<const SCEV *, 2> Operands;
803     bool Changed = false;
804     for (const auto *Op : Expr->operands()) {
805       Operands.push_back(((SC *)this)->visit(Op));
806       Changed |= Op != Operands.back();
807     }
808     return !Changed ? Expr : SE.getAddExpr(Operands);
809   }
810 
811   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
812     SmallVector<const SCEV *, 2> Operands;
813     bool Changed = false;
814     for (const auto *Op : Expr->operands()) {
815       Operands.push_back(((SC *)this)->visit(Op));
816       Changed |= Op != Operands.back();
817     }
818     return !Changed ? Expr : SE.getMulExpr(Operands);
819   }
820 
821   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
822     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
823     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
824     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
825     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
826   }
827 
828   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
829     SmallVector<const SCEV *, 2> Operands;
830     bool Changed = false;
831     for (const auto *Op : Expr->operands()) {
832       Operands.push_back(((SC *)this)->visit(Op));
833       Changed |= Op != Operands.back();
834     }
835     return !Changed ? Expr
836                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
837                                        Expr->getNoWrapFlags());
838   }
839 
840   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
841     SmallVector<const SCEV *, 2> Operands;
842     bool Changed = false;
843     for (const auto *Op : Expr->operands()) {
844       Operands.push_back(((SC *)this)->visit(Op));
845       Changed |= Op != Operands.back();
846     }
847     return !Changed ? Expr : SE.getSMaxExpr(Operands);
848   }
849 
850   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
851     SmallVector<const SCEV *, 2> Operands;
852     bool Changed = false;
853     for (const auto *Op : Expr->operands()) {
854       Operands.push_back(((SC *)this)->visit(Op));
855       Changed |= Op != Operands.back();
856     }
857     return !Changed ? Expr : SE.getUMaxExpr(Operands);
858   }
859 
860   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
861     SmallVector<const SCEV *, 2> Operands;
862     bool Changed = false;
863     for (const auto *Op : Expr->operands()) {
864       Operands.push_back(((SC *)this)->visit(Op));
865       Changed |= Op != Operands.back();
866     }
867     return !Changed ? Expr : SE.getSMinExpr(Operands);
868   }
869 
870   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
871     SmallVector<const SCEV *, 2> Operands;
872     bool Changed = false;
873     for (const auto *Op : Expr->operands()) {
874       Operands.push_back(((SC *)this)->visit(Op));
875       Changed |= Op != Operands.back();
876     }
877     return !Changed ? Expr : SE.getUMinExpr(Operands);
878   }
879 
880   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
881     SmallVector<const SCEV *, 2> Operands;
882     bool Changed = false;
883     for (const auto *Op : Expr->operands()) {
884       Operands.push_back(((SC *)this)->visit(Op));
885       Changed |= Op != Operands.back();
886     }
887     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
888   }
889 
890   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
891 
892   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
893     return Expr;
894   }
895 };
896 
897 using ValueToValueMap = DenseMap<const Value *, Value *>;
898 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
899 
900 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
901 /// the SCEVUnknown components following the Map (Value -> SCEV).
902 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
903 public:
904   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
905                              ValueToSCEVMapTy &Map) {
906     SCEVParameterRewriter Rewriter(SE, Map);
907     return Rewriter.visit(Scev);
908   }
909 
910   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
911       : SCEVRewriteVisitor(SE), Map(M) {}
912 
913   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
914     auto I = Map.find(Expr->getValue());
915     if (I == Map.end())
916       return Expr;
917     return I->second;
918   }
919 
920 private:
921   ValueToSCEVMapTy &Map;
922 };
923 
924 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
925 
926 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
927 /// the Map (Loop -> SCEV) to all AddRecExprs.
928 class SCEVLoopAddRecRewriter
929     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
930 public:
931   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
932       : SCEVRewriteVisitor(SE), Map(M) {}
933 
934   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
935                              ScalarEvolution &SE) {
936     SCEVLoopAddRecRewriter Rewriter(SE, Map);
937     return Rewriter.visit(Scev);
938   }
939 
940   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
941     SmallVector<const SCEV *, 2> Operands;
942     for (const SCEV *Op : Expr->operands())
943       Operands.push_back(visit(Op));
944 
945     const Loop *L = Expr->getLoop();
946     if (0 == Map.count(L))
947       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
948 
949     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
950   }
951 
952 private:
953   LoopToScevMapT &Map;
954 };
955 
956 } // end namespace llvm
957 
958 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
959