1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with "normalized" expressions. 10 // See the comments at the top of ScalarEvolutionNormalization.h for details. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/ScalarEvolutionNormalization.h" 15 #include "llvm/Analysis/LoopInfo.h" 16 #include "llvm/Analysis/ScalarEvolution.h" 17 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 18 using namespace llvm; 19 20 /// TransformKind - Different types of transformations that 21 /// TransformForPostIncUse can do. 22 enum TransformKind { 23 /// Normalize - Normalize according to the given loops. 24 Normalize, 25 /// Denormalize - Perform the inverse transform on the expression with the 26 /// given loop set. 27 Denormalize 28 }; 29 30 namespace { 31 struct NormalizeDenormalizeRewriter 32 : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { 33 const TransformKind Kind; 34 35 // NB! Pred is a function_ref. Storing it here is okay only because 36 // we're careful about the lifetime of NormalizeDenormalizeRewriter. 37 const NormalizePredTy Pred; 38 39 NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, 40 ScalarEvolution &SE) 41 : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), 42 Pred(Pred) {} 43 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); 44 }; 45 } // namespace 46 47 const SCEV * 48 NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { 49 SmallVector<const SCEV *, 8> Operands; 50 51 transform(AR->operands(), std::back_inserter(Operands), 52 [&](const SCEV *Op) { return visit(Op); }); 53 54 if (!Pred(AR)) 55 return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); 56 57 // Normalization and denormalization are fancy names for decrementing and 58 // incrementing a SCEV expression with respect to a set of loops. Since 59 // Pred(AR) has returned true, we know we need to normalize or denormalize AR 60 // with respect to its loop. 61 62 if (Kind == Denormalize) { 63 // Denormalization / "partial increment" is essentially the same as \c 64 // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the 65 // symmetry with Normalization clear. 66 for (int i = 0, e = Operands.size() - 1; i < e; i++) 67 Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); 68 } else { 69 assert(Kind == Normalize && "Only two possibilities!"); 70 71 // Normalization / "partial decrement" is a bit more subtle. Since 72 // incrementing a SCEV expression (in general) changes the step of the SCEV 73 // expression as well, we cannot use the step of the current expression. 74 // Instead, we have to use the step of the very expression we're trying to 75 // compute! 76 // 77 // We solve the issue by recursively building up the result, starting from 78 // the "least significant" operand in the add recurrence: 79 // 80 // Base case: 81 // Single operand add recurrence. It's its own normalization. 82 // 83 // N-operand case: 84 // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S 85 // 86 // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its 87 // normalization by induction. We subtract the normalized step 88 // recurrence from S_{N-1} to get the normalization of S. 89 90 for (int i = Operands.size() - 2; i >= 0; i--) 91 Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); 92 } 93 94 return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); 95 } 96 97 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, 98 const PostIncLoopSet &Loops, 99 ScalarEvolution &SE, 100 bool CheckInvertible) { 101 if (Loops.empty()) 102 return S; 103 auto Pred = [&](const SCEVAddRecExpr *AR) { 104 return Loops.count(AR->getLoop()); 105 }; 106 const SCEV *Normalized = 107 NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); 108 const SCEV *Denormalized = denormalizeForPostIncUse(Normalized, Loops, SE); 109 // If the normalized expression isn't invertible. 110 if (CheckInvertible && Denormalized != S) 111 return nullptr; 112 return Normalized; 113 } 114 115 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, 116 ScalarEvolution &SE) { 117 return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); 118 } 119 120 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, 121 const PostIncLoopSet &Loops, 122 ScalarEvolution &SE) { 123 if (Loops.empty()) 124 return S; 125 auto Pred = [&](const SCEVAddRecExpr *AR) { 126 return Loops.count(AR->getLoop()); 127 }; 128 return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); 129 } 130