1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the class that knows how to divide SCEV's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Analysis/ScalarEvolutionDivision.h" 14 #include "llvm/ADT/APInt.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Analysis/ScalarEvolution.h" 18 #include "llvm/Support/Casting.h" 19 #include <cassert> 20 #include <cstdint> 21 22 namespace llvm { 23 class Type; 24 } 25 26 using namespace llvm; 27 28 namespace { 29 30 static inline int sizeOfSCEV(const SCEV *S) { 31 struct FindSCEVSize { 32 int Size = 0; 33 34 FindSCEVSize() = default; 35 36 bool follow(const SCEV *S) { 37 ++Size; 38 // Keep looking at all operands of S. 39 return true; 40 } 41 42 bool isDone() const { return false; } 43 }; 44 45 FindSCEVSize F; 46 SCEVTraversal<FindSCEVSize> ST(F); 47 ST.visitAll(S); 48 return F.Size; 49 } 50 51 } // namespace 52 53 // Computes the Quotient and Remainder of the division of Numerator by 54 // Denominator. 55 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, 56 const SCEV *Denominator, const SCEV **Quotient, 57 const SCEV **Remainder) { 58 assert(Numerator && Denominator && "Uninitialized SCEV"); 59 60 SCEVDivision D(SE, Numerator, Denominator); 61 62 // Check for the trivial case here to avoid having to check for it in the 63 // rest of the code. 64 if (Numerator == Denominator) { 65 *Quotient = D.One; 66 *Remainder = D.Zero; 67 return; 68 } 69 70 if (Numerator->isZero()) { 71 *Quotient = D.Zero; 72 *Remainder = D.Zero; 73 return; 74 } 75 76 // A simple case when N/1. The quotient is N. 77 if (Denominator->isOne()) { 78 *Quotient = Numerator; 79 *Remainder = D.Zero; 80 return; 81 } 82 83 // Split the Denominator when it is a product. 84 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 85 const SCEV *Q, *R; 86 *Quotient = Numerator; 87 for (const SCEV *Op : T->operands()) { 88 divide(SE, *Quotient, Op, &Q, &R); 89 *Quotient = Q; 90 91 // Bail out when the Numerator is not divisible by one of the terms of 92 // the Denominator. 93 if (!R->isZero()) { 94 *Quotient = D.Zero; 95 *Remainder = Numerator; 96 return; 97 } 98 } 99 *Remainder = D.Zero; 100 return; 101 } 102 103 D.visit(Numerator); 104 *Quotient = D.Quotient; 105 *Remainder = D.Remainder; 106 } 107 108 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { 109 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 110 APInt NumeratorVal = Numerator->getAPInt(); 111 APInt DenominatorVal = D->getAPInt(); 112 uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 113 uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 114 115 if (NumeratorBW > DenominatorBW) 116 DenominatorVal = DenominatorVal.sext(NumeratorBW); 117 else if (NumeratorBW < DenominatorBW) 118 NumeratorVal = NumeratorVal.sext(DenominatorBW); 119 120 APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 121 APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 122 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 123 Quotient = SE.getConstant(QuotientVal); 124 Remainder = SE.getConstant(RemainderVal); 125 return; 126 } 127 } 128 129 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 130 const SCEV *StartQ, *StartR, *StepQ, *StepR; 131 if (!Numerator->isAffine()) 132 return cannotDivide(Numerator); 133 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 134 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 135 // Bail out if the types do not match. 136 Type *Ty = Denominator->getType(); 137 if (Ty != StartQ->getType() || Ty != StartR->getType() || 138 Ty != StepQ->getType() || Ty != StepR->getType()) 139 return cannotDivide(Numerator); 140 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 141 Numerator->getNoWrapFlags()); 142 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 143 Numerator->getNoWrapFlags()); 144 } 145 146 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { 147 SmallVector<const SCEV *, 2> Qs, Rs; 148 Type *Ty = Denominator->getType(); 149 150 for (const SCEV *Op : Numerator->operands()) { 151 const SCEV *Q, *R; 152 divide(SE, Op, Denominator, &Q, &R); 153 154 // Bail out if types do not match. 155 if (Ty != Q->getType() || Ty != R->getType()) 156 return cannotDivide(Numerator); 157 158 Qs.push_back(Q); 159 Rs.push_back(R); 160 } 161 162 if (Qs.size() == 1) { 163 Quotient = Qs[0]; 164 Remainder = Rs[0]; 165 return; 166 } 167 168 Quotient = SE.getAddExpr(Qs); 169 Remainder = SE.getAddExpr(Rs); 170 } 171 172 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { 173 SmallVector<const SCEV *, 2> Qs; 174 Type *Ty = Denominator->getType(); 175 176 bool FoundDenominatorTerm = false; 177 for (const SCEV *Op : Numerator->operands()) { 178 // Bail out if types do not match. 179 if (Ty != Op->getType()) 180 return cannotDivide(Numerator); 181 182 if (FoundDenominatorTerm) { 183 Qs.push_back(Op); 184 continue; 185 } 186 187 // Check whether Denominator divides one of the product operands. 188 const SCEV *Q, *R; 189 divide(SE, Op, Denominator, &Q, &R); 190 if (!R->isZero()) { 191 Qs.push_back(Op); 192 continue; 193 } 194 195 // Bail out if types do not match. 196 if (Ty != Q->getType()) 197 return cannotDivide(Numerator); 198 199 FoundDenominatorTerm = true; 200 Qs.push_back(Q); 201 } 202 203 if (FoundDenominatorTerm) { 204 Remainder = Zero; 205 if (Qs.size() == 1) 206 Quotient = Qs[0]; 207 else 208 Quotient = SE.getMulExpr(Qs); 209 return; 210 } 211 212 if (!isa<SCEVUnknown>(Denominator)) 213 return cannotDivide(Numerator); 214 215 // The Remainder is obtained by replacing Denominator by 0 in Numerator. 216 ValueToSCEVMapTy RewriteMap; 217 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero; 218 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 219 220 if (Remainder->isZero()) { 221 // The Quotient is obtained by replacing Denominator by 1 in Numerator. 222 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One; 223 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 224 return; 225 } 226 227 // Quotient is (Numerator - Remainder) divided by Denominator. 228 const SCEV *Q, *R; 229 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 230 // This SCEV does not seem to simplify: fail the division here. 231 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 232 return cannotDivide(Numerator); 233 divide(SE, Diff, Denominator, &Q, &R); 234 if (R != Zero) 235 return cannotDivide(Numerator); 236 Quotient = Q; 237 } 238 239 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 240 const SCEV *Denominator) 241 : SE(S), Denominator(Denominator) { 242 Zero = SE.getZero(Denominator->getType()); 243 One = SE.getOne(Denominator->getType()); 244 245 // We generally do not know how to divide Expr by Denominator. We initialize 246 // the division to a "cannot divide" state to simplify the rest of the code. 247 cannotDivide(Numerator); 248 } 249 250 // Convenience function for giving up on the division. We set the quotient to 251 // be equal to zero and the remainder to be equal to the numerator. 252 void SCEVDivision::cannotDivide(const SCEV *Numerator) { 253 Quotient = Zero; 254 Remainder = Numerator; 255 } 256