xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp (revision c9539b89010900499a200cdd6c0265ea5d950875)
1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Support/Casting.h"
19 #include <cassert>
20 #include <cstdint>
21 
22 namespace llvm {
23 class Type;
24 }
25 
26 using namespace llvm;
27 
28 namespace {
29 
30 static inline int sizeOfSCEV(const SCEV *S) {
31   struct FindSCEVSize {
32     int Size = 0;
33 
34     FindSCEVSize() = default;
35 
36     bool follow(const SCEV *S) {
37       ++Size;
38       // Keep looking at all operands of S.
39       return true;
40     }
41 
42     bool isDone() const { return false; }
43   };
44 
45   FindSCEVSize F;
46   SCEVTraversal<FindSCEVSize> ST(F);
47   ST.visitAll(S);
48   return F.Size;
49 }
50 
51 } // namespace
52 
53 // Computes the Quotient and Remainder of the division of Numerator by
54 // Denominator.
55 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56                           const SCEV *Denominator, const SCEV **Quotient,
57                           const SCEV **Remainder) {
58   assert(Numerator && Denominator && "Uninitialized SCEV");
59 
60   SCEVDivision D(SE, Numerator, Denominator);
61 
62   // Check for the trivial case here to avoid having to check for it in the
63   // rest of the code.
64   if (Numerator == Denominator) {
65     *Quotient = D.One;
66     *Remainder = D.Zero;
67     return;
68   }
69 
70   if (Numerator->isZero()) {
71     *Quotient = D.Zero;
72     *Remainder = D.Zero;
73     return;
74   }
75 
76   // A simple case when N/1. The quotient is N.
77   if (Denominator->isOne()) {
78     *Quotient = Numerator;
79     *Remainder = D.Zero;
80     return;
81   }
82 
83   // Split the Denominator when it is a product.
84   if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
85     const SCEV *Q, *R;
86     *Quotient = Numerator;
87     for (const SCEV *Op : T->operands()) {
88       divide(SE, *Quotient, Op, &Q, &R);
89       *Quotient = Q;
90 
91       // Bail out when the Numerator is not divisible by one of the terms of
92       // the Denominator.
93       if (!R->isZero()) {
94         *Quotient = D.Zero;
95         *Remainder = Numerator;
96         return;
97       }
98     }
99     *Remainder = D.Zero;
100     return;
101   }
102 
103   D.visit(Numerator);
104   *Quotient = D.Quotient;
105   *Remainder = D.Remainder;
106 }
107 
108 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
109   if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
110     APInt NumeratorVal = Numerator->getAPInt();
111     APInt DenominatorVal = D->getAPInt();
112     uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113     uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114 
115     if (NumeratorBW > DenominatorBW)
116       DenominatorVal = DenominatorVal.sext(NumeratorBW);
117     else if (NumeratorBW < DenominatorBW)
118       NumeratorVal = NumeratorVal.sext(DenominatorBW);
119 
120     APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121     APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122     APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
123     Quotient = SE.getConstant(QuotientVal);
124     Remainder = SE.getConstant(RemainderVal);
125     return;
126   }
127 }
128 
129 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
130   const SCEV *StartQ, *StartR, *StepQ, *StepR;
131   if (!Numerator->isAffine())
132     return cannotDivide(Numerator);
133   divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
134   divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
135   // Bail out if the types do not match.
136   Type *Ty = Denominator->getType();
137   if (Ty != StartQ->getType() || Ty != StartR->getType() ||
138       Ty != StepQ->getType() || Ty != StepR->getType())
139     return cannotDivide(Numerator);
140   Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
141                               Numerator->getNoWrapFlags());
142   Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
143                                Numerator->getNoWrapFlags());
144 }
145 
146 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
147   SmallVector<const SCEV *, 2> Qs, Rs;
148   Type *Ty = Denominator->getType();
149 
150   for (const SCEV *Op : Numerator->operands()) {
151     const SCEV *Q, *R;
152     divide(SE, Op, Denominator, &Q, &R);
153 
154     // Bail out if types do not match.
155     if (Ty != Q->getType() || Ty != R->getType())
156       return cannotDivide(Numerator);
157 
158     Qs.push_back(Q);
159     Rs.push_back(R);
160   }
161 
162   if (Qs.size() == 1) {
163     Quotient = Qs[0];
164     Remainder = Rs[0];
165     return;
166   }
167 
168   Quotient = SE.getAddExpr(Qs);
169   Remainder = SE.getAddExpr(Rs);
170 }
171 
172 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
173   SmallVector<const SCEV *, 2> Qs;
174   Type *Ty = Denominator->getType();
175 
176   bool FoundDenominatorTerm = false;
177   for (const SCEV *Op : Numerator->operands()) {
178     // Bail out if types do not match.
179     if (Ty != Op->getType())
180       return cannotDivide(Numerator);
181 
182     if (FoundDenominatorTerm) {
183       Qs.push_back(Op);
184       continue;
185     }
186 
187     // Check whether Denominator divides one of the product operands.
188     const SCEV *Q, *R;
189     divide(SE, Op, Denominator, &Q, &R);
190     if (!R->isZero()) {
191       Qs.push_back(Op);
192       continue;
193     }
194 
195     // Bail out if types do not match.
196     if (Ty != Q->getType())
197       return cannotDivide(Numerator);
198 
199     FoundDenominatorTerm = true;
200     Qs.push_back(Q);
201   }
202 
203   if (FoundDenominatorTerm) {
204     Remainder = Zero;
205     if (Qs.size() == 1)
206       Quotient = Qs[0];
207     else
208       Quotient = SE.getMulExpr(Qs);
209     return;
210   }
211 
212   if (!isa<SCEVUnknown>(Denominator))
213     return cannotDivide(Numerator);
214 
215   // The Remainder is obtained by replacing Denominator by 0 in Numerator.
216   ValueToSCEVMapTy RewriteMap;
217   RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
218   Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
219 
220   if (Remainder->isZero()) {
221     // The Quotient is obtained by replacing Denominator by 1 in Numerator.
222     RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
223     Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
224     return;
225   }
226 
227   // Quotient is (Numerator - Remainder) divided by Denominator.
228   const SCEV *Q, *R;
229   const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
230   // This SCEV does not seem to simplify: fail the division here.
231   if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
232     return cannotDivide(Numerator);
233   divide(SE, Diff, Denominator, &Q, &R);
234   if (R != Zero)
235     return cannotDivide(Numerator);
236   Quotient = Q;
237 }
238 
239 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
240                            const SCEV *Denominator)
241     : SE(S), Denominator(Denominator) {
242   Zero = SE.getZero(Denominator->getType());
243   One = SE.getOne(Denominator->getType());
244 
245   // We generally do not know how to divide Expr by Denominator. We initialize
246   // the division to a "cannot divide" state to simplify the rest of the code.
247   cannotDivide(Numerator);
248 }
249 
250 // Convenience function for giving up on the division. We set the quotient to
251 // be equal to zero and the remainder to be equal to the numerator.
252 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
253   Quotient = Zero;
254   Remainder = Numerator;
255 }
256