1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the class that knows how to divide SCEV's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Analysis/ScalarEvolutionDivision.h" 14 #include "llvm/ADT/APInt.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Analysis/ScalarEvolution.h" 18 #include "llvm/Support/Casting.h" 19 #include <cassert> 20 #include <cstdint> 21 22 namespace llvm { 23 class Type; 24 } // namespace llvm 25 26 using namespace llvm; 27 28 namespace { 29 30 static inline int sizeOfSCEV(const SCEV *S) { 31 struct FindSCEVSize { 32 int Size = 0; 33 34 FindSCEVSize() = default; 35 36 bool follow(const SCEV *S) { 37 ++Size; 38 // Keep looking at all operands of S. 39 return true; 40 } 41 42 bool isDone() const { return false; } 43 }; 44 45 FindSCEVSize F; 46 SCEVTraversal<FindSCEVSize> ST(F); 47 ST.visitAll(S); 48 return F.Size; 49 } 50 51 } // namespace 52 53 // Computes the Quotient and Remainder of the division of Numerator by 54 // Denominator. 55 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, 56 const SCEV *Denominator, const SCEV **Quotient, 57 const SCEV **Remainder) { 58 assert(Numerator && Denominator && "Uninitialized SCEV"); 59 60 SCEVDivision D(SE, Numerator, Denominator); 61 62 // Check for the trivial case here to avoid having to check for it in the 63 // rest of the code. 64 if (Numerator == Denominator) { 65 *Quotient = D.One; 66 *Remainder = D.Zero; 67 return; 68 } 69 70 if (Numerator->isZero()) { 71 *Quotient = D.Zero; 72 *Remainder = D.Zero; 73 return; 74 } 75 76 // A simple case when N/1. The quotient is N. 77 if (Denominator->isOne()) { 78 *Quotient = Numerator; 79 *Remainder = D.Zero; 80 return; 81 } 82 83 // Split the Denominator when it is a product. 84 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 85 const SCEV *Q, *R; 86 *Quotient = Numerator; 87 for (const SCEV *Op : T->operands()) { 88 divide(SE, *Quotient, Op, &Q, &R); 89 *Quotient = Q; 90 91 // Bail out when the Numerator is not divisible by one of the terms of 92 // the Denominator. 93 if (!R->isZero()) { 94 *Quotient = D.Zero; 95 *Remainder = Numerator; 96 return; 97 } 98 } 99 *Remainder = D.Zero; 100 return; 101 } 102 103 D.visit(Numerator); 104 *Quotient = D.Quotient; 105 *Remainder = D.Remainder; 106 } 107 108 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { 109 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 110 APInt NumeratorVal = Numerator->getAPInt(); 111 APInt DenominatorVal = D->getAPInt(); 112 uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 113 uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 114 115 if (NumeratorBW > DenominatorBW) 116 DenominatorVal = DenominatorVal.sext(NumeratorBW); 117 else if (NumeratorBW < DenominatorBW) 118 NumeratorVal = NumeratorVal.sext(DenominatorBW); 119 120 APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 121 APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 122 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 123 Quotient = SE.getConstant(QuotientVal); 124 Remainder = SE.getConstant(RemainderVal); 125 return; 126 } 127 } 128 129 void SCEVDivision::visitVScale(const SCEVVScale *Numerator) { 130 return cannotDivide(Numerator); 131 } 132 133 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 134 const SCEV *StartQ, *StartR, *StepQ, *StepR; 135 if (!Numerator->isAffine()) 136 return cannotDivide(Numerator); 137 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 138 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 139 // Bail out if the types do not match. 140 Type *Ty = Denominator->getType(); 141 if (Ty != StartQ->getType() || Ty != StartR->getType() || 142 Ty != StepQ->getType() || Ty != StepR->getType()) 143 return cannotDivide(Numerator); 144 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 145 Numerator->getNoWrapFlags()); 146 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 147 Numerator->getNoWrapFlags()); 148 } 149 150 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { 151 SmallVector<const SCEV *, 2> Qs, Rs; 152 Type *Ty = Denominator->getType(); 153 154 for (const SCEV *Op : Numerator->operands()) { 155 const SCEV *Q, *R; 156 divide(SE, Op, Denominator, &Q, &R); 157 158 // Bail out if types do not match. 159 if (Ty != Q->getType() || Ty != R->getType()) 160 return cannotDivide(Numerator); 161 162 Qs.push_back(Q); 163 Rs.push_back(R); 164 } 165 166 if (Qs.size() == 1) { 167 Quotient = Qs[0]; 168 Remainder = Rs[0]; 169 return; 170 } 171 172 Quotient = SE.getAddExpr(Qs); 173 Remainder = SE.getAddExpr(Rs); 174 } 175 176 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { 177 SmallVector<const SCEV *, 2> Qs; 178 Type *Ty = Denominator->getType(); 179 180 bool FoundDenominatorTerm = false; 181 for (const SCEV *Op : Numerator->operands()) { 182 // Bail out if types do not match. 183 if (Ty != Op->getType()) 184 return cannotDivide(Numerator); 185 186 if (FoundDenominatorTerm) { 187 Qs.push_back(Op); 188 continue; 189 } 190 191 // Check whether Denominator divides one of the product operands. 192 const SCEV *Q, *R; 193 divide(SE, Op, Denominator, &Q, &R); 194 if (!R->isZero()) { 195 Qs.push_back(Op); 196 continue; 197 } 198 199 // Bail out if types do not match. 200 if (Ty != Q->getType()) 201 return cannotDivide(Numerator); 202 203 FoundDenominatorTerm = true; 204 Qs.push_back(Q); 205 } 206 207 if (FoundDenominatorTerm) { 208 Remainder = Zero; 209 if (Qs.size() == 1) 210 Quotient = Qs[0]; 211 else 212 Quotient = SE.getMulExpr(Qs); 213 return; 214 } 215 216 if (!isa<SCEVUnknown>(Denominator)) 217 return cannotDivide(Numerator); 218 219 // The Remainder is obtained by replacing Denominator by 0 in Numerator. 220 ValueToSCEVMapTy RewriteMap; 221 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero; 222 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 223 224 if (Remainder->isZero()) { 225 // The Quotient is obtained by replacing Denominator by 1 in Numerator. 226 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One; 227 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); 228 return; 229 } 230 231 // Quotient is (Numerator - Remainder) divided by Denominator. 232 const SCEV *Q, *R; 233 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 234 // This SCEV does not seem to simplify: fail the division here. 235 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 236 return cannotDivide(Numerator); 237 divide(SE, Diff, Denominator, &Q, &R); 238 if (R != Zero) 239 return cannotDivide(Numerator); 240 Quotient = Q; 241 } 242 243 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 244 const SCEV *Denominator) 245 : SE(S), Denominator(Denominator) { 246 Zero = SE.getZero(Denominator->getType()); 247 One = SE.getOne(Denominator->getType()); 248 249 // We generally do not know how to divide Expr by Denominator. We initialize 250 // the division to a "cannot divide" state to simplify the rest of the code. 251 cannotDivide(Numerator); 252 } 253 254 // Convenience function for giving up on the division. We set the quotient to 255 // be equal to zero and the remainder to be equal to the numerator. 256 void SCEVDivision::cannotDivide(const SCEV *Numerator) { 257 Quotient = Zero; 258 Remainder = Numerator; 259 } 260