1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with "normalized" expressions. 10 // See the comments at the top of ScalarEvolutionNormalization.h for details. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/ScalarEvolutionNormalization.h" 15 #include "llvm/Analysis/LoopInfo.h" 16 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 17 using namespace llvm; 18 19 /// TransformKind - Different types of transformations that 20 /// TransformForPostIncUse can do. 21 enum TransformKind { 22 /// Normalize - Normalize according to the given loops. 23 Normalize, 24 /// Denormalize - Perform the inverse transform on the expression with the 25 /// given loop set. 26 Denormalize 27 }; 28 29 namespace { 30 struct NormalizeDenormalizeRewriter 31 : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { 32 const TransformKind Kind; 33 34 // NB! Pred is a function_ref. Storing it here is okay only because 35 // we're careful about the lifetime of NormalizeDenormalizeRewriter. 36 const NormalizePredTy Pred; 37 38 NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, 39 ScalarEvolution &SE) 40 : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), 41 Pred(Pred) {} 42 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); 43 }; 44 } // namespace 45 46 const SCEV * 47 NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { 48 SmallVector<const SCEV *, 8> Operands; 49 50 transform(AR->operands(), std::back_inserter(Operands), 51 [&](const SCEV *Op) { return visit(Op); }); 52 53 if (!Pred(AR)) 54 return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); 55 56 // Normalization and denormalization are fancy names for decrementing and 57 // incrementing a SCEV expression with respect to a set of loops. Since 58 // Pred(AR) has returned true, we know we need to normalize or denormalize AR 59 // with respect to its loop. 60 61 if (Kind == Denormalize) { 62 // Denormalization / "partial increment" is essentially the same as \c 63 // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the 64 // symmetry with Normalization clear. 65 for (int i = 0, e = Operands.size() - 1; i < e; i++) 66 Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); 67 } else { 68 assert(Kind == Normalize && "Only two possibilities!"); 69 70 // Normalization / "partial decrement" is a bit more subtle. Since 71 // incrementing a SCEV expression (in general) changes the step of the SCEV 72 // expression as well, we cannot use the step of the current expression. 73 // Instead, we have to use the step of the very expression we're trying to 74 // compute! 75 // 76 // We solve the issue by recursively building up the result, starting from 77 // the "least significant" operand in the add recurrence: 78 // 79 // Base case: 80 // Single operand add recurrence. It's its own normalization. 81 // 82 // N-operand case: 83 // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S 84 // 85 // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its 86 // normalization by induction. We subtract the normalized step 87 // recurrence from S_{N-1} to get the normalization of S. 88 89 for (int i = Operands.size() - 2; i >= 0; i--) 90 Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); 91 } 92 93 return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); 94 } 95 96 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, 97 const PostIncLoopSet &Loops, 98 ScalarEvolution &SE) { 99 auto Pred = [&](const SCEVAddRecExpr *AR) { 100 return Loops.count(AR->getLoop()); 101 }; 102 return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); 103 } 104 105 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, 106 ScalarEvolution &SE) { 107 return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); 108 } 109 110 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, 111 const PostIncLoopSet &Loops, 112 ScalarEvolution &SE) { 113 auto Pred = [&](const SCEVAddRecExpr *AR) { 114 return Loops.count(AR->getLoop()); 115 }; 116 return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); 117 } 118