1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Support/Casting.h"
19 #include <cassert>
20 #include <cstdint>
21
22 namespace llvm {
23 class Type;
24 } // namespace llvm
25
26 using namespace llvm;
27
28 namespace {
29
sizeOfSCEV(const SCEV * S)30 static inline int sizeOfSCEV(const SCEV *S) {
31 struct FindSCEVSize {
32 int Size = 0;
33
34 FindSCEVSize() = default;
35
36 bool follow(const SCEV *S) {
37 ++Size;
38 // Keep looking at all operands of S.
39 return true;
40 }
41
42 bool isDone() const { return false; }
43 };
44
45 FindSCEVSize F;
46 SCEVTraversal<FindSCEVSize> ST(F);
47 ST.visitAll(S);
48 return F.Size;
49 }
50
51 } // namespace
52
53 // Computes the Quotient and Remainder of the division of Numerator by
54 // Denominator.
divide(ScalarEvolution & SE,const SCEV * Numerator,const SCEV * Denominator,const SCEV ** Quotient,const SCEV ** Remainder)55 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56 const SCEV *Denominator, const SCEV **Quotient,
57 const SCEV **Remainder) {
58 assert(Numerator && Denominator && "Uninitialized SCEV");
59
60 SCEVDivision D(SE, Numerator, Denominator);
61
62 // Check for the trivial case here to avoid having to check for it in the
63 // rest of the code.
64 if (Numerator == Denominator) {
65 *Quotient = D.One;
66 *Remainder = D.Zero;
67 return;
68 }
69
70 if (Numerator->isZero()) {
71 *Quotient = D.Zero;
72 *Remainder = D.Zero;
73 return;
74 }
75
76 // A simple case when N/1. The quotient is N.
77 if (Denominator->isOne()) {
78 *Quotient = Numerator;
79 *Remainder = D.Zero;
80 return;
81 }
82
83 // Split the Denominator when it is a product.
84 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
85 const SCEV *Q, *R;
86 *Quotient = Numerator;
87 for (const SCEV *Op : T->operands()) {
88 divide(SE, *Quotient, Op, &Q, &R);
89 *Quotient = Q;
90
91 // Bail out when the Numerator is not divisible by one of the terms of
92 // the Denominator.
93 if (!R->isZero()) {
94 *Quotient = D.Zero;
95 *Remainder = Numerator;
96 return;
97 }
98 }
99 *Remainder = D.Zero;
100 return;
101 }
102
103 D.visit(Numerator);
104 *Quotient = D.Quotient;
105 *Remainder = D.Remainder;
106 }
107
visitConstant(const SCEVConstant * Numerator)108 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
109 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
110 APInt NumeratorVal = Numerator->getAPInt();
111 APInt DenominatorVal = D->getAPInt();
112 uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113 uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114
115 if (NumeratorBW > DenominatorBW)
116 DenominatorVal = DenominatorVal.sext(NumeratorBW);
117 else if (NumeratorBW < DenominatorBW)
118 NumeratorVal = NumeratorVal.sext(DenominatorBW);
119
120 APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121 APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
123 Quotient = SE.getConstant(QuotientVal);
124 Remainder = SE.getConstant(RemainderVal);
125 return;
126 }
127 }
128
visitVScale(const SCEVVScale * Numerator)129 void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
130 return cannotDivide(Numerator);
131 }
132
visitAddRecExpr(const SCEVAddRecExpr * Numerator)133 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
134 const SCEV *StartQ, *StartR, *StepQ, *StepR;
135 if (!Numerator->isAffine())
136 return cannotDivide(Numerator);
137 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
138 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
139 // Bail out if the types do not match.
140 Type *Ty = Denominator->getType();
141 if (Ty != StartQ->getType() || Ty != StartR->getType() ||
142 Ty != StepQ->getType() || Ty != StepR->getType())
143 return cannotDivide(Numerator);
144 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
145 Numerator->getNoWrapFlags());
146 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
147 Numerator->getNoWrapFlags());
148 }
149
visitAddExpr(const SCEVAddExpr * Numerator)150 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
151 SmallVector<const SCEV *, 2> Qs, Rs;
152 Type *Ty = Denominator->getType();
153
154 for (const SCEV *Op : Numerator->operands()) {
155 const SCEV *Q, *R;
156 divide(SE, Op, Denominator, &Q, &R);
157
158 // Bail out if types do not match.
159 if (Ty != Q->getType() || Ty != R->getType())
160 return cannotDivide(Numerator);
161
162 Qs.push_back(Q);
163 Rs.push_back(R);
164 }
165
166 if (Qs.size() == 1) {
167 Quotient = Qs[0];
168 Remainder = Rs[0];
169 return;
170 }
171
172 Quotient = SE.getAddExpr(Qs);
173 Remainder = SE.getAddExpr(Rs);
174 }
175
visitMulExpr(const SCEVMulExpr * Numerator)176 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
177 SmallVector<const SCEV *, 2> Qs;
178 Type *Ty = Denominator->getType();
179
180 bool FoundDenominatorTerm = false;
181 for (const SCEV *Op : Numerator->operands()) {
182 // Bail out if types do not match.
183 if (Ty != Op->getType())
184 return cannotDivide(Numerator);
185
186 if (FoundDenominatorTerm) {
187 Qs.push_back(Op);
188 continue;
189 }
190
191 // Check whether Denominator divides one of the product operands.
192 const SCEV *Q, *R;
193 divide(SE, Op, Denominator, &Q, &R);
194 if (!R->isZero()) {
195 Qs.push_back(Op);
196 continue;
197 }
198
199 // Bail out if types do not match.
200 if (Ty != Q->getType())
201 return cannotDivide(Numerator);
202
203 FoundDenominatorTerm = true;
204 Qs.push_back(Q);
205 }
206
207 if (FoundDenominatorTerm) {
208 Remainder = Zero;
209 if (Qs.size() == 1)
210 Quotient = Qs[0];
211 else
212 Quotient = SE.getMulExpr(Qs);
213 return;
214 }
215
216 if (!isa<SCEVUnknown>(Denominator))
217 return cannotDivide(Numerator);
218
219 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
220 ValueToSCEVMapTy RewriteMap;
221 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
222 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
223
224 if (Remainder->isZero()) {
225 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
226 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
227 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
228 return;
229 }
230
231 // Quotient is (Numerator - Remainder) divided by Denominator.
232 const SCEV *Q, *R;
233 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
234 // This SCEV does not seem to simplify: fail the division here.
235 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
236 return cannotDivide(Numerator);
237 divide(SE, Diff, Denominator, &Q, &R);
238 if (R != Zero)
239 return cannotDivide(Numerator);
240 Quotient = Q;
241 }
242
SCEVDivision(ScalarEvolution & S,const SCEV * Numerator,const SCEV * Denominator)243 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
244 const SCEV *Denominator)
245 : SE(S), Denominator(Denominator) {
246 Zero = SE.getZero(Denominator->getType());
247 One = SE.getOne(Denominator->getType());
248
249 // We generally do not know how to divide Expr by Denominator. We initialize
250 // the division to a "cannot divide" state to simplify the rest of the code.
251 cannotDivide(Numerator);
252 }
253
254 // Convenience function for giving up on the division. We set the quotient to
255 // be equal to zero and the remainder to be equal to the numerator.
cannotDivide(const SCEV * Numerator)256 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
257 Quotient = Zero;
258 Remainder = Numerator;
259 }
260