xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp (revision 5ffd83dbcc34f10e07f6d3e968ae6365869615f4)
1*5ffd83dbSDimitry Andric //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2*5ffd83dbSDimitry Andric //
3*5ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*5ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*5ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*5ffd83dbSDimitry Andric //
7*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
8*5ffd83dbSDimitry Andric //
9*5ffd83dbSDimitry Andric // This file defines the class that knows how to divide SCEV's.
10*5ffd83dbSDimitry Andric //
11*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
12*5ffd83dbSDimitry Andric 
13*5ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolutionDivision.h"
14*5ffd83dbSDimitry Andric #include "llvm/ADT/APInt.h"
15*5ffd83dbSDimitry Andric #include "llvm/ADT/DenseMap.h"
16*5ffd83dbSDimitry Andric #include "llvm/ADT/SmallVector.h"
17*5ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
18*5ffd83dbSDimitry Andric #include "llvm/IR/Constants.h"
19*5ffd83dbSDimitry Andric #include "llvm/Support/Casting.h"
20*5ffd83dbSDimitry Andric #include "llvm/Support/ErrorHandling.h"
21*5ffd83dbSDimitry Andric #include <cassert>
22*5ffd83dbSDimitry Andric #include <cstdint>
23*5ffd83dbSDimitry Andric 
24*5ffd83dbSDimitry Andric namespace llvm {
25*5ffd83dbSDimitry Andric class Type;
26*5ffd83dbSDimitry Andric }
27*5ffd83dbSDimitry Andric 
28*5ffd83dbSDimitry Andric using namespace llvm;
29*5ffd83dbSDimitry Andric 
30*5ffd83dbSDimitry Andric namespace {
31*5ffd83dbSDimitry Andric 
32*5ffd83dbSDimitry Andric static inline int sizeOfSCEV(const SCEV *S) {
33*5ffd83dbSDimitry Andric   struct FindSCEVSize {
34*5ffd83dbSDimitry Andric     int Size = 0;
35*5ffd83dbSDimitry Andric 
36*5ffd83dbSDimitry Andric     FindSCEVSize() = default;
37*5ffd83dbSDimitry Andric 
38*5ffd83dbSDimitry Andric     bool follow(const SCEV *S) {
39*5ffd83dbSDimitry Andric       ++Size;
40*5ffd83dbSDimitry Andric       // Keep looking at all operands of S.
41*5ffd83dbSDimitry Andric       return true;
42*5ffd83dbSDimitry Andric     }
43*5ffd83dbSDimitry Andric 
44*5ffd83dbSDimitry Andric     bool isDone() const { return false; }
45*5ffd83dbSDimitry Andric   };
46*5ffd83dbSDimitry Andric 
47*5ffd83dbSDimitry Andric   FindSCEVSize F;
48*5ffd83dbSDimitry Andric   SCEVTraversal<FindSCEVSize> ST(F);
49*5ffd83dbSDimitry Andric   ST.visitAll(S);
50*5ffd83dbSDimitry Andric   return F.Size;
51*5ffd83dbSDimitry Andric }
52*5ffd83dbSDimitry Andric 
53*5ffd83dbSDimitry Andric } // namespace
54*5ffd83dbSDimitry Andric 
55*5ffd83dbSDimitry Andric // Computes the Quotient and Remainder of the division of Numerator by
56*5ffd83dbSDimitry Andric // Denominator.
57*5ffd83dbSDimitry Andric void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
58*5ffd83dbSDimitry Andric                           const SCEV *Denominator, const SCEV **Quotient,
59*5ffd83dbSDimitry Andric                           const SCEV **Remainder) {
60*5ffd83dbSDimitry Andric   assert(Numerator && Denominator && "Uninitialized SCEV");
61*5ffd83dbSDimitry Andric 
62*5ffd83dbSDimitry Andric   SCEVDivision D(SE, Numerator, Denominator);
63*5ffd83dbSDimitry Andric 
64*5ffd83dbSDimitry Andric   // Check for the trivial case here to avoid having to check for it in the
65*5ffd83dbSDimitry Andric   // rest of the code.
66*5ffd83dbSDimitry Andric   if (Numerator == Denominator) {
67*5ffd83dbSDimitry Andric     *Quotient = D.One;
68*5ffd83dbSDimitry Andric     *Remainder = D.Zero;
69*5ffd83dbSDimitry Andric     return;
70*5ffd83dbSDimitry Andric   }
71*5ffd83dbSDimitry Andric 
72*5ffd83dbSDimitry Andric   if (Numerator->isZero()) {
73*5ffd83dbSDimitry Andric     *Quotient = D.Zero;
74*5ffd83dbSDimitry Andric     *Remainder = D.Zero;
75*5ffd83dbSDimitry Andric     return;
76*5ffd83dbSDimitry Andric   }
77*5ffd83dbSDimitry Andric 
78*5ffd83dbSDimitry Andric   // A simple case when N/1. The quotient is N.
79*5ffd83dbSDimitry Andric   if (Denominator->isOne()) {
80*5ffd83dbSDimitry Andric     *Quotient = Numerator;
81*5ffd83dbSDimitry Andric     *Remainder = D.Zero;
82*5ffd83dbSDimitry Andric     return;
83*5ffd83dbSDimitry Andric   }
84*5ffd83dbSDimitry Andric 
85*5ffd83dbSDimitry Andric   // Split the Denominator when it is a product.
86*5ffd83dbSDimitry Andric   if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
87*5ffd83dbSDimitry Andric     const SCEV *Q, *R;
88*5ffd83dbSDimitry Andric     *Quotient = Numerator;
89*5ffd83dbSDimitry Andric     for (const SCEV *Op : T->operands()) {
90*5ffd83dbSDimitry Andric       divide(SE, *Quotient, Op, &Q, &R);
91*5ffd83dbSDimitry Andric       *Quotient = Q;
92*5ffd83dbSDimitry Andric 
93*5ffd83dbSDimitry Andric       // Bail out when the Numerator is not divisible by one of the terms of
94*5ffd83dbSDimitry Andric       // the Denominator.
95*5ffd83dbSDimitry Andric       if (!R->isZero()) {
96*5ffd83dbSDimitry Andric         *Quotient = D.Zero;
97*5ffd83dbSDimitry Andric         *Remainder = Numerator;
98*5ffd83dbSDimitry Andric         return;
99*5ffd83dbSDimitry Andric       }
100*5ffd83dbSDimitry Andric     }
101*5ffd83dbSDimitry Andric     *Remainder = D.Zero;
102*5ffd83dbSDimitry Andric     return;
103*5ffd83dbSDimitry Andric   }
104*5ffd83dbSDimitry Andric 
105*5ffd83dbSDimitry Andric   D.visit(Numerator);
106*5ffd83dbSDimitry Andric   *Quotient = D.Quotient;
107*5ffd83dbSDimitry Andric   *Remainder = D.Remainder;
108*5ffd83dbSDimitry Andric }
109*5ffd83dbSDimitry Andric 
110*5ffd83dbSDimitry Andric void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
111*5ffd83dbSDimitry Andric   if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
112*5ffd83dbSDimitry Andric     APInt NumeratorVal = Numerator->getAPInt();
113*5ffd83dbSDimitry Andric     APInt DenominatorVal = D->getAPInt();
114*5ffd83dbSDimitry Andric     uint32_t NumeratorBW = NumeratorVal.getBitWidth();
115*5ffd83dbSDimitry Andric     uint32_t DenominatorBW = DenominatorVal.getBitWidth();
116*5ffd83dbSDimitry Andric 
117*5ffd83dbSDimitry Andric     if (NumeratorBW > DenominatorBW)
118*5ffd83dbSDimitry Andric       DenominatorVal = DenominatorVal.sext(NumeratorBW);
119*5ffd83dbSDimitry Andric     else if (NumeratorBW < DenominatorBW)
120*5ffd83dbSDimitry Andric       NumeratorVal = NumeratorVal.sext(DenominatorBW);
121*5ffd83dbSDimitry Andric 
122*5ffd83dbSDimitry Andric     APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
123*5ffd83dbSDimitry Andric     APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
124*5ffd83dbSDimitry Andric     APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
125*5ffd83dbSDimitry Andric     Quotient = SE.getConstant(QuotientVal);
126*5ffd83dbSDimitry Andric     Remainder = SE.getConstant(RemainderVal);
127*5ffd83dbSDimitry Andric     return;
128*5ffd83dbSDimitry Andric   }
129*5ffd83dbSDimitry Andric }
130*5ffd83dbSDimitry Andric 
131*5ffd83dbSDimitry Andric void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
132*5ffd83dbSDimitry Andric   const SCEV *StartQ, *StartR, *StepQ, *StepR;
133*5ffd83dbSDimitry Andric   if (!Numerator->isAffine())
134*5ffd83dbSDimitry Andric     return cannotDivide(Numerator);
135*5ffd83dbSDimitry Andric   divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
136*5ffd83dbSDimitry Andric   divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
137*5ffd83dbSDimitry Andric   // Bail out if the types do not match.
138*5ffd83dbSDimitry Andric   Type *Ty = Denominator->getType();
139*5ffd83dbSDimitry Andric   if (Ty != StartQ->getType() || Ty != StartR->getType() ||
140*5ffd83dbSDimitry Andric       Ty != StepQ->getType() || Ty != StepR->getType())
141*5ffd83dbSDimitry Andric     return cannotDivide(Numerator);
142*5ffd83dbSDimitry Andric   Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
143*5ffd83dbSDimitry Andric                               Numerator->getNoWrapFlags());
144*5ffd83dbSDimitry Andric   Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
145*5ffd83dbSDimitry Andric                                Numerator->getNoWrapFlags());
146*5ffd83dbSDimitry Andric }
147*5ffd83dbSDimitry Andric 
148*5ffd83dbSDimitry Andric void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
149*5ffd83dbSDimitry Andric   SmallVector<const SCEV *, 2> Qs, Rs;
150*5ffd83dbSDimitry Andric   Type *Ty = Denominator->getType();
151*5ffd83dbSDimitry Andric 
152*5ffd83dbSDimitry Andric   for (const SCEV *Op : Numerator->operands()) {
153*5ffd83dbSDimitry Andric     const SCEV *Q, *R;
154*5ffd83dbSDimitry Andric     divide(SE, Op, Denominator, &Q, &R);
155*5ffd83dbSDimitry Andric 
156*5ffd83dbSDimitry Andric     // Bail out if types do not match.
157*5ffd83dbSDimitry Andric     if (Ty != Q->getType() || Ty != R->getType())
158*5ffd83dbSDimitry Andric       return cannotDivide(Numerator);
159*5ffd83dbSDimitry Andric 
160*5ffd83dbSDimitry Andric     Qs.push_back(Q);
161*5ffd83dbSDimitry Andric     Rs.push_back(R);
162*5ffd83dbSDimitry Andric   }
163*5ffd83dbSDimitry Andric 
164*5ffd83dbSDimitry Andric   if (Qs.size() == 1) {
165*5ffd83dbSDimitry Andric     Quotient = Qs[0];
166*5ffd83dbSDimitry Andric     Remainder = Rs[0];
167*5ffd83dbSDimitry Andric     return;
168*5ffd83dbSDimitry Andric   }
169*5ffd83dbSDimitry Andric 
170*5ffd83dbSDimitry Andric   Quotient = SE.getAddExpr(Qs);
171*5ffd83dbSDimitry Andric   Remainder = SE.getAddExpr(Rs);
172*5ffd83dbSDimitry Andric }
173*5ffd83dbSDimitry Andric 
174*5ffd83dbSDimitry Andric void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
175*5ffd83dbSDimitry Andric   SmallVector<const SCEV *, 2> Qs;
176*5ffd83dbSDimitry Andric   Type *Ty = Denominator->getType();
177*5ffd83dbSDimitry Andric 
178*5ffd83dbSDimitry Andric   bool FoundDenominatorTerm = false;
179*5ffd83dbSDimitry Andric   for (const SCEV *Op : Numerator->operands()) {
180*5ffd83dbSDimitry Andric     // Bail out if types do not match.
181*5ffd83dbSDimitry Andric     if (Ty != Op->getType())
182*5ffd83dbSDimitry Andric       return cannotDivide(Numerator);
183*5ffd83dbSDimitry Andric 
184*5ffd83dbSDimitry Andric     if (FoundDenominatorTerm) {
185*5ffd83dbSDimitry Andric       Qs.push_back(Op);
186*5ffd83dbSDimitry Andric       continue;
187*5ffd83dbSDimitry Andric     }
188*5ffd83dbSDimitry Andric 
189*5ffd83dbSDimitry Andric     // Check whether Denominator divides one of the product operands.
190*5ffd83dbSDimitry Andric     const SCEV *Q, *R;
191*5ffd83dbSDimitry Andric     divide(SE, Op, Denominator, &Q, &R);
192*5ffd83dbSDimitry Andric     if (!R->isZero()) {
193*5ffd83dbSDimitry Andric       Qs.push_back(Op);
194*5ffd83dbSDimitry Andric       continue;
195*5ffd83dbSDimitry Andric     }
196*5ffd83dbSDimitry Andric 
197*5ffd83dbSDimitry Andric     // Bail out if types do not match.
198*5ffd83dbSDimitry Andric     if (Ty != Q->getType())
199*5ffd83dbSDimitry Andric       return cannotDivide(Numerator);
200*5ffd83dbSDimitry Andric 
201*5ffd83dbSDimitry Andric     FoundDenominatorTerm = true;
202*5ffd83dbSDimitry Andric     Qs.push_back(Q);
203*5ffd83dbSDimitry Andric   }
204*5ffd83dbSDimitry Andric 
205*5ffd83dbSDimitry Andric   if (FoundDenominatorTerm) {
206*5ffd83dbSDimitry Andric     Remainder = Zero;
207*5ffd83dbSDimitry Andric     if (Qs.size() == 1)
208*5ffd83dbSDimitry Andric       Quotient = Qs[0];
209*5ffd83dbSDimitry Andric     else
210*5ffd83dbSDimitry Andric       Quotient = SE.getMulExpr(Qs);
211*5ffd83dbSDimitry Andric     return;
212*5ffd83dbSDimitry Andric   }
213*5ffd83dbSDimitry Andric 
214*5ffd83dbSDimitry Andric   if (!isa<SCEVUnknown>(Denominator))
215*5ffd83dbSDimitry Andric     return cannotDivide(Numerator);
216*5ffd83dbSDimitry Andric 
217*5ffd83dbSDimitry Andric   // The Remainder is obtained by replacing Denominator by 0 in Numerator.
218*5ffd83dbSDimitry Andric   ValueToValueMap RewriteMap;
219*5ffd83dbSDimitry Andric   RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
220*5ffd83dbSDimitry Andric       cast<SCEVConstant>(Zero)->getValue();
221*5ffd83dbSDimitry Andric   Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
222*5ffd83dbSDimitry Andric 
223*5ffd83dbSDimitry Andric   if (Remainder->isZero()) {
224*5ffd83dbSDimitry Andric     // The Quotient is obtained by replacing Denominator by 1 in Numerator.
225*5ffd83dbSDimitry Andric     RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
226*5ffd83dbSDimitry Andric         cast<SCEVConstant>(One)->getValue();
227*5ffd83dbSDimitry Andric     Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
228*5ffd83dbSDimitry Andric     return;
229*5ffd83dbSDimitry Andric   }
230*5ffd83dbSDimitry Andric 
231*5ffd83dbSDimitry Andric   // Quotient is (Numerator - Remainder) divided by Denominator.
232*5ffd83dbSDimitry Andric   const SCEV *Q, *R;
233*5ffd83dbSDimitry Andric   const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
234*5ffd83dbSDimitry Andric   // This SCEV does not seem to simplify: fail the division here.
235*5ffd83dbSDimitry Andric   if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
236*5ffd83dbSDimitry Andric     return cannotDivide(Numerator);
237*5ffd83dbSDimitry Andric   divide(SE, Diff, Denominator, &Q, &R);
238*5ffd83dbSDimitry Andric   if (R != Zero)
239*5ffd83dbSDimitry Andric     return cannotDivide(Numerator);
240*5ffd83dbSDimitry Andric   Quotient = Q;
241*5ffd83dbSDimitry Andric }
242*5ffd83dbSDimitry Andric 
243*5ffd83dbSDimitry Andric SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
244*5ffd83dbSDimitry Andric                            const SCEV *Denominator)
245*5ffd83dbSDimitry Andric     : SE(S), Denominator(Denominator) {
246*5ffd83dbSDimitry Andric   Zero = SE.getZero(Denominator->getType());
247*5ffd83dbSDimitry Andric   One = SE.getOne(Denominator->getType());
248*5ffd83dbSDimitry Andric 
249*5ffd83dbSDimitry Andric   // We generally do not know how to divide Expr by Denominator. We initialize
250*5ffd83dbSDimitry Andric   // the division to a "cannot divide" state to simplify the rest of the code.
251*5ffd83dbSDimitry Andric   cannotDivide(Numerator);
252*5ffd83dbSDimitry Andric }
253*5ffd83dbSDimitry Andric 
254*5ffd83dbSDimitry Andric // Convenience function for giving up on the division. We set the quotient to
255*5ffd83dbSDimitry Andric // be equal to zero and the remainder to be equal to the numerator.
256*5ffd83dbSDimitry Andric void SCEVDivision::cannotDivide(const SCEV *Numerator) {
257*5ffd83dbSDimitry Andric   Quotient = Zero;
258*5ffd83dbSDimitry Andric   Remainder = Numerator;
259*5ffd83dbSDimitry Andric }
260