xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/Delinearization.cpp (revision 27ef5d48c729defb83a8822143dc71ab17f9d68b)
1  //===---- Delinearization.cpp - MultiDimensional Index Delinearization ----===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This implements an analysis pass that tries to delinearize all GEP
10  // instructions in all loops using the SCEV analysis functionality. This pass is
11  // only used for testing purposes: if your pass needs delinearization, please
12  // use the on-demand SCEVAddRecExpr::delinearize() function.
13  //
14  //===----------------------------------------------------------------------===//
15  
16  #include "llvm/Analysis/Delinearization.h"
17  #include "llvm/Analysis/LoopInfo.h"
18  #include "llvm/Analysis/Passes.h"
19  #include "llvm/Analysis/ScalarEvolution.h"
20  #include "llvm/Analysis/ScalarEvolutionDivision.h"
21  #include "llvm/Analysis/ScalarEvolutionExpressions.h"
22  #include "llvm/IR/Constants.h"
23  #include "llvm/IR/DerivedTypes.h"
24  #include "llvm/IR/Function.h"
25  #include "llvm/IR/InstIterator.h"
26  #include "llvm/IR/Instructions.h"
27  #include "llvm/IR/PassManager.h"
28  #include "llvm/Support/Debug.h"
29  #include "llvm/Support/raw_ostream.h"
30  
31  using namespace llvm;
32  
33  #define DL_NAME "delinearize"
34  #define DEBUG_TYPE DL_NAME
35  
36  // Return true when S contains at least an undef value.
37  static inline bool containsUndefs(const SCEV *S) {
38    return SCEVExprContains(S, [](const SCEV *S) {
39      if (const auto *SU = dyn_cast<SCEVUnknown>(S))
40        return isa<UndefValue>(SU->getValue());
41      return false;
42    });
43  }
44  
45  namespace {
46  
47  // Collect all steps of SCEV expressions.
48  struct SCEVCollectStrides {
49    ScalarEvolution &SE;
50    SmallVectorImpl<const SCEV *> &Strides;
51  
52    SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S)
53        : SE(SE), Strides(S) {}
54  
55    bool follow(const SCEV *S) {
56      if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
57        Strides.push_back(AR->getStepRecurrence(SE));
58      return true;
59    }
60  
61    bool isDone() const { return false; }
62  };
63  
64  // Collect all SCEVUnknown and SCEVMulExpr expressions.
65  struct SCEVCollectTerms {
66    SmallVectorImpl<const SCEV *> &Terms;
67  
68    SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {}
69  
70    bool follow(const SCEV *S) {
71      if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
72          isa<SCEVSignExtendExpr>(S)) {
73        if (!containsUndefs(S))
74          Terms.push_back(S);
75  
76        // Stop recursion: once we collected a term, do not walk its operands.
77        return false;
78      }
79  
80      // Keep looking.
81      return true;
82    }
83  
84    bool isDone() const { return false; }
85  };
86  
87  // Check if a SCEV contains an AddRecExpr.
88  struct SCEVHasAddRec {
89    bool &ContainsAddRec;
90  
91    SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
92      ContainsAddRec = false;
93    }
94  
95    bool follow(const SCEV *S) {
96      if (isa<SCEVAddRecExpr>(S)) {
97        ContainsAddRec = true;
98  
99        // Stop recursion: once we collected a term, do not walk its operands.
100        return false;
101      }
102  
103      // Keep looking.
104      return true;
105    }
106  
107    bool isDone() const { return false; }
108  };
109  
110  // Find factors that are multiplied with an expression that (possibly as a
111  // subexpression) contains an AddRecExpr. In the expression:
112  //
113  //  8 * (100 +  %p * %q * (%a + {0, +, 1}_loop))
114  //
115  // "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)"
116  // that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size
117  // parameters as they form a product with an induction variable.
118  //
119  // This collector expects all array size parameters to be in the same MulExpr.
120  // It might be necessary to later add support for collecting parameters that are
121  // spread over different nested MulExpr.
122  struct SCEVCollectAddRecMultiplies {
123    SmallVectorImpl<const SCEV *> &Terms;
124    ScalarEvolution &SE;
125  
126    SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T,
127                                ScalarEvolution &SE)
128        : Terms(T), SE(SE) {}
129  
130    bool follow(const SCEV *S) {
131      if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) {
132        bool HasAddRec = false;
133        SmallVector<const SCEV *, 0> Operands;
134        for (const auto *Op : Mul->operands()) {
135          const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op);
136          if (Unknown && !isa<CallInst>(Unknown->getValue())) {
137            Operands.push_back(Op);
138          } else if (Unknown) {
139            HasAddRec = true;
140          } else {
141            bool ContainsAddRec = false;
142            SCEVHasAddRec ContiansAddRec(ContainsAddRec);
143            visitAll(Op, ContiansAddRec);
144            HasAddRec |= ContainsAddRec;
145          }
146        }
147        if (Operands.size() == 0)
148          return true;
149  
150        if (!HasAddRec)
151          return false;
152  
153        Terms.push_back(SE.getMulExpr(Operands));
154        // Stop recursion: once we collected a term, do not walk its operands.
155        return false;
156      }
157  
158      // Keep looking.
159      return true;
160    }
161  
162    bool isDone() const { return false; }
163  };
164  
165  } // end anonymous namespace
166  
167  /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
168  /// two places:
169  ///   1) The strides of AddRec expressions.
170  ///   2) Unknowns that are multiplied with AddRec expressions.
171  void llvm::collectParametricTerms(ScalarEvolution &SE, const SCEV *Expr,
172                                    SmallVectorImpl<const SCEV *> &Terms) {
173    SmallVector<const SCEV *, 4> Strides;
174    SCEVCollectStrides StrideCollector(SE, Strides);
175    visitAll(Expr, StrideCollector);
176  
177    LLVM_DEBUG({
178      dbgs() << "Strides:\n";
179      for (const SCEV *S : Strides)
180        dbgs() << *S << "\n";
181    });
182  
183    for (const SCEV *S : Strides) {
184      SCEVCollectTerms TermCollector(Terms);
185      visitAll(S, TermCollector);
186    }
187  
188    LLVM_DEBUG({
189      dbgs() << "Terms:\n";
190      for (const SCEV *T : Terms)
191        dbgs() << *T << "\n";
192    });
193  
194    SCEVCollectAddRecMultiplies MulCollector(Terms, SE);
195    visitAll(Expr, MulCollector);
196  }
197  
198  static bool findArrayDimensionsRec(ScalarEvolution &SE,
199                                     SmallVectorImpl<const SCEV *> &Terms,
200                                     SmallVectorImpl<const SCEV *> &Sizes) {
201    int Last = Terms.size() - 1;
202    const SCEV *Step = Terms[Last];
203  
204    // End of recursion.
205    if (Last == 0) {
206      if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) {
207        SmallVector<const SCEV *, 2> Qs;
208        for (const SCEV *Op : M->operands())
209          if (!isa<SCEVConstant>(Op))
210            Qs.push_back(Op);
211  
212        Step = SE.getMulExpr(Qs);
213      }
214  
215      Sizes.push_back(Step);
216      return true;
217    }
218  
219    for (const SCEV *&Term : Terms) {
220      // Normalize the terms before the next call to findArrayDimensionsRec.
221      const SCEV *Q, *R;
222      SCEVDivision::divide(SE, Term, Step, &Q, &R);
223  
224      // Bail out when GCD does not evenly divide one of the terms.
225      if (!R->isZero())
226        return false;
227  
228      Term = Q;
229    }
230  
231    // Remove all SCEVConstants.
232    erase_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); });
233  
234    if (Terms.size() > 0)
235      if (!findArrayDimensionsRec(SE, Terms, Sizes))
236        return false;
237  
238    Sizes.push_back(Step);
239    return true;
240  }
241  
242  // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
243  static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
244    for (const SCEV *T : Terms)
245      if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); }))
246        return true;
247  
248    return false;
249  }
250  
251  // Return the number of product terms in S.
252  static inline int numberOfTerms(const SCEV *S) {
253    if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S))
254      return Expr->getNumOperands();
255    return 1;
256  }
257  
258  static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) {
259    if (isa<SCEVConstant>(T))
260      return nullptr;
261  
262    if (isa<SCEVUnknown>(T))
263      return T;
264  
265    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) {
266      SmallVector<const SCEV *, 2> Factors;
267      for (const SCEV *Op : M->operands())
268        if (!isa<SCEVConstant>(Op))
269          Factors.push_back(Op);
270  
271      return SE.getMulExpr(Factors);
272    }
273  
274    return T;
275  }
276  
277  void llvm::findArrayDimensions(ScalarEvolution &SE,
278                                 SmallVectorImpl<const SCEV *> &Terms,
279                                 SmallVectorImpl<const SCEV *> &Sizes,
280                                 const SCEV *ElementSize) {
281    if (Terms.size() < 1 || !ElementSize)
282      return;
283  
284    // Early return when Terms do not contain parameters: we do not delinearize
285    // non parametric SCEVs.
286    if (!containsParameters(Terms))
287      return;
288  
289    LLVM_DEBUG({
290      dbgs() << "Terms:\n";
291      for (const SCEV *T : Terms)
292        dbgs() << *T << "\n";
293    });
294  
295    // Remove duplicates.
296    array_pod_sort(Terms.begin(), Terms.end());
297    Terms.erase(llvm::unique(Terms), Terms.end());
298  
299    // Put larger terms first.
300    llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) {
301      return numberOfTerms(LHS) > numberOfTerms(RHS);
302    });
303  
304    // Try to divide all terms by the element size. If term is not divisible by
305    // element size, proceed with the original term.
306    for (const SCEV *&Term : Terms) {
307      const SCEV *Q, *R;
308      SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
309      if (!Q->isZero())
310        Term = Q;
311    }
312  
313    SmallVector<const SCEV *, 4> NewTerms;
314  
315    // Remove constant factors.
316    for (const SCEV *T : Terms)
317      if (const SCEV *NewT = removeConstantFactors(SE, T))
318        NewTerms.push_back(NewT);
319  
320    LLVM_DEBUG({
321      dbgs() << "Terms after sorting:\n";
322      for (const SCEV *T : NewTerms)
323        dbgs() << *T << "\n";
324    });
325  
326    if (NewTerms.empty() || !findArrayDimensionsRec(SE, NewTerms, Sizes)) {
327      Sizes.clear();
328      return;
329    }
330  
331    // The last element to be pushed into Sizes is the size of an element.
332    Sizes.push_back(ElementSize);
333  
334    LLVM_DEBUG({
335      dbgs() << "Sizes:\n";
336      for (const SCEV *S : Sizes)
337        dbgs() << *S << "\n";
338    });
339  }
340  
341  void llvm::computeAccessFunctions(ScalarEvolution &SE, const SCEV *Expr,
342                                    SmallVectorImpl<const SCEV *> &Subscripts,
343                                    SmallVectorImpl<const SCEV *> &Sizes) {
344    // Early exit in case this SCEV is not an affine multivariate function.
345    if (Sizes.empty())
346      return;
347  
348    if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr))
349      if (!AR->isAffine())
350        return;
351  
352    const SCEV *Res = Expr;
353    int Last = Sizes.size() - 1;
354    for (int i = Last; i >= 0; i--) {
355      const SCEV *Q, *R;
356      SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R);
357  
358      LLVM_DEBUG({
359        dbgs() << "Res: " << *Res << "\n";
360        dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
361        dbgs() << "Res divided by Sizes[i]:\n";
362        dbgs() << "Quotient: " << *Q << "\n";
363        dbgs() << "Remainder: " << *R << "\n";
364      });
365  
366      Res = Q;
367  
368      // Do not record the last subscript corresponding to the size of elements in
369      // the array.
370      if (i == Last) {
371  
372        // Bail out if the byte offset is non-zero.
373        if (!R->isZero()) {
374          Subscripts.clear();
375          Sizes.clear();
376          return;
377        }
378  
379        continue;
380      }
381  
382      // Record the access function for the current subscript.
383      Subscripts.push_back(R);
384    }
385  
386    // Also push in last position the remainder of the last division: it will be
387    // the access function of the innermost dimension.
388    Subscripts.push_back(Res);
389  
390    std::reverse(Subscripts.begin(), Subscripts.end());
391  
392    LLVM_DEBUG({
393      dbgs() << "Subscripts:\n";
394      for (const SCEV *S : Subscripts)
395        dbgs() << *S << "\n";
396    });
397  }
398  
399  /// Splits the SCEV into two vectors of SCEVs representing the subscripts and
400  /// sizes of an array access. Returns the remainder of the delinearization that
401  /// is the offset start of the array.  The SCEV->delinearize algorithm computes
402  /// the multiples of SCEV coefficients: that is a pattern matching of sub
403  /// expressions in the stride and base of a SCEV corresponding to the
404  /// computation of a GCD (greatest common divisor) of base and stride.  When
405  /// SCEV->delinearize fails, it returns the SCEV unchanged.
406  ///
407  /// For example: when analyzing the memory access A[i][j][k] in this loop nest
408  ///
409  ///  void foo(long n, long m, long o, double A[n][m][o]) {
410  ///
411  ///    for (long i = 0; i < n; i++)
412  ///      for (long j = 0; j < m; j++)
413  ///        for (long k = 0; k < o; k++)
414  ///          A[i][j][k] = 1.0;
415  ///  }
416  ///
417  /// the delinearization input is the following AddRec SCEV:
418  ///
419  ///  AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k>
420  ///
421  /// From this SCEV, we are able to say that the base offset of the access is %A
422  /// because it appears as an offset that does not divide any of the strides in
423  /// the loops:
424  ///
425  ///  CHECK: Base offset: %A
426  ///
427  /// and then SCEV->delinearize determines the size of some of the dimensions of
428  /// the array as these are the multiples by which the strides are happening:
429  ///
430  ///  CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double)
431  ///  bytes.
432  ///
433  /// Note that the outermost dimension remains of UnknownSize because there are
434  /// no strides that would help identifying the size of the last dimension: when
435  /// the array has been statically allocated, one could compute the size of that
436  /// dimension by dividing the overall size of the array by the size of the known
437  /// dimensions: %m * %o * 8.
438  ///
439  /// Finally delinearize provides the access functions for the array reference
440  /// that does correspond to A[i][j][k] of the above C testcase:
441  ///
442  ///  CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>]
443  ///
444  /// The testcases are checking the output of a function pass:
445  /// DelinearizationPass that walks through all loads and stores of a function
446  /// asking for the SCEV of the memory access with respect to all enclosing
447  /// loops, calling SCEV->delinearize on that and printing the results.
448  void llvm::delinearize(ScalarEvolution &SE, const SCEV *Expr,
449                         SmallVectorImpl<const SCEV *> &Subscripts,
450                         SmallVectorImpl<const SCEV *> &Sizes,
451                         const SCEV *ElementSize) {
452    // First step: collect parametric terms.
453    SmallVector<const SCEV *, 4> Terms;
454    collectParametricTerms(SE, Expr, Terms);
455  
456    if (Terms.empty())
457      return;
458  
459    // Second step: find subscript sizes.
460    findArrayDimensions(SE, Terms, Sizes, ElementSize);
461  
462    if (Sizes.empty())
463      return;
464  
465    // Third step: compute the access functions for each subscript.
466    computeAccessFunctions(SE, Expr, Subscripts, Sizes);
467  
468    if (Subscripts.empty())
469      return;
470  
471    LLVM_DEBUG({
472      dbgs() << "succeeded to delinearize " << *Expr << "\n";
473      dbgs() << "ArrayDecl[UnknownSize]";
474      for (const SCEV *S : Sizes)
475        dbgs() << "[" << *S << "]";
476  
477      dbgs() << "\nArrayRef";
478      for (const SCEV *S : Subscripts)
479        dbgs() << "[" << *S << "]";
480      dbgs() << "\n";
481    });
482  }
483  
484  bool llvm::getIndexExpressionsFromGEP(ScalarEvolution &SE,
485                                        const GetElementPtrInst *GEP,
486                                        SmallVectorImpl<const SCEV *> &Subscripts,
487                                        SmallVectorImpl<int> &Sizes) {
488    assert(Subscripts.empty() && Sizes.empty() &&
489           "Expected output lists to be empty on entry to this function.");
490    assert(GEP && "getIndexExpressionsFromGEP called with a null GEP");
491    Type *Ty = nullptr;
492    bool DroppedFirstDim = false;
493    for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
494      const SCEV *Expr = SE.getSCEV(GEP->getOperand(i));
495      if (i == 1) {
496        Ty = GEP->getSourceElementType();
497        if (auto *Const = dyn_cast<SCEVConstant>(Expr))
498          if (Const->getValue()->isZero()) {
499            DroppedFirstDim = true;
500            continue;
501          }
502        Subscripts.push_back(Expr);
503        continue;
504      }
505  
506      auto *ArrayTy = dyn_cast<ArrayType>(Ty);
507      if (!ArrayTy) {
508        Subscripts.clear();
509        Sizes.clear();
510        return false;
511      }
512  
513      Subscripts.push_back(Expr);
514      if (!(DroppedFirstDim && i == 2))
515        Sizes.push_back(ArrayTy->getNumElements());
516  
517      Ty = ArrayTy->getElementType();
518    }
519    return !Subscripts.empty();
520  }
521  
522  bool llvm::tryDelinearizeFixedSizeImpl(
523      ScalarEvolution *SE, Instruction *Inst, const SCEV *AccessFn,
524      SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<int> &Sizes) {
525    Value *SrcPtr = getLoadStorePointerOperand(Inst);
526  
527    // Check the simple case where the array dimensions are fixed size.
528    auto *SrcGEP = dyn_cast<GetElementPtrInst>(SrcPtr);
529    if (!SrcGEP)
530      return false;
531  
532    getIndexExpressionsFromGEP(*SE, SrcGEP, Subscripts, Sizes);
533  
534    // Check that the two size arrays are non-empty and equal in length and
535    // value.
536    // TODO: it would be better to let the caller to clear Subscripts, similar
537    // to how we handle Sizes.
538    if (Sizes.empty() || Subscripts.size() <= 1) {
539      Subscripts.clear();
540      return false;
541    }
542  
543    // Check that for identical base pointers we do not miss index offsets
544    // that have been added before this GEP is applied.
545    Value *SrcBasePtr = SrcGEP->getOperand(0)->stripPointerCasts();
546    const SCEVUnknown *SrcBase =
547        dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn));
548    if (!SrcBase || SrcBasePtr != SrcBase->getValue()) {
549      Subscripts.clear();
550      return false;
551    }
552  
553    assert(Subscripts.size() == Sizes.size() + 1 &&
554           "Expected equal number of entries in the list of size and "
555           "subscript.");
556  
557    return true;
558  }
559  
560  namespace {
561  
562  void printDelinearization(raw_ostream &O, Function *F, LoopInfo *LI,
563                            ScalarEvolution *SE) {
564    O << "Delinearization on function " << F->getName() << ":\n";
565    for (Instruction &Inst : instructions(F)) {
566      // Only analyze loads and stores.
567      if (!isa<StoreInst>(&Inst) && !isa<LoadInst>(&Inst) &&
568          !isa<GetElementPtrInst>(&Inst))
569        continue;
570  
571      const BasicBlock *BB = Inst.getParent();
572      // Delinearize the memory access as analyzed in all the surrounding loops.
573      // Do not analyze memory accesses outside loops.
574      for (Loop *L = LI->getLoopFor(BB); L != nullptr; L = L->getParentLoop()) {
575        const SCEV *AccessFn = SE->getSCEVAtScope(getPointerOperand(&Inst), L);
576  
577        const SCEVUnknown *BasePointer =
578            dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn));
579        // Do not delinearize if we cannot find the base pointer.
580        if (!BasePointer)
581          break;
582        AccessFn = SE->getMinusSCEV(AccessFn, BasePointer);
583  
584        O << "\n";
585        O << "Inst:" << Inst << "\n";
586        O << "In Loop with Header: " << L->getHeader()->getName() << "\n";
587        O << "AccessFunction: " << *AccessFn << "\n";
588  
589        SmallVector<const SCEV *, 3> Subscripts, Sizes;
590        delinearize(*SE, AccessFn, Subscripts, Sizes, SE->getElementSize(&Inst));
591        if (Subscripts.size() == 0 || Sizes.size() == 0 ||
592            Subscripts.size() != Sizes.size()) {
593          O << "failed to delinearize\n";
594          continue;
595        }
596  
597        O << "Base offset: " << *BasePointer << "\n";
598        O << "ArrayDecl[UnknownSize]";
599        int Size = Subscripts.size();
600        for (int i = 0; i < Size - 1; i++)
601          O << "[" << *Sizes[i] << "]";
602        O << " with elements of " << *Sizes[Size - 1] << " bytes.\n";
603  
604        O << "ArrayRef";
605        for (int i = 0; i < Size; i++)
606          O << "[" << *Subscripts[i] << "]";
607        O << "\n";
608      }
609    }
610  }
611  
612  } // end anonymous namespace
613  
614  DelinearizationPrinterPass::DelinearizationPrinterPass(raw_ostream &OS)
615      : OS(OS) {}
616  PreservedAnalyses DelinearizationPrinterPass::run(Function &F,
617                                                    FunctionAnalysisManager &AM) {
618    printDelinearization(OS, &F, &AM.getResult<LoopAnalysis>(F),
619                         &AM.getResult<ScalarEvolutionAnalysis>(F));
620    return PreservedAnalyses::all();
621  }
622