1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/ValueHandle.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27
28 namespace llvm {
29
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37
38 enum SCEVTypes : unsigned short {
39 // These should be ordered in terms of increasing complexity to make the
40 // folders simpler.
41 scConstant,
42 scVScale,
43 scTruncate,
44 scZeroExtend,
45 scSignExtend,
46 scAddExpr,
47 scMulExpr,
48 scUDivExpr,
49 scAddRecExpr,
50 scUMaxExpr,
51 scSMaxExpr,
52 scUMinExpr,
53 scSMinExpr,
54 scSequentialUMinExpr,
55 scPtrToInt,
56 scUnknown,
57 scCouldNotCompute
58 };
59
60 /// This class represents a constant integer value.
61 class SCEVConstant : public SCEV {
62 friend class ScalarEvolution;
63
64 ConstantInt *V;
65
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)66 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
67 : SCEV(ID, scConstant, 1), V(v) {}
68
69 public:
getValue()70 ConstantInt *getValue() const { return V; }
getAPInt()71 const APInt &getAPInt() const { return getValue()->getValue(); }
72
getType()73 Type *getType() const { return V->getType(); }
74
75 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)76 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
77 };
78
79 /// This class represents the value of vscale, as used when defining the length
80 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
81 class SCEVVScale : public SCEV {
82 friend class ScalarEvolution;
83
SCEVVScale(const FoldingSetNodeIDRef ID,Type * ty)84 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
85 : SCEV(ID, scVScale, 0), Ty(ty) {}
86
87 Type *Ty;
88
89 public:
getType()90 Type *getType() const { return Ty; }
91
92 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)93 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
94 };
95
computeExpressionSize(ArrayRef<const SCEV * > Args)96 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
97 APInt Size(16, 1);
98 for (const auto *Arg : Args)
99 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
100 return (unsigned short)Size.getZExtValue();
101 }
102
103 /// This is the base class for unary cast operator classes.
104 class SCEVCastExpr : public SCEV {
105 protected:
106 const SCEV *Op;
107 Type *Ty;
108
109 LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
110 const SCEV *op, Type *ty);
111
112 public:
getOperand()113 const SCEV *getOperand() const { return Op; }
getOperand(unsigned i)114 const SCEV *getOperand(unsigned i) const {
115 assert(i == 0 && "Operand index out of range!");
116 return Op;
117 }
operands()118 ArrayRef<const SCEV *> operands() const { return Op; }
getNumOperands()119 size_t getNumOperands() const { return 1; }
getType()120 Type *getType() const { return Ty; }
121
122 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)123 static bool classof(const SCEV *S) {
124 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
125 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
126 }
127 };
128
129 /// This class represents a cast from a pointer to a pointer-sized integer
130 /// value.
131 class SCEVPtrToIntExpr : public SCEVCastExpr {
132 friend class ScalarEvolution;
133
134 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
135
136 public:
137 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)138 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
139 };
140
141 /// This is the base class for unary integral cast operator classes.
142 class SCEVIntegralCastExpr : public SCEVCastExpr {
143 protected:
144 LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
145 const SCEV *op, Type *ty);
146
147 public:
148 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)149 static bool classof(const SCEV *S) {
150 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
151 S->getSCEVType() == scSignExtend;
152 }
153 };
154
155 /// This class represents a truncation of an integer value to a
156 /// smaller integer value.
157 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
158 friend class ScalarEvolution;
159
160 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
161
162 public:
163 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)164 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
165 };
166
167 /// This class represents a zero extension of a small integer value
168 /// to a larger integer value.
169 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
170 friend class ScalarEvolution;
171
172 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
173
174 public:
175 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)176 static bool classof(const SCEV *S) {
177 return S->getSCEVType() == scZeroExtend;
178 }
179 };
180
181 /// This class represents a sign extension of a small integer value
182 /// to a larger integer value.
183 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
184 friend class ScalarEvolution;
185
186 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
187
188 public:
189 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)190 static bool classof(const SCEV *S) {
191 return S->getSCEVType() == scSignExtend;
192 }
193 };
194
195 /// This node is a base class providing common functionality for
196 /// n'ary operators.
197 class SCEVNAryExpr : public SCEV {
198 protected:
199 // Since SCEVs are immutable, ScalarEvolution allocates operand
200 // arrays with its SCEVAllocator, so this class just needs a simple
201 // pointer rather than a more elaborate vector-like data structure.
202 // This also avoids the need for a non-trivial destructor.
203 const SCEV *const *Operands;
204 size_t NumOperands;
205
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)206 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
207 const SCEV *const *O, size_t N)
208 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
209 NumOperands(N) {}
210
211 public:
getNumOperands()212 size_t getNumOperands() const { return NumOperands; }
213
getOperand(unsigned i)214 const SCEV *getOperand(unsigned i) const {
215 assert(i < NumOperands && "Operand index out of range!");
216 return Operands[i];
217 }
218
operands()219 ArrayRef<const SCEV *> operands() const {
220 return ArrayRef(Operands, NumOperands);
221 }
222
223 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
224 return (NoWrapFlags)(SubclassData & Mask);
225 }
226
hasNoUnsignedWrap()227 bool hasNoUnsignedWrap() const {
228 return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
229 }
230
hasNoSignedWrap()231 bool hasNoSignedWrap() const {
232 return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
233 }
234
hasNoSelfWrap()235 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
236
237 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)238 static bool classof(const SCEV *S) {
239 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
240 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
241 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
242 S->getSCEVType() == scSequentialUMinExpr ||
243 S->getSCEVType() == scAddRecExpr;
244 }
245 };
246
247 /// This node is the base class for n'ary commutative operators.
248 class SCEVCommutativeExpr : public SCEVNAryExpr {
249 protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)250 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
251 const SCEV *const *O, size_t N)
252 : SCEVNAryExpr(ID, T, O, N) {}
253
254 public:
255 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)256 static bool classof(const SCEV *S) {
257 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
258 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
259 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
260 }
261
262 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)263 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
264 };
265
266 /// This node represents an addition of some number of SCEVs.
267 class SCEVAddExpr : public SCEVCommutativeExpr {
268 friend class ScalarEvolution;
269
270 Type *Ty;
271
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)272 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
273 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
274 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
275 return Op->getType()->isPointerTy();
276 });
277 if (FirstPointerTypedOp != operands().end())
278 Ty = (*FirstPointerTypedOp)->getType();
279 else
280 Ty = getOperand(0)->getType();
281 }
282
283 public:
getType()284 Type *getType() const { return Ty; }
285
286 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)287 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
288 };
289
290 /// This node represents multiplication of some number of SCEVs.
291 class SCEVMulExpr : public SCEVCommutativeExpr {
292 friend class ScalarEvolution;
293
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)294 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
295 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
296
297 public:
getType()298 Type *getType() const { return getOperand(0)->getType(); }
299
300 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)301 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
302 };
303
304 /// This class represents a binary unsigned division operation.
305 class SCEVUDivExpr : public SCEV {
306 friend class ScalarEvolution;
307
308 std::array<const SCEV *, 2> Operands;
309
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)310 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
311 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
312 Operands[0] = lhs;
313 Operands[1] = rhs;
314 }
315
316 public:
getLHS()317 const SCEV *getLHS() const { return Operands[0]; }
getRHS()318 const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()319 size_t getNumOperands() const { return 2; }
getOperand(unsigned i)320 const SCEV *getOperand(unsigned i) const {
321 assert((i == 0 || i == 1) && "Operand index out of range!");
322 return i == 0 ? getLHS() : getRHS();
323 }
324
operands()325 ArrayRef<const SCEV *> operands() const { return Operands; }
326
getType()327 Type *getType() const {
328 // In most cases the types of LHS and RHS will be the same, but in some
329 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
330 // depend on the type for correctness, but handling types carefully can
331 // avoid extra casts in the SCEVExpander. The LHS is more likely to be
332 // a pointer type than the RHS, so use the RHS' type here.
333 return getRHS()->getType();
334 }
335
336 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)337 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
338 };
339
340 /// This node represents a polynomial recurrence on the trip count
341 /// of the specified loop. This is the primary focus of the
342 /// ScalarEvolution framework; all the other SCEV subclasses are
343 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
344 /// expressions to be created and analyzed.
345 ///
346 /// All operands of an AddRec are required to be loop invariant.
347 ///
348 class SCEVAddRecExpr : public SCEVNAryExpr {
349 friend class ScalarEvolution;
350
351 const Loop *L;
352
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)353 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
354 const Loop *l)
355 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
356
357 public:
getType()358 Type *getType() const { return getStart()->getType(); }
getStart()359 const SCEV *getStart() const { return Operands[0]; }
getLoop()360 const Loop *getLoop() const { return L; }
361
362 /// Constructs and returns the recurrence indicating how much this
363 /// expression steps by. If this is a polynomial of degree N, it
364 /// returns a chrec of degree N-1. We cannot determine whether
365 /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)366 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
367 if (isAffine())
368 return getOperand(1);
369 return SE.getAddRecExpr(
370 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
371 FlagAnyWrap);
372 }
373
374 /// Return true if this represents an expression A + B*x where A
375 /// and B are loop invariant values.
isAffine()376 bool isAffine() const {
377 // We know that the start value is invariant. This expression is thus
378 // affine iff the step is also invariant.
379 return getNumOperands() == 2;
380 }
381
382 /// Return true if this represents an expression A + B*x + C*x^2
383 /// where A, B and C are loop invariant values. This corresponds
384 /// to an addrec of the form {L,+,M,+,N}
isQuadratic()385 bool isQuadratic() const { return getNumOperands() == 3; }
386
387 /// Set flags for a recurrence without clearing any previously set flags.
388 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
389 /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)390 void setNoWrapFlags(NoWrapFlags Flags) {
391 if (Flags & (FlagNUW | FlagNSW))
392 Flags = ScalarEvolution::setFlags(Flags, FlagNW);
393 SubclassData |= Flags;
394 }
395
396 /// Return the value of this chain of recurrences at the specified
397 /// iteration number.
398 LLVM_ABI const SCEV *evaluateAtIteration(const SCEV *It,
399 ScalarEvolution &SE) const;
400
401 /// Return the value of this chain of recurrences at the specified iteration
402 /// number. Takes an explicit list of operands to represent an AddRec.
403 LLVM_ABI static const SCEV *
404 evaluateAtIteration(ArrayRef<const SCEV *> Operands, const SCEV *It,
405 ScalarEvolution &SE);
406
407 /// Return the number of iterations of this loop that produce
408 /// values in the specified constant range. Another way of
409 /// looking at this is that it returns the first iteration number
410 /// where the value is not in the condition, thus computing the
411 /// exit count. If the iteration count can't be computed, an
412 /// instance of SCEVCouldNotCompute is returned.
413 LLVM_ABI const SCEV *getNumIterationsInRange(const ConstantRange &Range,
414 ScalarEvolution &SE) const;
415
416 /// Return an expression representing the value of this expression
417 /// one iteration of the loop ahead.
418 LLVM_ABI const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
419
420 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)421 static bool classof(const SCEV *S) {
422 return S->getSCEVType() == scAddRecExpr;
423 }
424 };
425
426 /// This node is the base class min/max selections.
427 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
428 friend class ScalarEvolution;
429
isMinMaxType(enum SCEVTypes T)430 static bool isMinMaxType(enum SCEVTypes T) {
431 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
432 T == scUMinExpr;
433 }
434
435 protected:
436 /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)437 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
438 const SCEV *const *O, size_t N)
439 : SCEVCommutativeExpr(ID, T, O, N) {
440 assert(isMinMaxType(T));
441 // Min and max never overflow
442 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
443 }
444
445 public:
getType()446 Type *getType() const { return getOperand(0)->getType(); }
447
classof(const SCEV * S)448 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
449
negate(enum SCEVTypes T)450 static enum SCEVTypes negate(enum SCEVTypes T) {
451 switch (T) {
452 case scSMaxExpr:
453 return scSMinExpr;
454 case scSMinExpr:
455 return scSMaxExpr;
456 case scUMaxExpr:
457 return scUMinExpr;
458 case scUMinExpr:
459 return scUMaxExpr;
460 default:
461 llvm_unreachable("Not a min or max SCEV type!");
462 }
463 }
464 };
465
466 /// This class represents a signed maximum selection.
467 class SCEVSMaxExpr : public SCEVMinMaxExpr {
468 friend class ScalarEvolution;
469
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)470 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
471 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
472
473 public:
474 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)475 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
476 };
477
478 /// This class represents an unsigned maximum selection.
479 class SCEVUMaxExpr : public SCEVMinMaxExpr {
480 friend class ScalarEvolution;
481
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)482 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
483 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
484
485 public:
486 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)487 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
488 };
489
490 /// This class represents a signed minimum selection.
491 class SCEVSMinExpr : public SCEVMinMaxExpr {
492 friend class ScalarEvolution;
493
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)494 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
495 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
496
497 public:
498 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)499 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
500 };
501
502 /// This class represents an unsigned minimum selection.
503 class SCEVUMinExpr : public SCEVMinMaxExpr {
504 friend class ScalarEvolution;
505
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)506 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
507 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
508
509 public:
510 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)511 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
512 };
513
514 /// This node is the base class for sequential/in-order min/max selections.
515 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
516 /// are early-returning upon reaching saturation point.
517 /// I.e. given `0 umin_seq poison`, the result will be `0`, while the result of
518 /// `0 umin poison` is `poison`. When returning early, later expressions are not
519 /// executed, so `0 umin_seq (%x u/ 0)` does not result in undefined behavior.
520 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
521 friend class ScalarEvolution;
522
isSequentialMinMaxType(enum SCEVTypes T)523 static bool isSequentialMinMaxType(enum SCEVTypes T) {
524 return T == scSequentialUMinExpr;
525 }
526
527 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)528 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
529
530 protected:
531 /// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)532 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
533 const SCEV *const *O, size_t N)
534 : SCEVNAryExpr(ID, T, O, N) {
535 assert(isSequentialMinMaxType(T));
536 // Min and max never overflow
537 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
538 }
539
540 public:
getType()541 Type *getType() const { return getOperand(0)->getType(); }
542
getEquivalentNonSequentialSCEVType(SCEVTypes Ty)543 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
544 assert(isSequentialMinMaxType(Ty));
545 switch (Ty) {
546 case scSequentialUMinExpr:
547 return scUMinExpr;
548 default:
549 llvm_unreachable("Not a sequential min/max type.");
550 }
551 }
552
getEquivalentNonSequentialSCEVType()553 SCEVTypes getEquivalentNonSequentialSCEVType() const {
554 return getEquivalentNonSequentialSCEVType(getSCEVType());
555 }
556
classof(const SCEV * S)557 static bool classof(const SCEV *S) {
558 return isSequentialMinMaxType(S->getSCEVType());
559 }
560 };
561
562 /// This class represents a sequential/in-order unsigned minimum selection.
563 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
564 friend class ScalarEvolution;
565
SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)566 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
567 size_t N)
568 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
569
570 public:
571 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)572 static bool classof(const SCEV *S) {
573 return S->getSCEVType() == scSequentialUMinExpr;
574 }
575 };
576
577 /// This means that we are dealing with an entirely unknown SCEV
578 /// value, and only represent it as its LLVM Value. This is the
579 /// "bottom" value for the analysis.
580 class LLVM_ABI SCEVUnknown final : public SCEV, private CallbackVH {
581 friend class ScalarEvolution;
582
583 /// The parent ScalarEvolution value. This is used to update the
584 /// parent's maps when the value associated with a SCEVUnknown is
585 /// deleted or RAUW'd.
586 ScalarEvolution *SE;
587
588 /// The next pointer in the linked list of all SCEVUnknown
589 /// instances owned by a ScalarEvolution.
590 SCEVUnknown *Next;
591
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)592 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
593 SCEVUnknown *next)
594 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
595
596 // Implement CallbackVH.
597 void deleted() override;
598 void allUsesReplacedWith(Value *New) override;
599
600 public:
getValue()601 Value *getValue() const { return getValPtr(); }
602
getType()603 Type *getType() const { return getValPtr()->getType(); }
604
605 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)606 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
607 };
608
609 /// This class defines a simple visitor class that may be used for
610 /// various SCEV analysis purposes.
611 template <typename SC, typename RetVal = void> struct SCEVVisitor {
visitSCEVVisitor612 RetVal visit(const SCEV *S) {
613 switch (S->getSCEVType()) {
614 case scConstant:
615 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
616 case scVScale:
617 return ((SC *)this)->visitVScale((const SCEVVScale *)S);
618 case scPtrToInt:
619 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
620 case scTruncate:
621 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
622 case scZeroExtend:
623 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
624 case scSignExtend:
625 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
626 case scAddExpr:
627 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
628 case scMulExpr:
629 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
630 case scUDivExpr:
631 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
632 case scAddRecExpr:
633 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
634 case scSMaxExpr:
635 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
636 case scUMaxExpr:
637 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
638 case scSMinExpr:
639 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
640 case scUMinExpr:
641 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
642 case scSequentialUMinExpr:
643 return ((SC *)this)
644 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
645 case scUnknown:
646 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
647 case scCouldNotCompute:
648 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
649 }
650 llvm_unreachable("Unknown SCEV kind!");
651 }
652
visitCouldNotComputeSCEVVisitor653 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
654 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
655 }
656 };
657
658 /// Visit all nodes in the expression tree using worklist traversal.
659 ///
660 /// Visitor implements:
661 /// // return true to follow this node.
662 /// bool follow(const SCEV *S);
663 /// // return true to terminate the search.
664 /// bool isDone();
665 template <typename SV> class SCEVTraversal {
666 SV &Visitor;
667 SmallVector<const SCEV *, 8> Worklist;
668 SmallPtrSet<const SCEV *, 8> Visited;
669
push(const SCEV * S)670 void push(const SCEV *S) {
671 if (Visited.insert(S).second && Visitor.follow(S))
672 Worklist.push_back(S);
673 }
674
675 public:
SCEVTraversal(SV & V)676 SCEVTraversal(SV &V) : Visitor(V) {}
677
visitAll(const SCEV * Root)678 void visitAll(const SCEV *Root) {
679 push(Root);
680 while (!Worklist.empty() && !Visitor.isDone()) {
681 const SCEV *S = Worklist.pop_back_val();
682
683 switch (S->getSCEVType()) {
684 case scConstant:
685 case scVScale:
686 case scUnknown:
687 continue;
688 case scPtrToInt:
689 case scTruncate:
690 case scZeroExtend:
691 case scSignExtend:
692 case scAddExpr:
693 case scMulExpr:
694 case scUDivExpr:
695 case scSMaxExpr:
696 case scUMaxExpr:
697 case scSMinExpr:
698 case scUMinExpr:
699 case scSequentialUMinExpr:
700 case scAddRecExpr:
701 for (const auto *Op : S->operands()) {
702 push(Op);
703 if (Visitor.isDone())
704 break;
705 }
706 continue;
707 case scCouldNotCompute:
708 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
709 }
710 llvm_unreachable("Unknown SCEV kind!");
711 }
712 }
713 };
714
715 /// Use SCEVTraversal to visit all nodes in the given expression tree.
visitAll(const SCEV * Root,SV & Visitor)716 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
717 SCEVTraversal<SV> T(Visitor);
718 T.visitAll(Root);
719 }
720
721 /// Return true if any node in \p Root satisfies the predicate \p Pred.
722 template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)723 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
724 struct FindClosure {
725 bool Found = false;
726 PredTy Pred;
727
728 FindClosure(PredTy Pred) : Pred(Pred) {}
729
730 bool follow(const SCEV *S) {
731 if (!Pred(S))
732 return true;
733
734 Found = true;
735 return false;
736 }
737
738 bool isDone() const { return Found; }
739 };
740
741 FindClosure FC(Pred);
742 visitAll(Root, FC);
743 return FC.Found;
744 }
745
746 /// This visitor recursively visits a SCEV expression and re-writes it.
747 /// The result from each visit is cached, so it will return the same
748 /// SCEV for the same input.
749 template <typename SC>
750 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
751 protected:
752 ScalarEvolution &SE;
753 // Memoize the result of each visit so that we only compute once for
754 // the same input SCEV. This is to avoid redundant computations when
755 // a SCEV is referenced by multiple SCEVs. Without memoization, this
756 // visit algorithm would have exponential time complexity in the worst
757 // case, causing the compiler to hang on certain tests.
758 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
759
760 public:
SCEVRewriteVisitor(ScalarEvolution & SE)761 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
762
visit(const SCEV * S)763 const SCEV *visit(const SCEV *S) {
764 auto It = RewriteResults.find(S);
765 if (It != RewriteResults.end())
766 return It->second;
767 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
768 auto Result = RewriteResults.try_emplace(S, Visited);
769 assert(Result.second && "Should insert a new entry");
770 return Result.first->second;
771 }
772
visitConstant(const SCEVConstant * Constant)773 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
774
visitVScale(const SCEVVScale * VScale)775 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
776
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)777 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
778 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
779 return Operand == Expr->getOperand()
780 ? Expr
781 : SE.getPtrToIntExpr(Operand, Expr->getType());
782 }
783
visitTruncateExpr(const SCEVTruncateExpr * Expr)784 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
785 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
786 return Operand == Expr->getOperand()
787 ? Expr
788 : SE.getTruncateExpr(Operand, Expr->getType());
789 }
790
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)791 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
792 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
793 return Operand == Expr->getOperand()
794 ? Expr
795 : SE.getZeroExtendExpr(Operand, Expr->getType());
796 }
797
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)798 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
799 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
800 return Operand == Expr->getOperand()
801 ? Expr
802 : SE.getSignExtendExpr(Operand, Expr->getType());
803 }
804
visitAddExpr(const SCEVAddExpr * Expr)805 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
806 SmallVector<const SCEV *, 2> Operands;
807 bool Changed = false;
808 for (const auto *Op : Expr->operands()) {
809 Operands.push_back(((SC *)this)->visit(Op));
810 Changed |= Op != Operands.back();
811 }
812 return !Changed ? Expr : SE.getAddExpr(Operands);
813 }
814
visitMulExpr(const SCEVMulExpr * Expr)815 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
816 SmallVector<const SCEV *, 2> Operands;
817 bool Changed = false;
818 for (const auto *Op : Expr->operands()) {
819 Operands.push_back(((SC *)this)->visit(Op));
820 Changed |= Op != Operands.back();
821 }
822 return !Changed ? Expr : SE.getMulExpr(Operands);
823 }
824
visitUDivExpr(const SCEVUDivExpr * Expr)825 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
826 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
827 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
828 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
829 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
830 }
831
visitAddRecExpr(const SCEVAddRecExpr * Expr)832 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
833 SmallVector<const SCEV *, 2> Operands;
834 bool Changed = false;
835 for (const auto *Op : Expr->operands()) {
836 Operands.push_back(((SC *)this)->visit(Op));
837 Changed |= Op != Operands.back();
838 }
839 return !Changed ? Expr
840 : SE.getAddRecExpr(Operands, Expr->getLoop(),
841 Expr->getNoWrapFlags());
842 }
843
visitSMaxExpr(const SCEVSMaxExpr * Expr)844 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
845 SmallVector<const SCEV *, 2> Operands;
846 bool Changed = false;
847 for (const auto *Op : Expr->operands()) {
848 Operands.push_back(((SC *)this)->visit(Op));
849 Changed |= Op != Operands.back();
850 }
851 return !Changed ? Expr : SE.getSMaxExpr(Operands);
852 }
853
visitUMaxExpr(const SCEVUMaxExpr * Expr)854 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
855 SmallVector<const SCEV *, 2> Operands;
856 bool Changed = false;
857 for (const auto *Op : Expr->operands()) {
858 Operands.push_back(((SC *)this)->visit(Op));
859 Changed |= Op != Operands.back();
860 }
861 return !Changed ? Expr : SE.getUMaxExpr(Operands);
862 }
863
visitSMinExpr(const SCEVSMinExpr * Expr)864 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
865 SmallVector<const SCEV *, 2> Operands;
866 bool Changed = false;
867 for (const auto *Op : Expr->operands()) {
868 Operands.push_back(((SC *)this)->visit(Op));
869 Changed |= Op != Operands.back();
870 }
871 return !Changed ? Expr : SE.getSMinExpr(Operands);
872 }
873
visitUMinExpr(const SCEVUMinExpr * Expr)874 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
875 SmallVector<const SCEV *, 2> Operands;
876 bool Changed = false;
877 for (const auto *Op : Expr->operands()) {
878 Operands.push_back(((SC *)this)->visit(Op));
879 Changed |= Op != Operands.back();
880 }
881 return !Changed ? Expr : SE.getUMinExpr(Operands);
882 }
883
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)884 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
885 SmallVector<const SCEV *, 2> Operands;
886 bool Changed = false;
887 for (const auto *Op : Expr->operands()) {
888 Operands.push_back(((SC *)this)->visit(Op));
889 Changed |= Op != Operands.back();
890 }
891 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
892 }
893
visitUnknown(const SCEVUnknown * Expr)894 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
895
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)896 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
897 return Expr;
898 }
899 };
900
901 using ValueToValueMap = DenseMap<const Value *, Value *>;
902 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
903
904 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
905 /// the SCEVUnknown components following the Map (Value -> SCEV).
906 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
907 public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)908 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
909 ValueToSCEVMapTy &Map) {
910 SCEVParameterRewriter Rewriter(SE, Map);
911 return Rewriter.visit(Scev);
912 }
913
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)914 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
915 : SCEVRewriteVisitor(SE), Map(M) {}
916
visitUnknown(const SCEVUnknown * Expr)917 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
918 auto I = Map.find(Expr->getValue());
919 if (I == Map.end())
920 return Expr;
921 return I->second;
922 }
923
924 private:
925 ValueToSCEVMapTy ⤅
926 };
927
928 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
929
930 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
931 /// the Map (Loop -> SCEV) to all AddRecExprs.
932 class SCEVLoopAddRecRewriter
933 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
934 public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)935 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
936 : SCEVRewriteVisitor(SE), Map(M) {}
937
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)938 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
939 ScalarEvolution &SE) {
940 SCEVLoopAddRecRewriter Rewriter(SE, Map);
941 return Rewriter.visit(Scev);
942 }
943
visitAddRecExpr(const SCEVAddRecExpr * Expr)944 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
945 SmallVector<const SCEV *, 2> Operands;
946 for (const SCEV *Op : Expr->operands())
947 Operands.push_back(visit(Op));
948
949 const Loop *L = Expr->getLoop();
950 auto It = Map.find(L);
951 if (It == Map.end())
952 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
953
954 return SCEVAddRecExpr::evaluateAtIteration(Operands, It->second, SE);
955 }
956
957 private:
958 LoopToScevMapT ⤅
959 };
960
961 } // end namespace llvm
962
963 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
964