1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the class that knows how to divide SCEV's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Analysis/ScalarEvolutionDivision.h" 14 #include "llvm/ADT/APInt.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Analysis/ScalarEvolution.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/Support/Casting.h" 20 #include "llvm/Support/ErrorHandling.h" 21 #include <cassert> 22 #include <cstdint> 23 24 namespace llvm { 25 class Type; 26 } 27 28 using namespace llvm; 29 30 namespace { 31 32 static inline int sizeOfSCEV(const SCEV *S) { 33 struct FindSCEVSize { 34 int Size = 0; 35 36 FindSCEVSize() = default; 37 38 bool follow(const SCEV *S) { 39 ++Size; 40 // Keep looking at all operands of S. 41 return true; 42 } 43 44 bool isDone() const { return false; } 45 }; 46 47 FindSCEVSize F; 48 SCEVTraversal<FindSCEVSize> ST(F); 49 ST.visitAll(S); 50 return F.Size; 51 } 52 53 } // namespace 54 55 // Computes the Quotient and Remainder of the division of Numerator by 56 // Denominator. 57 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, 58 const SCEV *Denominator, const SCEV **Quotient, 59 const SCEV **Remainder) { 60 assert(Numerator && Denominator && "Uninitialized SCEV"); 61 62 SCEVDivision D(SE, Numerator, Denominator); 63 64 // Check for the trivial case here to avoid having to check for it in the 65 // rest of the code. 66 if (Numerator == Denominator) { 67 *Quotient = D.One; 68 *Remainder = D.Zero; 69 return; 70 } 71 72 if (Numerator->isZero()) { 73 *Quotient = D.Zero; 74 *Remainder = D.Zero; 75 return; 76 } 77 78 // A simple case when N/1. The quotient is N. 79 if (Denominator->isOne()) { 80 *Quotient = Numerator; 81 *Remainder = D.Zero; 82 return; 83 } 84 85 // Split the Denominator when it is a product. 86 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 87 const SCEV *Q, *R; 88 *Quotient = Numerator; 89 for (const SCEV *Op : T->operands()) { 90 divide(SE, *Quotient, Op, &Q, &R); 91 *Quotient = Q; 92 93 // Bail out when the Numerator is not divisible by one of the terms of 94 // the Denominator. 95 if (!R->isZero()) { 96 *Quotient = D.Zero; 97 *Remainder = Numerator; 98 return; 99 } 100 } 101 *Remainder = D.Zero; 102 return; 103 } 104 105 D.visit(Numerator); 106 *Quotient = D.Quotient; 107 *Remainder = D.Remainder; 108 } 109 110 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { 111 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 112 APInt NumeratorVal = Numerator->getAPInt(); 113 APInt DenominatorVal = D->getAPInt(); 114 uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 115 uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 116 117 if (NumeratorBW > DenominatorBW) 118 DenominatorVal = DenominatorVal.sext(NumeratorBW); 119 else if (NumeratorBW < DenominatorBW) 120 NumeratorVal = NumeratorVal.sext(DenominatorBW); 121 122 APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 123 APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 124 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 125 Quotient = SE.getConstant(QuotientVal); 126 Remainder = SE.getConstant(RemainderVal); 127 return; 128 } 129 } 130 131 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 132 const SCEV *StartQ, *StartR, *StepQ, *StepR; 133 if (!Numerator->isAffine()) 134 return cannotDivide(Numerator); 135 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 136 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 137 // Bail out if the types do not match. 138 Type *Ty = Denominator->getType(); 139 if (Ty != StartQ->getType() || Ty != StartR->getType() || 140 Ty != StepQ->getType() || Ty != StepR->getType()) 141 return cannotDivide(Numerator); 142 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 143 Numerator->getNoWrapFlags()); 144 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 145 Numerator->getNoWrapFlags()); 146 } 147 148 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { 149 SmallVector<const SCEV *, 2> Qs, Rs; 150 Type *Ty = Denominator->getType(); 151 152 for (const SCEV *Op : Numerator->operands()) { 153 const SCEV *Q, *R; 154 divide(SE, Op, Denominator, &Q, &R); 155 156 // Bail out if types do not match. 157 if (Ty != Q->getType() || Ty != R->getType()) 158 return cannotDivide(Numerator); 159 160 Qs.push_back(Q); 161 Rs.push_back(R); 162 } 163 164 if (Qs.size() == 1) { 165 Quotient = Qs[0]; 166 Remainder = Rs[0]; 167 return; 168 } 169 170 Quotient = SE.getAddExpr(Qs); 171 Remainder = SE.getAddExpr(Rs); 172 } 173 174 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { 175 SmallVector<const SCEV *, 2> Qs; 176 Type *Ty = Denominator->getType(); 177 178 bool FoundDenominatorTerm = false; 179 for (const SCEV *Op : Numerator->operands()) { 180 // Bail out if types do not match. 181 if (Ty != Op->getType()) 182 return cannotDivide(Numerator); 183 184 if (FoundDenominatorTerm) { 185 Qs.push_back(Op); 186 continue; 187 } 188 189 // Check whether Denominator divides one of the product operands. 190 const SCEV *Q, *R; 191 divide(SE, Op, Denominator, &Q, &R); 192 if (!R->isZero()) { 193 Qs.push_back(Op); 194 continue; 195 } 196 197 // Bail out if types do not match. 198 if (Ty != Q->getType()) 199 return cannotDivide(Numerator); 200 201 FoundDenominatorTerm = true; 202 Qs.push_back(Q); 203 } 204 205 if (FoundDenominatorTerm) { 206 Remainder = Zero; 207 if (Qs.size() == 1) 208 Quotient = Qs[0]; 209 else 210 Quotient = SE.getMulExpr(Qs); 211 return; 212 } 213 214 if (!isa<SCEVUnknown>(Denominator)) 215 return cannotDivide(Numerator); 216 217 // The Remainder is obtained by replacing Denominator by 0 in Numerator. 218 ValueToValueMap RewriteMap; 219 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 220 cast<SCEVConstant>(Zero)->getValue(); 221 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 222 223 if (Remainder->isZero()) { 224 // The Quotient is obtained by replacing Denominator by 1 in Numerator. 225 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 226 cast<SCEVConstant>(One)->getValue(); 227 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 228 return; 229 } 230 231 // Quotient is (Numerator - Remainder) divided by Denominator. 232 const SCEV *Q, *R; 233 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 234 // This SCEV does not seem to simplify: fail the division here. 235 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 236 return cannotDivide(Numerator); 237 divide(SE, Diff, Denominator, &Q, &R); 238 if (R != Zero) 239 return cannotDivide(Numerator); 240 Quotient = Q; 241 } 242 243 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 244 const SCEV *Denominator) 245 : SE(S), Denominator(Denominator) { 246 Zero = SE.getZero(Denominator->getType()); 247 One = SE.getOne(Denominator->getType()); 248 249 // We generally do not know how to divide Expr by Denominator. We initialize 250 // the division to a "cannot divide" state to simplify the rest of the code. 251 cannotDivide(Numerator); 252 } 253 254 // Convenience function for giving up on the division. We set the quotient to 255 // be equal to zero and the remainder to be equal to the numerator. 256 void SCEVDivision::cannotDivide(const SCEV *Numerator) { 257 Quotient = Zero; 258 Remainder = Numerator; 259 } 260