1*5ffd83dbSDimitry Andric //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// 2*5ffd83dbSDimitry Andric // 3*5ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*5ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*5ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*5ffd83dbSDimitry Andric // 7*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 8*5ffd83dbSDimitry Andric // 9*5ffd83dbSDimitry Andric // This file defines the class that knows how to divide SCEV's. 10*5ffd83dbSDimitry Andric // 11*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 12*5ffd83dbSDimitry Andric 13*5ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolutionDivision.h" 14*5ffd83dbSDimitry Andric #include "llvm/ADT/APInt.h" 15*5ffd83dbSDimitry Andric #include "llvm/ADT/DenseMap.h" 16*5ffd83dbSDimitry Andric #include "llvm/ADT/SmallVector.h" 17*5ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 18*5ffd83dbSDimitry Andric #include "llvm/IR/Constants.h" 19*5ffd83dbSDimitry Andric #include "llvm/Support/Casting.h" 20*5ffd83dbSDimitry Andric #include "llvm/Support/ErrorHandling.h" 21*5ffd83dbSDimitry Andric #include <cassert> 22*5ffd83dbSDimitry Andric #include <cstdint> 23*5ffd83dbSDimitry Andric 24*5ffd83dbSDimitry Andric namespace llvm { 25*5ffd83dbSDimitry Andric class Type; 26*5ffd83dbSDimitry Andric } 27*5ffd83dbSDimitry Andric 28*5ffd83dbSDimitry Andric using namespace llvm; 29*5ffd83dbSDimitry Andric 30*5ffd83dbSDimitry Andric namespace { 31*5ffd83dbSDimitry Andric 32*5ffd83dbSDimitry Andric static inline int sizeOfSCEV(const SCEV *S) { 33*5ffd83dbSDimitry Andric struct FindSCEVSize { 34*5ffd83dbSDimitry Andric int Size = 0; 35*5ffd83dbSDimitry Andric 36*5ffd83dbSDimitry Andric FindSCEVSize() = default; 37*5ffd83dbSDimitry Andric 38*5ffd83dbSDimitry Andric bool follow(const SCEV *S) { 39*5ffd83dbSDimitry Andric ++Size; 40*5ffd83dbSDimitry Andric // Keep looking at all operands of S. 41*5ffd83dbSDimitry Andric return true; 42*5ffd83dbSDimitry Andric } 43*5ffd83dbSDimitry Andric 44*5ffd83dbSDimitry Andric bool isDone() const { return false; } 45*5ffd83dbSDimitry Andric }; 46*5ffd83dbSDimitry Andric 47*5ffd83dbSDimitry Andric FindSCEVSize F; 48*5ffd83dbSDimitry Andric SCEVTraversal<FindSCEVSize> ST(F); 49*5ffd83dbSDimitry Andric ST.visitAll(S); 50*5ffd83dbSDimitry Andric return F.Size; 51*5ffd83dbSDimitry Andric } 52*5ffd83dbSDimitry Andric 53*5ffd83dbSDimitry Andric } // namespace 54*5ffd83dbSDimitry Andric 55*5ffd83dbSDimitry Andric // Computes the Quotient and Remainder of the division of Numerator by 56*5ffd83dbSDimitry Andric // Denominator. 57*5ffd83dbSDimitry Andric void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, 58*5ffd83dbSDimitry Andric const SCEV *Denominator, const SCEV **Quotient, 59*5ffd83dbSDimitry Andric const SCEV **Remainder) { 60*5ffd83dbSDimitry Andric assert(Numerator && Denominator && "Uninitialized SCEV"); 61*5ffd83dbSDimitry Andric 62*5ffd83dbSDimitry Andric SCEVDivision D(SE, Numerator, Denominator); 63*5ffd83dbSDimitry Andric 64*5ffd83dbSDimitry Andric // Check for the trivial case here to avoid having to check for it in the 65*5ffd83dbSDimitry Andric // rest of the code. 66*5ffd83dbSDimitry Andric if (Numerator == Denominator) { 67*5ffd83dbSDimitry Andric *Quotient = D.One; 68*5ffd83dbSDimitry Andric *Remainder = D.Zero; 69*5ffd83dbSDimitry Andric return; 70*5ffd83dbSDimitry Andric } 71*5ffd83dbSDimitry Andric 72*5ffd83dbSDimitry Andric if (Numerator->isZero()) { 73*5ffd83dbSDimitry Andric *Quotient = D.Zero; 74*5ffd83dbSDimitry Andric *Remainder = D.Zero; 75*5ffd83dbSDimitry Andric return; 76*5ffd83dbSDimitry Andric } 77*5ffd83dbSDimitry Andric 78*5ffd83dbSDimitry Andric // A simple case when N/1. The quotient is N. 79*5ffd83dbSDimitry Andric if (Denominator->isOne()) { 80*5ffd83dbSDimitry Andric *Quotient = Numerator; 81*5ffd83dbSDimitry Andric *Remainder = D.Zero; 82*5ffd83dbSDimitry Andric return; 83*5ffd83dbSDimitry Andric } 84*5ffd83dbSDimitry Andric 85*5ffd83dbSDimitry Andric // Split the Denominator when it is a product. 86*5ffd83dbSDimitry Andric if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 87*5ffd83dbSDimitry Andric const SCEV *Q, *R; 88*5ffd83dbSDimitry Andric *Quotient = Numerator; 89*5ffd83dbSDimitry Andric for (const SCEV *Op : T->operands()) { 90*5ffd83dbSDimitry Andric divide(SE, *Quotient, Op, &Q, &R); 91*5ffd83dbSDimitry Andric *Quotient = Q; 92*5ffd83dbSDimitry Andric 93*5ffd83dbSDimitry Andric // Bail out when the Numerator is not divisible by one of the terms of 94*5ffd83dbSDimitry Andric // the Denominator. 95*5ffd83dbSDimitry Andric if (!R->isZero()) { 96*5ffd83dbSDimitry Andric *Quotient = D.Zero; 97*5ffd83dbSDimitry Andric *Remainder = Numerator; 98*5ffd83dbSDimitry Andric return; 99*5ffd83dbSDimitry Andric } 100*5ffd83dbSDimitry Andric } 101*5ffd83dbSDimitry Andric *Remainder = D.Zero; 102*5ffd83dbSDimitry Andric return; 103*5ffd83dbSDimitry Andric } 104*5ffd83dbSDimitry Andric 105*5ffd83dbSDimitry Andric D.visit(Numerator); 106*5ffd83dbSDimitry Andric *Quotient = D.Quotient; 107*5ffd83dbSDimitry Andric *Remainder = D.Remainder; 108*5ffd83dbSDimitry Andric } 109*5ffd83dbSDimitry Andric 110*5ffd83dbSDimitry Andric void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { 111*5ffd83dbSDimitry Andric if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 112*5ffd83dbSDimitry Andric APInt NumeratorVal = Numerator->getAPInt(); 113*5ffd83dbSDimitry Andric APInt DenominatorVal = D->getAPInt(); 114*5ffd83dbSDimitry Andric uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 115*5ffd83dbSDimitry Andric uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 116*5ffd83dbSDimitry Andric 117*5ffd83dbSDimitry Andric if (NumeratorBW > DenominatorBW) 118*5ffd83dbSDimitry Andric DenominatorVal = DenominatorVal.sext(NumeratorBW); 119*5ffd83dbSDimitry Andric else if (NumeratorBW < DenominatorBW) 120*5ffd83dbSDimitry Andric NumeratorVal = NumeratorVal.sext(DenominatorBW); 121*5ffd83dbSDimitry Andric 122*5ffd83dbSDimitry Andric APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 123*5ffd83dbSDimitry Andric APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 124*5ffd83dbSDimitry Andric APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 125*5ffd83dbSDimitry Andric Quotient = SE.getConstant(QuotientVal); 126*5ffd83dbSDimitry Andric Remainder = SE.getConstant(RemainderVal); 127*5ffd83dbSDimitry Andric return; 128*5ffd83dbSDimitry Andric } 129*5ffd83dbSDimitry Andric } 130*5ffd83dbSDimitry Andric 131*5ffd83dbSDimitry Andric void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 132*5ffd83dbSDimitry Andric const SCEV *StartQ, *StartR, *StepQ, *StepR; 133*5ffd83dbSDimitry Andric if (!Numerator->isAffine()) 134*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 135*5ffd83dbSDimitry Andric divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 136*5ffd83dbSDimitry Andric divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 137*5ffd83dbSDimitry Andric // Bail out if the types do not match. 138*5ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 139*5ffd83dbSDimitry Andric if (Ty != StartQ->getType() || Ty != StartR->getType() || 140*5ffd83dbSDimitry Andric Ty != StepQ->getType() || Ty != StepR->getType()) 141*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 142*5ffd83dbSDimitry Andric Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 143*5ffd83dbSDimitry Andric Numerator->getNoWrapFlags()); 144*5ffd83dbSDimitry Andric Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 145*5ffd83dbSDimitry Andric Numerator->getNoWrapFlags()); 146*5ffd83dbSDimitry Andric } 147*5ffd83dbSDimitry Andric 148*5ffd83dbSDimitry Andric void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { 149*5ffd83dbSDimitry Andric SmallVector<const SCEV *, 2> Qs, Rs; 150*5ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 151*5ffd83dbSDimitry Andric 152*5ffd83dbSDimitry Andric for (const SCEV *Op : Numerator->operands()) { 153*5ffd83dbSDimitry Andric const SCEV *Q, *R; 154*5ffd83dbSDimitry Andric divide(SE, Op, Denominator, &Q, &R); 155*5ffd83dbSDimitry Andric 156*5ffd83dbSDimitry Andric // Bail out if types do not match. 157*5ffd83dbSDimitry Andric if (Ty != Q->getType() || Ty != R->getType()) 158*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 159*5ffd83dbSDimitry Andric 160*5ffd83dbSDimitry Andric Qs.push_back(Q); 161*5ffd83dbSDimitry Andric Rs.push_back(R); 162*5ffd83dbSDimitry Andric } 163*5ffd83dbSDimitry Andric 164*5ffd83dbSDimitry Andric if (Qs.size() == 1) { 165*5ffd83dbSDimitry Andric Quotient = Qs[0]; 166*5ffd83dbSDimitry Andric Remainder = Rs[0]; 167*5ffd83dbSDimitry Andric return; 168*5ffd83dbSDimitry Andric } 169*5ffd83dbSDimitry Andric 170*5ffd83dbSDimitry Andric Quotient = SE.getAddExpr(Qs); 171*5ffd83dbSDimitry Andric Remainder = SE.getAddExpr(Rs); 172*5ffd83dbSDimitry Andric } 173*5ffd83dbSDimitry Andric 174*5ffd83dbSDimitry Andric void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { 175*5ffd83dbSDimitry Andric SmallVector<const SCEV *, 2> Qs; 176*5ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 177*5ffd83dbSDimitry Andric 178*5ffd83dbSDimitry Andric bool FoundDenominatorTerm = false; 179*5ffd83dbSDimitry Andric for (const SCEV *Op : Numerator->operands()) { 180*5ffd83dbSDimitry Andric // Bail out if types do not match. 181*5ffd83dbSDimitry Andric if (Ty != Op->getType()) 182*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 183*5ffd83dbSDimitry Andric 184*5ffd83dbSDimitry Andric if (FoundDenominatorTerm) { 185*5ffd83dbSDimitry Andric Qs.push_back(Op); 186*5ffd83dbSDimitry Andric continue; 187*5ffd83dbSDimitry Andric } 188*5ffd83dbSDimitry Andric 189*5ffd83dbSDimitry Andric // Check whether Denominator divides one of the product operands. 190*5ffd83dbSDimitry Andric const SCEV *Q, *R; 191*5ffd83dbSDimitry Andric divide(SE, Op, Denominator, &Q, &R); 192*5ffd83dbSDimitry Andric if (!R->isZero()) { 193*5ffd83dbSDimitry Andric Qs.push_back(Op); 194*5ffd83dbSDimitry Andric continue; 195*5ffd83dbSDimitry Andric } 196*5ffd83dbSDimitry Andric 197*5ffd83dbSDimitry Andric // Bail out if types do not match. 198*5ffd83dbSDimitry Andric if (Ty != Q->getType()) 199*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 200*5ffd83dbSDimitry Andric 201*5ffd83dbSDimitry Andric FoundDenominatorTerm = true; 202*5ffd83dbSDimitry Andric Qs.push_back(Q); 203*5ffd83dbSDimitry Andric } 204*5ffd83dbSDimitry Andric 205*5ffd83dbSDimitry Andric if (FoundDenominatorTerm) { 206*5ffd83dbSDimitry Andric Remainder = Zero; 207*5ffd83dbSDimitry Andric if (Qs.size() == 1) 208*5ffd83dbSDimitry Andric Quotient = Qs[0]; 209*5ffd83dbSDimitry Andric else 210*5ffd83dbSDimitry Andric Quotient = SE.getMulExpr(Qs); 211*5ffd83dbSDimitry Andric return; 212*5ffd83dbSDimitry Andric } 213*5ffd83dbSDimitry Andric 214*5ffd83dbSDimitry Andric if (!isa<SCEVUnknown>(Denominator)) 215*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 216*5ffd83dbSDimitry Andric 217*5ffd83dbSDimitry Andric // The Remainder is obtained by replacing Denominator by 0 in Numerator. 218*5ffd83dbSDimitry Andric ValueToValueMap RewriteMap; 219*5ffd83dbSDimitry Andric RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 220*5ffd83dbSDimitry Andric cast<SCEVConstant>(Zero)->getValue(); 221*5ffd83dbSDimitry Andric Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 222*5ffd83dbSDimitry Andric 223*5ffd83dbSDimitry Andric if (Remainder->isZero()) { 224*5ffd83dbSDimitry Andric // The Quotient is obtained by replacing Denominator by 1 in Numerator. 225*5ffd83dbSDimitry Andric RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 226*5ffd83dbSDimitry Andric cast<SCEVConstant>(One)->getValue(); 227*5ffd83dbSDimitry Andric Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 228*5ffd83dbSDimitry Andric return; 229*5ffd83dbSDimitry Andric } 230*5ffd83dbSDimitry Andric 231*5ffd83dbSDimitry Andric // Quotient is (Numerator - Remainder) divided by Denominator. 232*5ffd83dbSDimitry Andric const SCEV *Q, *R; 233*5ffd83dbSDimitry Andric const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 234*5ffd83dbSDimitry Andric // This SCEV does not seem to simplify: fail the division here. 235*5ffd83dbSDimitry Andric if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 236*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 237*5ffd83dbSDimitry Andric divide(SE, Diff, Denominator, &Q, &R); 238*5ffd83dbSDimitry Andric if (R != Zero) 239*5ffd83dbSDimitry Andric return cannotDivide(Numerator); 240*5ffd83dbSDimitry Andric Quotient = Q; 241*5ffd83dbSDimitry Andric } 242*5ffd83dbSDimitry Andric 243*5ffd83dbSDimitry Andric SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 244*5ffd83dbSDimitry Andric const SCEV *Denominator) 245*5ffd83dbSDimitry Andric : SE(S), Denominator(Denominator) { 246*5ffd83dbSDimitry Andric Zero = SE.getZero(Denominator->getType()); 247*5ffd83dbSDimitry Andric One = SE.getOne(Denominator->getType()); 248*5ffd83dbSDimitry Andric 249*5ffd83dbSDimitry Andric // We generally do not know how to divide Expr by Denominator. We initialize 250*5ffd83dbSDimitry Andric // the division to a "cannot divide" state to simplify the rest of the code. 251*5ffd83dbSDimitry Andric cannotDivide(Numerator); 252*5ffd83dbSDimitry Andric } 253*5ffd83dbSDimitry Andric 254*5ffd83dbSDimitry Andric // Convenience function for giving up on the division. We set the quotient to 255*5ffd83dbSDimitry Andric // be equal to zero and the remainder to be equal to the numerator. 256*5ffd83dbSDimitry Andric void SCEVDivision::cannotDivide(const SCEV *Numerator) { 257*5ffd83dbSDimitry Andric Quotient = Zero; 258*5ffd83dbSDimitry Andric Remainder = Numerator; 259*5ffd83dbSDimitry Andric } 260