15ffd83dbSDimitry Andric //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// 25ffd83dbSDimitry Andric // 35ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 55ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65ffd83dbSDimitry Andric // 75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 85ffd83dbSDimitry Andric // 95ffd83dbSDimitry Andric // This file defines the class that knows how to divide SCEV's. 105ffd83dbSDimitry Andric // 115ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 125ffd83dbSDimitry Andric 135ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolutionDivision.h" 145ffd83dbSDimitry Andric #include "llvm/ADT/APInt.h" 155ffd83dbSDimitry Andric #include "llvm/ADT/DenseMap.h" 165ffd83dbSDimitry Andric #include "llvm/ADT/SmallVector.h" 175ffd83dbSDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 185ffd83dbSDimitry Andric #include "llvm/Support/Casting.h" 195ffd83dbSDimitry Andric #include <cassert> 205ffd83dbSDimitry Andric #include <cstdint> 215ffd83dbSDimitry Andric 225ffd83dbSDimitry Andric namespace llvm { 235ffd83dbSDimitry Andric class Type; 245ffd83dbSDimitry Andric } 255ffd83dbSDimitry Andric 265ffd83dbSDimitry Andric using namespace llvm; 275ffd83dbSDimitry Andric 285ffd83dbSDimitry Andric namespace { 295ffd83dbSDimitry Andric 305ffd83dbSDimitry Andric static inline int sizeOfSCEV(const SCEV *S) { 315ffd83dbSDimitry Andric struct FindSCEVSize { 325ffd83dbSDimitry Andric int Size = 0; 335ffd83dbSDimitry Andric 345ffd83dbSDimitry Andric FindSCEVSize() = default; 355ffd83dbSDimitry Andric 365ffd83dbSDimitry Andric bool follow(const SCEV *S) { 375ffd83dbSDimitry Andric ++Size; 385ffd83dbSDimitry Andric // Keep looking at all operands of S. 395ffd83dbSDimitry Andric return true; 405ffd83dbSDimitry Andric } 415ffd83dbSDimitry Andric 425ffd83dbSDimitry Andric bool isDone() const { return false; } 435ffd83dbSDimitry Andric }; 445ffd83dbSDimitry Andric 455ffd83dbSDimitry Andric FindSCEVSize F; 465ffd83dbSDimitry Andric SCEVTraversal<FindSCEVSize> ST(F); 475ffd83dbSDimitry Andric ST.visitAll(S); 485ffd83dbSDimitry Andric return F.Size; 495ffd83dbSDimitry Andric } 505ffd83dbSDimitry Andric 515ffd83dbSDimitry Andric } // namespace 525ffd83dbSDimitry Andric 535ffd83dbSDimitry Andric // Computes the Quotient and Remainder of the division of Numerator by 545ffd83dbSDimitry Andric // Denominator. 555ffd83dbSDimitry Andric void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, 565ffd83dbSDimitry Andric const SCEV *Denominator, const SCEV **Quotient, 575ffd83dbSDimitry Andric const SCEV **Remainder) { 585ffd83dbSDimitry Andric assert(Numerator && Denominator && "Uninitialized SCEV"); 595ffd83dbSDimitry Andric 605ffd83dbSDimitry Andric SCEVDivision D(SE, Numerator, Denominator); 615ffd83dbSDimitry Andric 625ffd83dbSDimitry Andric // Check for the trivial case here to avoid having to check for it in the 635ffd83dbSDimitry Andric // rest of the code. 645ffd83dbSDimitry Andric if (Numerator == Denominator) { 655ffd83dbSDimitry Andric *Quotient = D.One; 665ffd83dbSDimitry Andric *Remainder = D.Zero; 675ffd83dbSDimitry Andric return; 685ffd83dbSDimitry Andric } 695ffd83dbSDimitry Andric 705ffd83dbSDimitry Andric if (Numerator->isZero()) { 715ffd83dbSDimitry Andric *Quotient = D.Zero; 725ffd83dbSDimitry Andric *Remainder = D.Zero; 735ffd83dbSDimitry Andric return; 745ffd83dbSDimitry Andric } 755ffd83dbSDimitry Andric 765ffd83dbSDimitry Andric // A simple case when N/1. The quotient is N. 775ffd83dbSDimitry Andric if (Denominator->isOne()) { 785ffd83dbSDimitry Andric *Quotient = Numerator; 795ffd83dbSDimitry Andric *Remainder = D.Zero; 805ffd83dbSDimitry Andric return; 815ffd83dbSDimitry Andric } 825ffd83dbSDimitry Andric 835ffd83dbSDimitry Andric // Split the Denominator when it is a product. 845ffd83dbSDimitry Andric if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 855ffd83dbSDimitry Andric const SCEV *Q, *R; 865ffd83dbSDimitry Andric *Quotient = Numerator; 875ffd83dbSDimitry Andric for (const SCEV *Op : T->operands()) { 885ffd83dbSDimitry Andric divide(SE, *Quotient, Op, &Q, &R); 895ffd83dbSDimitry Andric *Quotient = Q; 905ffd83dbSDimitry Andric 915ffd83dbSDimitry Andric // Bail out when the Numerator is not divisible by one of the terms of 925ffd83dbSDimitry Andric // the Denominator. 935ffd83dbSDimitry Andric if (!R->isZero()) { 945ffd83dbSDimitry Andric *Quotient = D.Zero; 955ffd83dbSDimitry Andric *Remainder = Numerator; 965ffd83dbSDimitry Andric return; 975ffd83dbSDimitry Andric } 985ffd83dbSDimitry Andric } 995ffd83dbSDimitry Andric *Remainder = D.Zero; 1005ffd83dbSDimitry Andric return; 1015ffd83dbSDimitry Andric } 1025ffd83dbSDimitry Andric 1035ffd83dbSDimitry Andric D.visit(Numerator); 1045ffd83dbSDimitry Andric *Quotient = D.Quotient; 1055ffd83dbSDimitry Andric *Remainder = D.Remainder; 1065ffd83dbSDimitry Andric } 1075ffd83dbSDimitry Andric 1085ffd83dbSDimitry Andric void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { 1095ffd83dbSDimitry Andric if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 1105ffd83dbSDimitry Andric APInt NumeratorVal = Numerator->getAPInt(); 1115ffd83dbSDimitry Andric APInt DenominatorVal = D->getAPInt(); 1125ffd83dbSDimitry Andric uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 1135ffd83dbSDimitry Andric uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 1145ffd83dbSDimitry Andric 1155ffd83dbSDimitry Andric if (NumeratorBW > DenominatorBW) 1165ffd83dbSDimitry Andric DenominatorVal = DenominatorVal.sext(NumeratorBW); 1175ffd83dbSDimitry Andric else if (NumeratorBW < DenominatorBW) 1185ffd83dbSDimitry Andric NumeratorVal = NumeratorVal.sext(DenominatorBW); 1195ffd83dbSDimitry Andric 1205ffd83dbSDimitry Andric APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 1215ffd83dbSDimitry Andric APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 1225ffd83dbSDimitry Andric APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 1235ffd83dbSDimitry Andric Quotient = SE.getConstant(QuotientVal); 1245ffd83dbSDimitry Andric Remainder = SE.getConstant(RemainderVal); 1255ffd83dbSDimitry Andric return; 1265ffd83dbSDimitry Andric } 1275ffd83dbSDimitry Andric } 1285ffd83dbSDimitry Andric 129*06c3fb27SDimitry Andric void SCEVDivision::visitVScale(const SCEVVScale *Numerator) { 130*06c3fb27SDimitry Andric return cannotDivide(Numerator); 131*06c3fb27SDimitry Andric } 132*06c3fb27SDimitry Andric 1335ffd83dbSDimitry Andric void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 1345ffd83dbSDimitry Andric const SCEV *StartQ, *StartR, *StepQ, *StepR; 1355ffd83dbSDimitry Andric if (!Numerator->isAffine()) 1365ffd83dbSDimitry Andric return cannotDivide(Numerator); 1375ffd83dbSDimitry Andric divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 1385ffd83dbSDimitry Andric divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 1395ffd83dbSDimitry Andric // Bail out if the types do not match. 1405ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 1415ffd83dbSDimitry Andric if (Ty != StartQ->getType() || Ty != StartR->getType() || 1425ffd83dbSDimitry Andric Ty != StepQ->getType() || Ty != StepR->getType()) 1435ffd83dbSDimitry Andric return cannotDivide(Numerator); 1445ffd83dbSDimitry Andric Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 1455ffd83dbSDimitry Andric Numerator->getNoWrapFlags()); 1465ffd83dbSDimitry Andric Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 1475ffd83dbSDimitry Andric Numerator->getNoWrapFlags()); 1485ffd83dbSDimitry Andric } 1495ffd83dbSDimitry Andric 1505ffd83dbSDimitry Andric void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { 1515ffd83dbSDimitry Andric SmallVector<const SCEV *, 2> Qs, Rs; 1525ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 1535ffd83dbSDimitry Andric 1545ffd83dbSDimitry Andric for (const SCEV *Op : Numerator->operands()) { 1555ffd83dbSDimitry Andric const SCEV *Q, *R; 1565ffd83dbSDimitry Andric divide(SE, Op, Denominator, &Q, &R); 1575ffd83dbSDimitry Andric 1585ffd83dbSDimitry Andric // Bail out if types do not match. 1595ffd83dbSDimitry Andric if (Ty != Q->getType() || Ty != R->getType()) 1605ffd83dbSDimitry Andric return cannotDivide(Numerator); 1615ffd83dbSDimitry Andric 1625ffd83dbSDimitry Andric Qs.push_back(Q); 1635ffd83dbSDimitry Andric Rs.push_back(R); 1645ffd83dbSDimitry Andric } 1655ffd83dbSDimitry Andric 1665ffd83dbSDimitry Andric if (Qs.size() == 1) { 1675ffd83dbSDimitry Andric Quotient = Qs[0]; 1685ffd83dbSDimitry Andric Remainder = Rs[0]; 1695ffd83dbSDimitry Andric return; 1705ffd83dbSDimitry Andric } 1715ffd83dbSDimitry Andric 1725ffd83dbSDimitry Andric Quotient = SE.getAddExpr(Qs); 1735ffd83dbSDimitry Andric Remainder = SE.getAddExpr(Rs); 1745ffd83dbSDimitry Andric } 1755ffd83dbSDimitry Andric 1765ffd83dbSDimitry Andric void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { 1775ffd83dbSDimitry Andric SmallVector<const SCEV *, 2> Qs; 1785ffd83dbSDimitry Andric Type *Ty = Denominator->getType(); 1795ffd83dbSDimitry Andric 1805ffd83dbSDimitry Andric bool FoundDenominatorTerm = false; 1815ffd83dbSDimitry Andric for (const SCEV *Op : Numerator->operands()) { 1825ffd83dbSDimitry Andric // Bail out if types do not match. 1835ffd83dbSDimitry Andric if (Ty != Op->getType()) 1845ffd83dbSDimitry Andric return cannotDivide(Numerator); 1855ffd83dbSDimitry Andric 1865ffd83dbSDimitry Andric if (FoundDenominatorTerm) { 1875ffd83dbSDimitry Andric Qs.push_back(Op); 1885ffd83dbSDimitry Andric continue; 1895ffd83dbSDimitry Andric } 1905ffd83dbSDimitry Andric 1915ffd83dbSDimitry Andric // Check whether Denominator divides one of the product operands. 1925ffd83dbSDimitry Andric const SCEV *Q, *R; 1935ffd83dbSDimitry Andric divide(SE, Op, Denominator, &Q, &R); 1945ffd83dbSDimitry Andric if (!R->isZero()) { 1955ffd83dbSDimitry Andric Qs.push_back(Op); 1965ffd83dbSDimitry Andric continue; 1975ffd83dbSDimitry Andric } 1985ffd83dbSDimitry Andric 1995ffd83dbSDimitry Andric // Bail out if types do not match. 2005ffd83dbSDimitry Andric if (Ty != Q->getType()) 2015ffd83dbSDimitry Andric return cannotDivide(Numerator); 2025ffd83dbSDimitry Andric 2035ffd83dbSDimitry Andric FoundDenominatorTerm = true; 2045ffd83dbSDimitry Andric Qs.push_back(Q); 2055ffd83dbSDimitry Andric } 2065ffd83dbSDimitry Andric 2075ffd83dbSDimitry Andric if (FoundDenominatorTerm) { 2085ffd83dbSDimitry Andric Remainder = Zero; 2095ffd83dbSDimitry Andric if (Qs.size() == 1) 2105ffd83dbSDimitry Andric Quotient = Qs[0]; 2115ffd83dbSDimitry Andric else 2125ffd83dbSDimitry Andric Quotient = SE.getMulExpr(Qs); 2135ffd83dbSDimitry Andric return; 2145ffd83dbSDimitry Andric } 2155ffd83dbSDimitry Andric 2165ffd83dbSDimitry Andric if (!isa<SCEVUnknown>(Denominator)) 2175ffd83dbSDimitry Andric return cannotDivide(Numerator); 2185ffd83dbSDimitry Andric 2195ffd83dbSDimitry Andric // The Remainder is obtained by replacing Denominator by 0 in Numerator. 220e8d8bef9SDimitry Andric ValueToSCEVMapTy RewriteMap; 221e8d8bef9SDimitry Andric RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero; 222e8d8bef9SDimitry Andric Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 2235ffd83dbSDimitry Andric 2245ffd83dbSDimitry Andric if (Remainder->isZero()) { 2255ffd83dbSDimitry Andric // The Quotient is obtained by replacing Denominator by 1 in Numerator. 226e8d8bef9SDimitry Andric RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One; 227e8d8bef9SDimitry Andric Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 2285ffd83dbSDimitry Andric return; 2295ffd83dbSDimitry Andric } 2305ffd83dbSDimitry Andric 2315ffd83dbSDimitry Andric // Quotient is (Numerator - Remainder) divided by Denominator. 2325ffd83dbSDimitry Andric const SCEV *Q, *R; 2335ffd83dbSDimitry Andric const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 2345ffd83dbSDimitry Andric // This SCEV does not seem to simplify: fail the division here. 2355ffd83dbSDimitry Andric if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 2365ffd83dbSDimitry Andric return cannotDivide(Numerator); 2375ffd83dbSDimitry Andric divide(SE, Diff, Denominator, &Q, &R); 2385ffd83dbSDimitry Andric if (R != Zero) 2395ffd83dbSDimitry Andric return cannotDivide(Numerator); 2405ffd83dbSDimitry Andric Quotient = Q; 2415ffd83dbSDimitry Andric } 2425ffd83dbSDimitry Andric 2435ffd83dbSDimitry Andric SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 2445ffd83dbSDimitry Andric const SCEV *Denominator) 2455ffd83dbSDimitry Andric : SE(S), Denominator(Denominator) { 2465ffd83dbSDimitry Andric Zero = SE.getZero(Denominator->getType()); 2475ffd83dbSDimitry Andric One = SE.getOne(Denominator->getType()); 2485ffd83dbSDimitry Andric 2495ffd83dbSDimitry Andric // We generally do not know how to divide Expr by Denominator. We initialize 2505ffd83dbSDimitry Andric // the division to a "cannot divide" state to simplify the rest of the code. 2515ffd83dbSDimitry Andric cannotDivide(Numerator); 2525ffd83dbSDimitry Andric } 2535ffd83dbSDimitry Andric 2545ffd83dbSDimitry Andric // Convenience function for giving up on the division. We set the quotient to 2555ffd83dbSDimitry Andric // be equal to zero and the remainder to be equal to the numerator. 2565ffd83dbSDimitry Andric void SCEVDivision::cannotDivide(const SCEV *Numerator) { 2575ffd83dbSDimitry Andric Quotient = Zero; 2585ffd83dbSDimitry Andric Remainder = Numerator; 2595ffd83dbSDimitry Andric } 260