xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/LoopCacheAnalysis.cpp (revision f7f4bd06a8d4e5d1e92d0d2905a68a2a03ed9c0c)
1  //===- LoopCacheAnalysis.cpp - Loop Cache Analysis -------------------------==//
2  //
3  //                     The LLVM Compiler Infrastructure
4  //
5  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6  // See https://llvm.org/LICENSE.txt for license information.
7  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8  //
9  //===----------------------------------------------------------------------===//
10  ///
11  /// \file
12  /// This file defines the implementation for the loop cache analysis.
13  /// The implementation is largely based on the following paper:
14  ///
15  ///       Compiler Optimizations for Improving Data Locality
16  ///       By: Steve Carr, Katherine S. McKinley, Chau-Wen Tseng
17  ///       http://www.cs.utexas.edu/users/mckinley/papers/asplos-1994.pdf
18  ///
19  /// The general approach taken to estimate the number of cache lines used by the
20  /// memory references in an inner loop is:
21  ///    1. Partition memory references that exhibit temporal or spacial reuse
22  ///       into reference groups.
23  ///    2. For each loop L in the a loop nest LN:
24  ///       a. Compute the cost of the reference group
25  ///       b. Compute the loop cost by summing up the reference groups costs
26  //===----------------------------------------------------------------------===//
27  
28  #include "llvm/Analysis/LoopCacheAnalysis.h"
29  #include "llvm/ADT/BreadthFirstIterator.h"
30  #include "llvm/ADT/Sequence.h"
31  #include "llvm/ADT/SmallVector.h"
32  #include "llvm/Analysis/AliasAnalysis.h"
33  #include "llvm/Analysis/Delinearization.h"
34  #include "llvm/Analysis/DependenceAnalysis.h"
35  #include "llvm/Analysis/LoopInfo.h"
36  #include "llvm/Analysis/ScalarEvolutionExpressions.h"
37  #include "llvm/Analysis/TargetTransformInfo.h"
38  #include "llvm/Support/CommandLine.h"
39  #include "llvm/Support/Debug.h"
40  
41  using namespace llvm;
42  
43  #define DEBUG_TYPE "loop-cache-cost"
44  
45  static cl::opt<unsigned> DefaultTripCount(
46      "default-trip-count", cl::init(100), cl::Hidden,
47      cl::desc("Use this to specify the default trip count of a loop"));
48  
49  // In this analysis two array references are considered to exhibit temporal
50  // reuse if they access either the same memory location, or a memory location
51  // with distance smaller than a configurable threshold.
52  static cl::opt<unsigned> TemporalReuseThreshold(
53      "temporal-reuse-threshold", cl::init(2), cl::Hidden,
54      cl::desc("Use this to specify the max. distance between array elements "
55               "accessed in a loop so that the elements are classified to have "
56               "temporal reuse"));
57  
58  /// Retrieve the innermost loop in the given loop nest \p Loops. It returns a
59  /// nullptr if any loops in the loop vector supplied has more than one sibling.
60  /// The loop vector is expected to contain loops collected in breadth-first
61  /// order.
62  static Loop *getInnerMostLoop(const LoopVectorTy &Loops) {
63    assert(!Loops.empty() && "Expecting a non-empy loop vector");
64  
65    Loop *LastLoop = Loops.back();
66    Loop *ParentLoop = LastLoop->getParentLoop();
67  
68    if (ParentLoop == nullptr) {
69      assert(Loops.size() == 1 && "Expecting a single loop");
70      return LastLoop;
71    }
72  
73    return (llvm::is_sorted(Loops,
74                            [](const Loop *L1, const Loop *L2) {
75                              return L1->getLoopDepth() < L2->getLoopDepth();
76                            }))
77               ? LastLoop
78               : nullptr;
79  }
80  
81  static bool isOneDimensionalArray(const SCEV &AccessFn, const SCEV &ElemSize,
82                                    const Loop &L, ScalarEvolution &SE) {
83    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&AccessFn);
84    if (!AR || !AR->isAffine())
85      return false;
86  
87    assert(AR->getLoop() && "AR should have a loop");
88  
89    // Check that start and increment are not add recurrences.
90    const SCEV *Start = AR->getStart();
91    const SCEV *Step = AR->getStepRecurrence(SE);
92    if (isa<SCEVAddRecExpr>(Start) || isa<SCEVAddRecExpr>(Step))
93      return false;
94  
95    // Check that start and increment are both invariant in the loop.
96    if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L))
97      return false;
98  
99    const SCEV *StepRec = AR->getStepRecurrence(SE);
100    if (StepRec && SE.isKnownNegative(StepRec))
101      StepRec = SE.getNegativeSCEV(StepRec);
102  
103    return StepRec == &ElemSize;
104  }
105  
106  /// Compute the trip count for the given loop \p L or assume a default value if
107  /// it is not a compile time constant. Return the SCEV expression for the trip
108  /// count.
109  static const SCEV *computeTripCount(const Loop &L, const SCEV &ElemSize,
110                                      ScalarEvolution &SE) {
111    const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(&L);
112    const SCEV *TripCount = (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
113                             isa<SCEVConstant>(BackedgeTakenCount))
114                                ? SE.getTripCountFromExitCount(BackedgeTakenCount)
115                                : nullptr;
116  
117    if (!TripCount) {
118      LLVM_DEBUG(dbgs() << "Trip count of loop " << L.getName()
119                 << " could not be computed, using DefaultTripCount\n");
120      TripCount = SE.getConstant(ElemSize.getType(), DefaultTripCount);
121    }
122  
123    return TripCount;
124  }
125  
126  //===----------------------------------------------------------------------===//
127  // IndexedReference implementation
128  //
129  raw_ostream &llvm::operator<<(raw_ostream &OS, const IndexedReference &R) {
130    if (!R.IsValid) {
131      OS << R.StoreOrLoadInst;
132      OS << ", IsValid=false.";
133      return OS;
134    }
135  
136    OS << *R.BasePointer;
137    for (const SCEV *Subscript : R.Subscripts)
138      OS << "[" << *Subscript << "]";
139  
140    OS << ", Sizes: ";
141    for (const SCEV *Size : R.Sizes)
142      OS << "[" << *Size << "]";
143  
144    return OS;
145  }
146  
147  IndexedReference::IndexedReference(Instruction &StoreOrLoadInst,
148                                     const LoopInfo &LI, ScalarEvolution &SE)
149      : StoreOrLoadInst(StoreOrLoadInst), SE(SE) {
150    assert((isa<StoreInst>(StoreOrLoadInst) || isa<LoadInst>(StoreOrLoadInst)) &&
151           "Expecting a load or store instruction");
152  
153    IsValid = delinearize(LI);
154    if (IsValid)
155      LLVM_DEBUG(dbgs().indent(2) << "Succesfully delinearized: " << *this
156                                  << "\n");
157  }
158  
159  std::optional<bool>
160  IndexedReference::hasSpacialReuse(const IndexedReference &Other, unsigned CLS,
161                                    AAResults &AA) const {
162    assert(IsValid && "Expecting a valid reference");
163  
164    if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) {
165      LLVM_DEBUG(dbgs().indent(2)
166                 << "No spacial reuse: different base pointers\n");
167      return false;
168    }
169  
170    unsigned NumSubscripts = getNumSubscripts();
171    if (NumSubscripts != Other.getNumSubscripts()) {
172      LLVM_DEBUG(dbgs().indent(2)
173                 << "No spacial reuse: different number of subscripts\n");
174      return false;
175    }
176  
177    // all subscripts must be equal, except the leftmost one (the last one).
178    for (auto SubNum : seq<unsigned>(0, NumSubscripts - 1)) {
179      if (getSubscript(SubNum) != Other.getSubscript(SubNum)) {
180        LLVM_DEBUG(dbgs().indent(2) << "No spacial reuse, different subscripts: "
181                                    << "\n\t" << *getSubscript(SubNum) << "\n\t"
182                                    << *Other.getSubscript(SubNum) << "\n");
183        return false;
184      }
185    }
186  
187    // the difference between the last subscripts must be less than the cache line
188    // size.
189    const SCEV *LastSubscript = getLastSubscript();
190    const SCEV *OtherLastSubscript = Other.getLastSubscript();
191    const SCEVConstant *Diff = dyn_cast<SCEVConstant>(
192        SE.getMinusSCEV(LastSubscript, OtherLastSubscript));
193  
194    if (Diff == nullptr) {
195      LLVM_DEBUG(dbgs().indent(2)
196                 << "No spacial reuse, difference between subscript:\n\t"
197                 << *LastSubscript << "\n\t" << OtherLastSubscript
198                 << "\nis not constant.\n");
199      return std::nullopt;
200    }
201  
202    bool InSameCacheLine = (Diff->getValue()->getSExtValue() < CLS);
203  
204    LLVM_DEBUG({
205      if (InSameCacheLine)
206        dbgs().indent(2) << "Found spacial reuse.\n";
207      else
208        dbgs().indent(2) << "No spacial reuse.\n";
209    });
210  
211    return InSameCacheLine;
212  }
213  
214  std::optional<bool>
215  IndexedReference::hasTemporalReuse(const IndexedReference &Other,
216                                     unsigned MaxDistance, const Loop &L,
217                                     DependenceInfo &DI, AAResults &AA) const {
218    assert(IsValid && "Expecting a valid reference");
219  
220    if (BasePointer != Other.getBasePointer() && !isAliased(Other, AA)) {
221      LLVM_DEBUG(dbgs().indent(2)
222                 << "No temporal reuse: different base pointer\n");
223      return false;
224    }
225  
226    std::unique_ptr<Dependence> D =
227        DI.depends(&StoreOrLoadInst, &Other.StoreOrLoadInst, true);
228  
229    if (D == nullptr) {
230      LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: no dependence\n");
231      return false;
232    }
233  
234    if (D->isLoopIndependent()) {
235      LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n");
236      return true;
237    }
238  
239    // Check the dependence distance at every loop level. There is temporal reuse
240    // if the distance at the given loop's depth is small (|d| <= MaxDistance) and
241    // it is zero at every other loop level.
242    int LoopDepth = L.getLoopDepth();
243    int Levels = D->getLevels();
244    for (int Level = 1; Level <= Levels; ++Level) {
245      const SCEV *Distance = D->getDistance(Level);
246      const SCEVConstant *SCEVConst = dyn_cast_or_null<SCEVConstant>(Distance);
247  
248      if (SCEVConst == nullptr) {
249        LLVM_DEBUG(dbgs().indent(2) << "No temporal reuse: distance unknown\n");
250        return std::nullopt;
251      }
252  
253      const ConstantInt &CI = *SCEVConst->getValue();
254      if (Level != LoopDepth && !CI.isZero()) {
255        LLVM_DEBUG(dbgs().indent(2)
256                   << "No temporal reuse: distance is not zero at depth=" << Level
257                   << "\n");
258        return false;
259      } else if (Level == LoopDepth && CI.getSExtValue() > MaxDistance) {
260        LLVM_DEBUG(
261            dbgs().indent(2)
262            << "No temporal reuse: distance is greater than MaxDistance at depth="
263            << Level << "\n");
264        return false;
265      }
266    }
267  
268    LLVM_DEBUG(dbgs().indent(2) << "Found temporal reuse\n");
269    return true;
270  }
271  
272  CacheCostTy IndexedReference::computeRefCost(const Loop &L,
273                                               unsigned CLS) const {
274    assert(IsValid && "Expecting a valid reference");
275    LLVM_DEBUG({
276      dbgs().indent(2) << "Computing cache cost for:\n";
277      dbgs().indent(4) << *this << "\n";
278    });
279  
280    // If the indexed reference is loop invariant the cost is one.
281    if (isLoopInvariant(L)) {
282      LLVM_DEBUG(dbgs().indent(4) << "Reference is loop invariant: RefCost=1\n");
283      return 1;
284    }
285  
286    const SCEV *TripCount = computeTripCount(L, *Sizes.back(), SE);
287    assert(TripCount && "Expecting valid TripCount");
288    LLVM_DEBUG(dbgs() << "TripCount=" << *TripCount << "\n");
289  
290    const SCEV *RefCost = nullptr;
291    const SCEV *Stride = nullptr;
292    if (isConsecutive(L, Stride, CLS)) {
293      // If the indexed reference is 'consecutive' the cost is
294      // (TripCount*Stride)/CLS.
295      assert(Stride != nullptr &&
296             "Stride should not be null for consecutive access!");
297      Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType());
298      const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS);
299      Stride = SE.getNoopOrAnyExtend(Stride, WiderType);
300      TripCount = SE.getNoopOrAnyExtend(TripCount, WiderType);
301      const SCEV *Numerator = SE.getMulExpr(Stride, TripCount);
302      RefCost = SE.getUDivExpr(Numerator, CacheLineSize);
303  
304      LLVM_DEBUG(dbgs().indent(4)
305                 << "Access is consecutive: RefCost=(TripCount*Stride)/CLS="
306                 << *RefCost << "\n");
307    } else {
308      // If the indexed reference is not 'consecutive' the cost is proportional to
309      // the trip count and the depth of the dimension which the subject loop
310      // subscript is accessing. We try to estimate this by multiplying the cost
311      // by the trip counts of loops corresponding to the inner dimensions. For
312      // example, given the indexed reference 'A[i][j][k]', and assuming the
313      // i-loop is in the innermost position, the cost would be equal to the
314      // iterations of the i-loop multiplied by iterations of the j-loop.
315      RefCost = TripCount;
316  
317      int Index = getSubscriptIndex(L);
318      assert(Index >= 0 && "Cound not locate a valid Index");
319  
320      for (unsigned I = Index + 1; I < getNumSubscripts() - 1; ++I) {
321        const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(getSubscript(I));
322        assert(AR && AR->getLoop() && "Expecting valid loop");
323        const SCEV *TripCount =
324            computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
325        Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
326        RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType),
327                                SE.getNoopOrAnyExtend(TripCount, WiderType));
328      }
329  
330      LLVM_DEBUG(dbgs().indent(4)
331                 << "Access is not consecutive: RefCost=" << *RefCost << "\n");
332    }
333    assert(RefCost && "Expecting a valid RefCost");
334  
335    // Attempt to fold RefCost into a constant.
336    if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
337      return ConstantCost->getValue()->getSExtValue();
338  
339    LLVM_DEBUG(dbgs().indent(4)
340               << "RefCost is not a constant! Setting to RefCost=InvalidCost "
341                  "(invalid value).\n");
342  
343    return CacheCost::InvalidCost;
344  }
345  
346  bool IndexedReference::tryDelinearizeFixedSize(
347      const SCEV *AccessFn, SmallVectorImpl<const SCEV *> &Subscripts) {
348    SmallVector<int, 4> ArraySizes;
349    if (!tryDelinearizeFixedSizeImpl(&SE, &StoreOrLoadInst, AccessFn, Subscripts,
350                                     ArraySizes))
351      return false;
352  
353    // Populate Sizes with scev expressions to be used in calculations later.
354    for (auto Idx : seq<unsigned>(1, Subscripts.size()))
355      Sizes.push_back(
356          SE.getConstant(Subscripts[Idx]->getType(), ArraySizes[Idx - 1]));
357  
358    LLVM_DEBUG({
359      dbgs() << "Delinearized subscripts of fixed-size array\n"
360             << "GEP:" << *getLoadStorePointerOperand(&StoreOrLoadInst)
361             << "\n";
362    });
363    return true;
364  }
365  
366  bool IndexedReference::delinearize(const LoopInfo &LI) {
367    assert(Subscripts.empty() && "Subscripts should be empty");
368    assert(Sizes.empty() && "Sizes should be empty");
369    assert(!IsValid && "Should be called once from the constructor");
370    LLVM_DEBUG(dbgs() << "Delinearizing: " << StoreOrLoadInst << "\n");
371  
372    const SCEV *ElemSize = SE.getElementSize(&StoreOrLoadInst);
373    const BasicBlock *BB = StoreOrLoadInst.getParent();
374  
375    if (Loop *L = LI.getLoopFor(BB)) {
376      const SCEV *AccessFn =
377          SE.getSCEVAtScope(getPointerOperand(&StoreOrLoadInst), L);
378  
379      BasePointer = dyn_cast<SCEVUnknown>(SE.getPointerBase(AccessFn));
380      if (BasePointer == nullptr) {
381        LLVM_DEBUG(
382            dbgs().indent(2)
383            << "ERROR: failed to delinearize, can't identify base pointer\n");
384        return false;
385      }
386  
387      bool IsFixedSize = false;
388      // Try to delinearize fixed-size arrays.
389      if (tryDelinearizeFixedSize(AccessFn, Subscripts)) {
390        IsFixedSize = true;
391        // The last element of Sizes is the element size.
392        Sizes.push_back(ElemSize);
393        LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName()
394                                    << "', AccessFn: " << *AccessFn << "\n");
395      }
396  
397      AccessFn = SE.getMinusSCEV(AccessFn, BasePointer);
398  
399      // Try to delinearize parametric-size arrays.
400      if (!IsFixedSize) {
401        LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName()
402                                    << "', AccessFn: " << *AccessFn << "\n");
403        llvm::delinearize(SE, AccessFn, Subscripts, Sizes,
404                          SE.getElementSize(&StoreOrLoadInst));
405      }
406  
407      if (Subscripts.empty() || Sizes.empty() ||
408          Subscripts.size() != Sizes.size()) {
409        // Attempt to determine whether we have a single dimensional array access.
410        // before giving up.
411        if (!isOneDimensionalArray(*AccessFn, *ElemSize, *L, SE)) {
412          LLVM_DEBUG(dbgs().indent(2)
413                     << "ERROR: failed to delinearize reference\n");
414          Subscripts.clear();
415          Sizes.clear();
416          return false;
417        }
418  
419        // The array may be accessed in reverse, for example:
420        //   for (i = N; i > 0; i--)
421        //     A[i] = 0;
422        // In this case, reconstruct the access function using the absolute value
423        // of the step recurrence.
424        const SCEVAddRecExpr *AccessFnAR = dyn_cast<SCEVAddRecExpr>(AccessFn);
425        const SCEV *StepRec = AccessFnAR ? AccessFnAR->getStepRecurrence(SE) : nullptr;
426  
427        if (StepRec && SE.isKnownNegative(StepRec))
428          AccessFn = SE.getAddRecExpr(AccessFnAR->getStart(),
429                                      SE.getNegativeSCEV(StepRec),
430                                      AccessFnAR->getLoop(),
431                                      AccessFnAR->getNoWrapFlags());
432        const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize);
433        Subscripts.push_back(Div);
434        Sizes.push_back(ElemSize);
435      }
436  
437      return all_of(Subscripts, [&](const SCEV *Subscript) {
438        return isSimpleAddRecurrence(*Subscript, *L);
439      });
440    }
441  
442    return false;
443  }
444  
445  bool IndexedReference::isLoopInvariant(const Loop &L) const {
446    Value *Addr = getPointerOperand(&StoreOrLoadInst);
447    assert(Addr != nullptr && "Expecting either a load or a store instruction");
448    assert(SE.isSCEVable(Addr->getType()) && "Addr should be SCEVable");
449  
450    if (SE.isLoopInvariant(SE.getSCEV(Addr), &L))
451      return true;
452  
453    // The indexed reference is loop invariant if none of the coefficients use
454    // the loop induction variable.
455    bool allCoeffForLoopAreZero = all_of(Subscripts, [&](const SCEV *Subscript) {
456      return isCoeffForLoopZeroOrInvariant(*Subscript, L);
457    });
458  
459    return allCoeffForLoopAreZero;
460  }
461  
462  bool IndexedReference::isConsecutive(const Loop &L, const SCEV *&Stride,
463                                       unsigned CLS) const {
464    // The indexed reference is 'consecutive' if the only coefficient that uses
465    // the loop induction variable is the last one...
466    const SCEV *LastSubscript = Subscripts.back();
467    for (const SCEV *Subscript : Subscripts) {
468      if (Subscript == LastSubscript)
469        continue;
470      if (!isCoeffForLoopZeroOrInvariant(*Subscript, L))
471        return false;
472    }
473  
474    // ...and the access stride is less than the cache line size.
475    const SCEV *Coeff = getLastCoefficient();
476    const SCEV *ElemSize = Sizes.back();
477    Type *WiderType = SE.getWiderType(Coeff->getType(), ElemSize->getType());
478    // FIXME: This assumes that all values are signed integers which may
479    // be incorrect in unusual codes and incorrectly use sext instead of zext.
480    // for (uint32_t i = 0; i < 512; ++i) {
481    //   uint8_t trunc = i;
482    //   A[trunc] = 42;
483    // }
484    // This consecutively iterates twice over A. If `trunc` is sign-extended,
485    // we would conclude that this may iterate backwards over the array.
486    // However, LoopCacheAnalysis is heuristic anyway and transformations must
487    // not result in wrong optimizations if the heuristic was incorrect.
488    Stride = SE.getMulExpr(SE.getNoopOrSignExtend(Coeff, WiderType),
489                           SE.getNoopOrSignExtend(ElemSize, WiderType));
490    const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS);
491  
492    Stride = SE.isKnownNegative(Stride) ? SE.getNegativeSCEV(Stride) : Stride;
493    return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize);
494  }
495  
496  int IndexedReference::getSubscriptIndex(const Loop &L) const {
497    for (auto Idx : seq<int>(0, getNumSubscripts())) {
498      const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(getSubscript(Idx));
499      if (AR && AR->getLoop() == &L) {
500        return Idx;
501      }
502    }
503    return -1;
504  }
505  
506  const SCEV *IndexedReference::getLastCoefficient() const {
507    const SCEV *LastSubscript = getLastSubscript();
508    auto *AR = cast<SCEVAddRecExpr>(LastSubscript);
509    return AR->getStepRecurrence(SE);
510  }
511  
512  bool IndexedReference::isCoeffForLoopZeroOrInvariant(const SCEV &Subscript,
513                                                       const Loop &L) const {
514    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(&Subscript);
515    return (AR != nullptr) ? AR->getLoop() != &L
516                           : SE.isLoopInvariant(&Subscript, &L);
517  }
518  
519  bool IndexedReference::isSimpleAddRecurrence(const SCEV &Subscript,
520                                               const Loop &L) const {
521    if (!isa<SCEVAddRecExpr>(Subscript))
522      return false;
523  
524    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(&Subscript);
525    assert(AR->getLoop() && "AR should have a loop");
526  
527    if (!AR->isAffine())
528      return false;
529  
530    const SCEV *Start = AR->getStart();
531    const SCEV *Step = AR->getStepRecurrence(SE);
532  
533    if (!SE.isLoopInvariant(Start, &L) || !SE.isLoopInvariant(Step, &L))
534      return false;
535  
536    return true;
537  }
538  
539  bool IndexedReference::isAliased(const IndexedReference &Other,
540                                   AAResults &AA) const {
541    const auto &Loc1 = MemoryLocation::get(&StoreOrLoadInst);
542    const auto &Loc2 = MemoryLocation::get(&Other.StoreOrLoadInst);
543    return AA.isMustAlias(Loc1, Loc2);
544  }
545  
546  //===----------------------------------------------------------------------===//
547  // CacheCost implementation
548  //
549  raw_ostream &llvm::operator<<(raw_ostream &OS, const CacheCost &CC) {
550    for (const auto &LC : CC.LoopCosts) {
551      const Loop *L = LC.first;
552      OS << "Loop '" << L->getName() << "' has cost = " << LC.second << "\n";
553    }
554    return OS;
555  }
556  
557  CacheCost::CacheCost(const LoopVectorTy &Loops, const LoopInfo &LI,
558                       ScalarEvolution &SE, TargetTransformInfo &TTI,
559                       AAResults &AA, DependenceInfo &DI,
560                       std::optional<unsigned> TRT)
561      : Loops(Loops), TRT(TRT.value_or(TemporalReuseThreshold)), LI(LI), SE(SE),
562        TTI(TTI), AA(AA), DI(DI) {
563    assert(!Loops.empty() && "Expecting a non-empty loop vector.");
564  
565    for (const Loop *L : Loops) {
566      unsigned TripCount = SE.getSmallConstantTripCount(L);
567      TripCount = (TripCount == 0) ? DefaultTripCount : TripCount;
568      TripCounts.push_back({L, TripCount});
569    }
570  
571    calculateCacheFootprint();
572  }
573  
574  std::unique_ptr<CacheCost>
575  CacheCost::getCacheCost(Loop &Root, LoopStandardAnalysisResults &AR,
576                          DependenceInfo &DI, std::optional<unsigned> TRT) {
577    if (!Root.isOutermost()) {
578      LLVM_DEBUG(dbgs() << "Expecting the outermost loop in a loop nest\n");
579      return nullptr;
580    }
581  
582    LoopVectorTy Loops;
583    append_range(Loops, breadth_first(&Root));
584  
585    if (!getInnerMostLoop(Loops)) {
586      LLVM_DEBUG(dbgs() << "Cannot compute cache cost of loop nest with more "
587                           "than one innermost loop\n");
588      return nullptr;
589    }
590  
591    return std::make_unique<CacheCost>(Loops, AR.LI, AR.SE, AR.TTI, AR.AA, DI, TRT);
592  }
593  
594  void CacheCost::calculateCacheFootprint() {
595    LLVM_DEBUG(dbgs() << "POPULATING REFERENCE GROUPS\n");
596    ReferenceGroupsTy RefGroups;
597    if (!populateReferenceGroups(RefGroups))
598      return;
599  
600    LLVM_DEBUG(dbgs() << "COMPUTING LOOP CACHE COSTS\n");
601    for (const Loop *L : Loops) {
602      assert(llvm::none_of(
603                 LoopCosts,
604                 [L](const LoopCacheCostTy &LCC) { return LCC.first == L; }) &&
605             "Should not add duplicate element");
606      CacheCostTy LoopCost = computeLoopCacheCost(*L, RefGroups);
607      LoopCosts.push_back(std::make_pair(L, LoopCost));
608    }
609  
610    sortLoopCosts();
611    RefGroups.clear();
612  }
613  
614  bool CacheCost::populateReferenceGroups(ReferenceGroupsTy &RefGroups) const {
615    assert(RefGroups.empty() && "Reference groups should be empty");
616  
617    unsigned CLS = TTI.getCacheLineSize();
618    Loop *InnerMostLoop = getInnerMostLoop(Loops);
619    assert(InnerMostLoop != nullptr && "Expecting a valid innermost loop");
620  
621    for (BasicBlock *BB : InnerMostLoop->getBlocks()) {
622      for (Instruction &I : *BB) {
623        if (!isa<StoreInst>(I) && !isa<LoadInst>(I))
624          continue;
625  
626        std::unique_ptr<IndexedReference> R(new IndexedReference(I, LI, SE));
627        if (!R->isValid())
628          continue;
629  
630        bool Added = false;
631        for (ReferenceGroupTy &RefGroup : RefGroups) {
632          const IndexedReference &Representative = *RefGroup.front();
633          LLVM_DEBUG({
634            dbgs() << "References:\n";
635            dbgs().indent(2) << *R << "\n";
636            dbgs().indent(2) << Representative << "\n";
637          });
638  
639  
640         // FIXME: Both positive and negative access functions will be placed
641         // into the same reference group, resulting in a bi-directional array
642         // access such as:
643         //   for (i = N; i > 0; i--)
644         //     A[i] = A[N - i];
645         // having the same cost calculation as a single dimention access pattern
646         //   for (i = 0; i < N; i++)
647         //     A[i] = A[i];
648         // when in actuality, depending on the array size, the first example
649         // should have a cost closer to 2x the second due to the two cache
650         // access per iteration from opposite ends of the array
651          std::optional<bool> HasTemporalReuse =
652              R->hasTemporalReuse(Representative, *TRT, *InnerMostLoop, DI, AA);
653          std::optional<bool> HasSpacialReuse =
654              R->hasSpacialReuse(Representative, CLS, AA);
655  
656          if ((HasTemporalReuse && *HasTemporalReuse) ||
657              (HasSpacialReuse && *HasSpacialReuse)) {
658            RefGroup.push_back(std::move(R));
659            Added = true;
660            break;
661          }
662        }
663  
664        if (!Added) {
665          ReferenceGroupTy RG;
666          RG.push_back(std::move(R));
667          RefGroups.push_back(std::move(RG));
668        }
669      }
670    }
671  
672    if (RefGroups.empty())
673      return false;
674  
675    LLVM_DEBUG({
676      dbgs() << "\nIDENTIFIED REFERENCE GROUPS:\n";
677      int n = 1;
678      for (const ReferenceGroupTy &RG : RefGroups) {
679        dbgs().indent(2) << "RefGroup " << n << ":\n";
680        for (const auto &IR : RG)
681          dbgs().indent(4) << *IR << "\n";
682        n++;
683      }
684      dbgs() << "\n";
685    });
686  
687    return true;
688  }
689  
690  CacheCostTy
691  CacheCost::computeLoopCacheCost(const Loop &L,
692                                  const ReferenceGroupsTy &RefGroups) const {
693    if (!L.isLoopSimplifyForm())
694      return InvalidCost;
695  
696    LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName()
697                      << "' as innermost loop.\n");
698  
699    // Compute the product of the trip counts of each other loop in the nest.
700    CacheCostTy TripCountsProduct = 1;
701    for (const auto &TC : TripCounts) {
702      if (TC.first == &L)
703        continue;
704      TripCountsProduct *= TC.second;
705    }
706  
707    CacheCostTy LoopCost = 0;
708    for (const ReferenceGroupTy &RG : RefGroups) {
709      CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L);
710      LoopCost += RefGroupCost * TripCountsProduct;
711    }
712  
713    LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName()
714                                << "' has cost=" << LoopCost << "\n");
715  
716    return LoopCost;
717  }
718  
719  CacheCostTy CacheCost::computeRefGroupCacheCost(const ReferenceGroupTy &RG,
720                                                  const Loop &L) const {
721    assert(!RG.empty() && "Reference group should have at least one member.");
722  
723    const IndexedReference *Representative = RG.front().get();
724    return Representative->computeRefCost(L, TTI.getCacheLineSize());
725  }
726  
727  //===----------------------------------------------------------------------===//
728  // LoopCachePrinterPass implementation
729  //
730  PreservedAnalyses LoopCachePrinterPass::run(Loop &L, LoopAnalysisManager &AM,
731                                              LoopStandardAnalysisResults &AR,
732                                              LPMUpdater &U) {
733    Function *F = L.getHeader()->getParent();
734    DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI);
735  
736    if (auto CC = CacheCost::getCacheCost(L, AR, DI))
737      OS << *CC;
738  
739    return PreservedAnalyses::all();
740  }
741