xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp (revision 62987288060ff68c817b7056815aa9fb8ba8ecd7)
1 //===- LoopAccessAnalysis.cpp - Loop Access Analysis Implementation --------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The implementation for the loop memory dependence that was originally
10 // developed for the loop vectorizer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/LoopAccessAnalysis.h"
15 #include "llvm/ADT/APInt.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/PointerIntPair.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/AliasSetTracker.h"
26 #include "llvm/Analysis/LoopAnalysisManager.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/LoopIterator.h"
29 #include "llvm/Analysis/MemoryLocation.h"
30 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
31 #include "llvm/Analysis/ScalarEvolution.h"
32 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
33 #include "llvm/Analysis/TargetLibraryInfo.h"
34 #include "llvm/Analysis/TargetTransformInfo.h"
35 #include "llvm/Analysis/ValueTracking.h"
36 #include "llvm/Analysis/VectorUtils.h"
37 #include "llvm/IR/BasicBlock.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/DataLayout.h"
40 #include "llvm/IR/DebugLoc.h"
41 #include "llvm/IR/DerivedTypes.h"
42 #include "llvm/IR/DiagnosticInfo.h"
43 #include "llvm/IR/Dominators.h"
44 #include "llvm/IR/Function.h"
45 #include "llvm/IR/GetElementPtrTypeIterator.h"
46 #include "llvm/IR/InstrTypes.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Instructions.h"
49 #include "llvm/IR/Operator.h"
50 #include "llvm/IR/PassManager.h"
51 #include "llvm/IR/PatternMatch.h"
52 #include "llvm/IR/Type.h"
53 #include "llvm/IR/Value.h"
54 #include "llvm/IR/ValueHandle.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include <algorithm>
61 #include <cassert>
62 #include <cstdint>
63 #include <iterator>
64 #include <utility>
65 #include <variant>
66 #include <vector>
67 
68 using namespace llvm;
69 using namespace llvm::PatternMatch;
70 
71 #define DEBUG_TYPE "loop-accesses"
72 
73 static cl::opt<unsigned, true>
74 VectorizationFactor("force-vector-width", cl::Hidden,
75                     cl::desc("Sets the SIMD width. Zero is autoselect."),
76                     cl::location(VectorizerParams::VectorizationFactor));
77 unsigned VectorizerParams::VectorizationFactor;
78 
79 static cl::opt<unsigned, true>
80 VectorizationInterleave("force-vector-interleave", cl::Hidden,
81                         cl::desc("Sets the vectorization interleave count. "
82                                  "Zero is autoselect."),
83                         cl::location(
84                             VectorizerParams::VectorizationInterleave));
85 unsigned VectorizerParams::VectorizationInterleave;
86 
87 static cl::opt<unsigned, true> RuntimeMemoryCheckThreshold(
88     "runtime-memory-check-threshold", cl::Hidden,
89     cl::desc("When performing memory disambiguation checks at runtime do not "
90              "generate more than this number of comparisons (default = 8)."),
91     cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8));
92 unsigned VectorizerParams::RuntimeMemoryCheckThreshold;
93 
94 /// The maximum iterations used to merge memory checks
95 static cl::opt<unsigned> MemoryCheckMergeThreshold(
96     "memory-check-merge-threshold", cl::Hidden,
97     cl::desc("Maximum number of comparisons done when trying to merge "
98              "runtime memory checks. (default = 100)"),
99     cl::init(100));
100 
101 /// Maximum SIMD width.
102 const unsigned VectorizerParams::MaxVectorWidth = 64;
103 
104 /// We collect dependences up to this threshold.
105 static cl::opt<unsigned>
106     MaxDependences("max-dependences", cl::Hidden,
107                    cl::desc("Maximum number of dependences collected by "
108                             "loop-access analysis (default = 100)"),
109                    cl::init(100));
110 
111 /// This enables versioning on the strides of symbolically striding memory
112 /// accesses in code like the following.
113 ///   for (i = 0; i < N; ++i)
114 ///     A[i * Stride1] += B[i * Stride2] ...
115 ///
116 /// Will be roughly translated to
117 ///    if (Stride1 == 1 && Stride2 == 1) {
118 ///      for (i = 0; i < N; i+=4)
119 ///       A[i:i+3] += ...
120 ///    } else
121 ///      ...
122 static cl::opt<bool> EnableMemAccessVersioning(
123     "enable-mem-access-versioning", cl::init(true), cl::Hidden,
124     cl::desc("Enable symbolic stride memory access versioning"));
125 
126 /// Enable store-to-load forwarding conflict detection. This option can
127 /// be disabled for correctness testing.
128 static cl::opt<bool> EnableForwardingConflictDetection(
129     "store-to-load-forwarding-conflict-detection", cl::Hidden,
130     cl::desc("Enable conflict detection in loop-access analysis"),
131     cl::init(true));
132 
133 static cl::opt<unsigned> MaxForkedSCEVDepth(
134     "max-forked-scev-depth", cl::Hidden,
135     cl::desc("Maximum recursion depth when finding forked SCEVs (default = 5)"),
136     cl::init(5));
137 
138 static cl::opt<bool> SpeculateUnitStride(
139     "laa-speculate-unit-stride", cl::Hidden,
140     cl::desc("Speculate that non-constant strides are unit in LAA"),
141     cl::init(true));
142 
143 static cl::opt<bool, true> HoistRuntimeChecks(
144     "hoist-runtime-checks", cl::Hidden,
145     cl::desc(
146         "Hoist inner loop runtime memory checks to outer loop if possible"),
147     cl::location(VectorizerParams::HoistRuntimeChecks), cl::init(true));
148 bool VectorizerParams::HoistRuntimeChecks;
149 
isInterleaveForced()150 bool VectorizerParams::isInterleaveForced() {
151   return ::VectorizationInterleave.getNumOccurrences() > 0;
152 }
153 
replaceSymbolicStrideSCEV(PredicatedScalarEvolution & PSE,const DenseMap<Value *,const SCEV * > & PtrToStride,Value * Ptr)154 const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
155                                             const DenseMap<Value *, const SCEV *> &PtrToStride,
156                                             Value *Ptr) {
157   const SCEV *OrigSCEV = PSE.getSCEV(Ptr);
158 
159   // If there is an entry in the map return the SCEV of the pointer with the
160   // symbolic stride replaced by one.
161   DenseMap<Value *, const SCEV *>::const_iterator SI = PtrToStride.find(Ptr);
162   if (SI == PtrToStride.end())
163     // For a non-symbolic stride, just return the original expression.
164     return OrigSCEV;
165 
166   const SCEV *StrideSCEV = SI->second;
167   // Note: This assert is both overly strong and overly weak.  The actual
168   // invariant here is that StrideSCEV should be loop invariant.  The only
169   // such invariant strides we happen to speculate right now are unknowns
170   // and thus this is a reasonable proxy of the actual invariant.
171   assert(isa<SCEVUnknown>(StrideSCEV) && "shouldn't be in map");
172 
173   ScalarEvolution *SE = PSE.getSE();
174   const auto *CT = SE->getOne(StrideSCEV->getType());
175   PSE.addPredicate(*SE->getEqualPredicate(StrideSCEV, CT));
176   auto *Expr = PSE.getSCEV(Ptr);
177 
178   LLVM_DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV
179 	     << " by: " << *Expr << "\n");
180   return Expr;
181 }
182 
RuntimeCheckingPtrGroup(unsigned Index,RuntimePointerChecking & RtCheck)183 RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
184     unsigned Index, RuntimePointerChecking &RtCheck)
185     : High(RtCheck.Pointers[Index].End), Low(RtCheck.Pointers[Index].Start),
186       AddressSpace(RtCheck.Pointers[Index]
187                        .PointerValue->getType()
188                        ->getPointerAddressSpace()),
189       NeedsFreeze(RtCheck.Pointers[Index].NeedsFreeze) {
190   Members.push_back(Index);
191 }
192 
193 /// Calculate Start and End points of memory access.
194 /// Let's assume A is the first access and B is a memory access on N-th loop
195 /// iteration. Then B is calculated as:
196 ///   B = A + Step*N .
197 /// Step value may be positive or negative.
198 /// N is a calculated back-edge taken count:
199 ///     N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0
200 /// Start and End points are calculated in the following way:
201 /// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt,
202 /// where SizeOfElt is the size of single memory access in bytes.
203 ///
204 /// There is no conflict when the intervals are disjoint:
205 /// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
getStartAndEndForAccess(const Loop * Lp,const SCEV * PtrExpr,Type * AccessTy,PredicatedScalarEvolution & PSE,DenseMap<std::pair<const SCEV *,Type * >,std::pair<const SCEV *,const SCEV * >> & PointerBounds)206 static std::pair<const SCEV *, const SCEV *> getStartAndEndForAccess(
207     const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy,
208     PredicatedScalarEvolution &PSE,
209     DenseMap<std::pair<const SCEV *, Type *>,
210              std::pair<const SCEV *, const SCEV *>> &PointerBounds) {
211   ScalarEvolution *SE = PSE.getSE();
212 
213   auto [Iter, Ins] = PointerBounds.insert(
214       {{PtrExpr, AccessTy},
215        {SE->getCouldNotCompute(), SE->getCouldNotCompute()}});
216   if (!Ins)
217     return Iter->second;
218 
219   const SCEV *ScStart;
220   const SCEV *ScEnd;
221 
222   if (SE->isLoopInvariant(PtrExpr, Lp)) {
223     ScStart = ScEnd = PtrExpr;
224   } else if (auto *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr)) {
225     const SCEV *Ex = PSE.getSymbolicMaxBackedgeTakenCount();
226 
227     ScStart = AR->getStart();
228     ScEnd = AR->evaluateAtIteration(Ex, *SE);
229     const SCEV *Step = AR->getStepRecurrence(*SE);
230 
231     // For expressions with negative step, the upper bound is ScStart and the
232     // lower bound is ScEnd.
233     if (const auto *CStep = dyn_cast<SCEVConstant>(Step)) {
234       if (CStep->getValue()->isNegative())
235         std::swap(ScStart, ScEnd);
236     } else {
237       // Fallback case: the step is not constant, but we can still
238       // get the upper and lower bounds of the interval by using min/max
239       // expressions.
240       ScStart = SE->getUMinExpr(ScStart, ScEnd);
241       ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd);
242     }
243   } else
244     return {SE->getCouldNotCompute(), SE->getCouldNotCompute()};
245 
246   assert(SE->isLoopInvariant(ScStart, Lp) && "ScStart needs to be invariant");
247   assert(SE->isLoopInvariant(ScEnd, Lp)&& "ScEnd needs to be invariant");
248 
249   // Add the size of the pointed element to ScEnd.
250   auto &DL = Lp->getHeader()->getDataLayout();
251   Type *IdxTy = DL.getIndexType(PtrExpr->getType());
252   const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy);
253   ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
254 
255   Iter->second = {ScStart, ScEnd};
256   return Iter->second;
257 }
258 
259 /// Calculate Start and End points of memory access using
260 /// getStartAndEndForAccess.
insert(Loop * Lp,Value * Ptr,const SCEV * PtrExpr,Type * AccessTy,bool WritePtr,unsigned DepSetId,unsigned ASId,PredicatedScalarEvolution & PSE,bool NeedsFreeze)261 void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
262                                     Type *AccessTy, bool WritePtr,
263                                     unsigned DepSetId, unsigned ASId,
264                                     PredicatedScalarEvolution &PSE,
265                                     bool NeedsFreeze) {
266   const auto &[ScStart, ScEnd] = getStartAndEndForAccess(
267       Lp, PtrExpr, AccessTy, PSE, DC.getPointerBounds());
268   assert(!isa<SCEVCouldNotCompute>(ScStart) &&
269          !isa<SCEVCouldNotCompute>(ScEnd) &&
270          "must be able to compute both start and end expressions");
271   Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr,
272                         NeedsFreeze);
273 }
274 
tryToCreateDiffCheck(const RuntimeCheckingPtrGroup & CGI,const RuntimeCheckingPtrGroup & CGJ)275 bool RuntimePointerChecking::tryToCreateDiffCheck(
276     const RuntimeCheckingPtrGroup &CGI, const RuntimeCheckingPtrGroup &CGJ) {
277   // If either group contains multiple different pointers, bail out.
278   // TODO: Support multiple pointers by using the minimum or maximum pointer,
279   // depending on src & sink.
280   if (CGI.Members.size() != 1 || CGJ.Members.size() != 1)
281     return false;
282 
283   PointerInfo *Src = &Pointers[CGI.Members[0]];
284   PointerInfo *Sink = &Pointers[CGJ.Members[0]];
285 
286   // If either pointer is read and written, multiple checks may be needed. Bail
287   // out.
288   if (!DC.getOrderForAccess(Src->PointerValue, !Src->IsWritePtr).empty() ||
289       !DC.getOrderForAccess(Sink->PointerValue, !Sink->IsWritePtr).empty())
290     return false;
291 
292   ArrayRef<unsigned> AccSrc =
293       DC.getOrderForAccess(Src->PointerValue, Src->IsWritePtr);
294   ArrayRef<unsigned> AccSink =
295       DC.getOrderForAccess(Sink->PointerValue, Sink->IsWritePtr);
296   // If either pointer is accessed multiple times, there may not be a clear
297   // src/sink relation. Bail out for now.
298   if (AccSrc.size() != 1 || AccSink.size() != 1)
299     return false;
300 
301   // If the sink is accessed before src, swap src/sink.
302   if (AccSink[0] < AccSrc[0])
303     std::swap(Src, Sink);
304 
305   auto *SrcAR = dyn_cast<SCEVAddRecExpr>(Src->Expr);
306   auto *SinkAR = dyn_cast<SCEVAddRecExpr>(Sink->Expr);
307   if (!SrcAR || !SinkAR || SrcAR->getLoop() != DC.getInnermostLoop() ||
308       SinkAR->getLoop() != DC.getInnermostLoop())
309     return false;
310 
311   SmallVector<Instruction *, 4> SrcInsts =
312       DC.getInstructionsForAccess(Src->PointerValue, Src->IsWritePtr);
313   SmallVector<Instruction *, 4> SinkInsts =
314       DC.getInstructionsForAccess(Sink->PointerValue, Sink->IsWritePtr);
315   Type *SrcTy = getLoadStoreType(SrcInsts[0]);
316   Type *DstTy = getLoadStoreType(SinkInsts[0]);
317   if (isa<ScalableVectorType>(SrcTy) || isa<ScalableVectorType>(DstTy))
318     return false;
319 
320   const DataLayout &DL =
321       SinkAR->getLoop()->getHeader()->getDataLayout();
322   unsigned AllocSize =
323       std::max(DL.getTypeAllocSize(SrcTy), DL.getTypeAllocSize(DstTy));
324 
325   // Only matching constant steps matching the AllocSize are supported at the
326   // moment. This simplifies the difference computation. Can be extended in the
327   // future.
328   auto *Step = dyn_cast<SCEVConstant>(SinkAR->getStepRecurrence(*SE));
329   if (!Step || Step != SrcAR->getStepRecurrence(*SE) ||
330       Step->getAPInt().abs() != AllocSize)
331     return false;
332 
333   IntegerType *IntTy =
334       IntegerType::get(Src->PointerValue->getContext(),
335                        DL.getPointerSizeInBits(CGI.AddressSpace));
336 
337   // When counting down, the dependence distance needs to be swapped.
338   if (Step->getValue()->isNegative())
339     std::swap(SinkAR, SrcAR);
340 
341   const SCEV *SinkStartInt = SE->getPtrToIntExpr(SinkAR->getStart(), IntTy);
342   const SCEV *SrcStartInt = SE->getPtrToIntExpr(SrcAR->getStart(), IntTy);
343   if (isa<SCEVCouldNotCompute>(SinkStartInt) ||
344       isa<SCEVCouldNotCompute>(SrcStartInt))
345     return false;
346 
347   const Loop *InnerLoop = SrcAR->getLoop();
348   // If the start values for both Src and Sink also vary according to an outer
349   // loop, then it's probably better to avoid creating diff checks because
350   // they may not be hoisted. We should instead let llvm::addRuntimeChecks
351   // do the expanded full range overlap checks, which can be hoisted.
352   if (HoistRuntimeChecks && InnerLoop->getParentLoop() &&
353       isa<SCEVAddRecExpr>(SinkStartInt) && isa<SCEVAddRecExpr>(SrcStartInt)) {
354     auto *SrcStartAR = cast<SCEVAddRecExpr>(SrcStartInt);
355     auto *SinkStartAR = cast<SCEVAddRecExpr>(SinkStartInt);
356     const Loop *StartARLoop = SrcStartAR->getLoop();
357     if (StartARLoop == SinkStartAR->getLoop() &&
358         StartARLoop == InnerLoop->getParentLoop() &&
359         // If the diff check would already be loop invariant (due to the
360         // recurrences being the same), then we prefer to keep the diff checks
361         // because they are cheaper.
362         SrcStartAR->getStepRecurrence(*SE) !=
363             SinkStartAR->getStepRecurrence(*SE)) {
364       LLVM_DEBUG(dbgs() << "LAA: Not creating diff runtime check, since these "
365                            "cannot be hoisted out of the outer loop\n");
366       return false;
367     }
368   }
369 
370   LLVM_DEBUG(dbgs() << "LAA: Creating diff runtime check for:\n"
371                     << "SrcStart: " << *SrcStartInt << '\n'
372                     << "SinkStartInt: " << *SinkStartInt << '\n');
373   DiffChecks.emplace_back(SrcStartInt, SinkStartInt, AllocSize,
374                           Src->NeedsFreeze || Sink->NeedsFreeze);
375   return true;
376 }
377 
generateChecks()378 SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() {
379   SmallVector<RuntimePointerCheck, 4> Checks;
380 
381   for (unsigned I = 0; I < CheckingGroups.size(); ++I) {
382     for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) {
383       const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I];
384       const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J];
385 
386       if (needsChecking(CGI, CGJ)) {
387         CanUseDiffCheck = CanUseDiffCheck && tryToCreateDiffCheck(CGI, CGJ);
388         Checks.push_back(std::make_pair(&CGI, &CGJ));
389       }
390     }
391   }
392   return Checks;
393 }
394 
generateChecks(MemoryDepChecker::DepCandidates & DepCands,bool UseDependencies)395 void RuntimePointerChecking::generateChecks(
396     MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) {
397   assert(Checks.empty() && "Checks is not empty");
398   groupChecks(DepCands, UseDependencies);
399   Checks = generateChecks();
400 }
401 
needsChecking(const RuntimeCheckingPtrGroup & M,const RuntimeCheckingPtrGroup & N) const402 bool RuntimePointerChecking::needsChecking(
403     const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const {
404   for (const auto &I : M.Members)
405     for (const auto &J : N.Members)
406       if (needsChecking(I, J))
407         return true;
408   return false;
409 }
410 
411 /// Compare \p I and \p J and return the minimum.
412 /// Return nullptr in case we couldn't find an answer.
getMinFromExprs(const SCEV * I,const SCEV * J,ScalarEvolution * SE)413 static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J,
414                                    ScalarEvolution *SE) {
415   const SCEV *Diff = SE->getMinusSCEV(J, I);
416   const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff);
417 
418   if (!C)
419     return nullptr;
420   return C->getValue()->isNegative() ? J : I;
421 }
422 
addPointer(unsigned Index,RuntimePointerChecking & RtCheck)423 bool RuntimeCheckingPtrGroup::addPointer(unsigned Index,
424                                          RuntimePointerChecking &RtCheck) {
425   return addPointer(
426       Index, RtCheck.Pointers[Index].Start, RtCheck.Pointers[Index].End,
427       RtCheck.Pointers[Index].PointerValue->getType()->getPointerAddressSpace(),
428       RtCheck.Pointers[Index].NeedsFreeze, *RtCheck.SE);
429 }
430 
addPointer(unsigned Index,const SCEV * Start,const SCEV * End,unsigned AS,bool NeedsFreeze,ScalarEvolution & SE)431 bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, const SCEV *Start,
432                                          const SCEV *End, unsigned AS,
433                                          bool NeedsFreeze,
434                                          ScalarEvolution &SE) {
435   assert(AddressSpace == AS &&
436          "all pointers in a checking group must be in the same address space");
437 
438   // Compare the starts and ends with the known minimum and maximum
439   // of this set. We need to know how we compare against the min/max
440   // of the set in order to be able to emit memchecks.
441   const SCEV *Min0 = getMinFromExprs(Start, Low, &SE);
442   if (!Min0)
443     return false;
444 
445   const SCEV *Min1 = getMinFromExprs(End, High, &SE);
446   if (!Min1)
447     return false;
448 
449   // Update the low bound  expression if we've found a new min value.
450   if (Min0 == Start)
451     Low = Start;
452 
453   // Update the high bound expression if we've found a new max value.
454   if (Min1 != End)
455     High = End;
456 
457   Members.push_back(Index);
458   this->NeedsFreeze |= NeedsFreeze;
459   return true;
460 }
461 
groupChecks(MemoryDepChecker::DepCandidates & DepCands,bool UseDependencies)462 void RuntimePointerChecking::groupChecks(
463     MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) {
464   // We build the groups from dependency candidates equivalence classes
465   // because:
466   //    - We know that pointers in the same equivalence class share
467   //      the same underlying object and therefore there is a chance
468   //      that we can compare pointers
469   //    - We wouldn't be able to merge two pointers for which we need
470   //      to emit a memcheck. The classes in DepCands are already
471   //      conveniently built such that no two pointers in the same
472   //      class need checking against each other.
473 
474   // We use the following (greedy) algorithm to construct the groups
475   // For every pointer in the equivalence class:
476   //   For each existing group:
477   //   - if the difference between this pointer and the min/max bounds
478   //     of the group is a constant, then make the pointer part of the
479   //     group and update the min/max bounds of that group as required.
480 
481   CheckingGroups.clear();
482 
483   // If we need to check two pointers to the same underlying object
484   // with a non-constant difference, we shouldn't perform any pointer
485   // grouping with those pointers. This is because we can easily get
486   // into cases where the resulting check would return false, even when
487   // the accesses are safe.
488   //
489   // The following example shows this:
490   // for (i = 0; i < 1000; ++i)
491   //   a[5000 + i * m] = a[i] + a[i + 9000]
492   //
493   // Here grouping gives a check of (5000, 5000 + 1000 * m) against
494   // (0, 10000) which is always false. However, if m is 1, there is no
495   // dependence. Not grouping the checks for a[i] and a[i + 9000] allows
496   // us to perform an accurate check in this case.
497   //
498   // The above case requires that we have an UnknownDependence between
499   // accesses to the same underlying object. This cannot happen unless
500   // FoundNonConstantDistanceDependence is set, and therefore UseDependencies
501   // is also false. In this case we will use the fallback path and create
502   // separate checking groups for all pointers.
503 
504   // If we don't have the dependency partitions, construct a new
505   // checking pointer group for each pointer. This is also required
506   // for correctness, because in this case we can have checking between
507   // pointers to the same underlying object.
508   if (!UseDependencies) {
509     for (unsigned I = 0; I < Pointers.size(); ++I)
510       CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this));
511     return;
512   }
513 
514   unsigned TotalComparisons = 0;
515 
516   DenseMap<Value *, SmallVector<unsigned>> PositionMap;
517   for (unsigned Index = 0; Index < Pointers.size(); ++Index) {
518     auto [It, _] = PositionMap.insert({Pointers[Index].PointerValue, {}});
519     It->second.push_back(Index);
520   }
521 
522   // We need to keep track of what pointers we've already seen so we
523   // don't process them twice.
524   SmallSet<unsigned, 2> Seen;
525 
526   // Go through all equivalence classes, get the "pointer check groups"
527   // and add them to the overall solution. We use the order in which accesses
528   // appear in 'Pointers' to enforce determinism.
529   for (unsigned I = 0; I < Pointers.size(); ++I) {
530     // We've seen this pointer before, and therefore already processed
531     // its equivalence class.
532     if (Seen.count(I))
533       continue;
534 
535     MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue,
536                                            Pointers[I].IsWritePtr);
537 
538     SmallVector<RuntimeCheckingPtrGroup, 2> Groups;
539     auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access));
540 
541     // Because DepCands is constructed by visiting accesses in the order in
542     // which they appear in alias sets (which is deterministic) and the
543     // iteration order within an equivalence class member is only dependent on
544     // the order in which unions and insertions are performed on the
545     // equivalence class, the iteration order is deterministic.
546     for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end();
547          MI != ME; ++MI) {
548       auto PointerI = PositionMap.find(MI->getPointer());
549       assert(PointerI != PositionMap.end() &&
550              "pointer in equivalence class not found in PositionMap");
551       for (unsigned Pointer : PointerI->second) {
552         bool Merged = false;
553         // Mark this pointer as seen.
554         Seen.insert(Pointer);
555 
556         // Go through all the existing sets and see if we can find one
557         // which can include this pointer.
558         for (RuntimeCheckingPtrGroup &Group : Groups) {
559           // Don't perform more than a certain amount of comparisons.
560           // This should limit the cost of grouping the pointers to something
561           // reasonable.  If we do end up hitting this threshold, the algorithm
562           // will create separate groups for all remaining pointers.
563           if (TotalComparisons > MemoryCheckMergeThreshold)
564             break;
565 
566           TotalComparisons++;
567 
568           if (Group.addPointer(Pointer, *this)) {
569             Merged = true;
570             break;
571           }
572         }
573 
574         if (!Merged)
575           // We couldn't add this pointer to any existing set or the threshold
576           // for the number of comparisons has been reached. Create a new group
577           // to hold the current pointer.
578           Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
579       }
580     }
581 
582     // We've computed the grouped checks for this partition.
583     // Save the results and continue with the next one.
584     llvm::copy(Groups, std::back_inserter(CheckingGroups));
585   }
586 }
587 
arePointersInSamePartition(const SmallVectorImpl<int> & PtrToPartition,unsigned PtrIdx1,unsigned PtrIdx2)588 bool RuntimePointerChecking::arePointersInSamePartition(
589     const SmallVectorImpl<int> &PtrToPartition, unsigned PtrIdx1,
590     unsigned PtrIdx2) {
591   return (PtrToPartition[PtrIdx1] != -1 &&
592           PtrToPartition[PtrIdx1] == PtrToPartition[PtrIdx2]);
593 }
594 
needsChecking(unsigned I,unsigned J) const595 bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const {
596   const PointerInfo &PointerI = Pointers[I];
597   const PointerInfo &PointerJ = Pointers[J];
598 
599   // No need to check if two readonly pointers intersect.
600   if (!PointerI.IsWritePtr && !PointerJ.IsWritePtr)
601     return false;
602 
603   // Only need to check pointers between two different dependency sets.
604   if (PointerI.DependencySetId == PointerJ.DependencySetId)
605     return false;
606 
607   // Only need to check pointers in the same alias set.
608   if (PointerI.AliasSetId != PointerJ.AliasSetId)
609     return false;
610 
611   return true;
612 }
613 
printChecks(raw_ostream & OS,const SmallVectorImpl<RuntimePointerCheck> & Checks,unsigned Depth) const614 void RuntimePointerChecking::printChecks(
615     raw_ostream &OS, const SmallVectorImpl<RuntimePointerCheck> &Checks,
616     unsigned Depth) const {
617   unsigned N = 0;
618   for (const auto &[Check1, Check2] : Checks) {
619     const auto &First = Check1->Members, &Second = Check2->Members;
620 
621     OS.indent(Depth) << "Check " << N++ << ":\n";
622 
623     OS.indent(Depth + 2) << "Comparing group (" << Check1 << "):\n";
624     for (unsigned K : First)
625       OS.indent(Depth + 2) << *Pointers[K].PointerValue << "\n";
626 
627     OS.indent(Depth + 2) << "Against group (" << Check2 << "):\n";
628     for (unsigned K : Second)
629       OS.indent(Depth + 2) << *Pointers[K].PointerValue << "\n";
630   }
631 }
632 
print(raw_ostream & OS,unsigned Depth) const633 void RuntimePointerChecking::print(raw_ostream &OS, unsigned Depth) const {
634 
635   OS.indent(Depth) << "Run-time memory checks:\n";
636   printChecks(OS, Checks, Depth);
637 
638   OS.indent(Depth) << "Grouped accesses:\n";
639   for (const auto &CG : CheckingGroups) {
640     OS.indent(Depth + 2) << "Group " << &CG << ":\n";
641     OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High
642                          << ")\n";
643     for (unsigned Member : CG.Members) {
644       OS.indent(Depth + 6) << "Member: " << *Pointers[Member].Expr << "\n";
645     }
646   }
647 }
648 
649 namespace {
650 
651 /// Analyses memory accesses in a loop.
652 ///
653 /// Checks whether run time pointer checks are needed and builds sets for data
654 /// dependence checking.
655 class AccessAnalysis {
656 public:
657   /// Read or write access location.
658   typedef PointerIntPair<Value *, 1, bool> MemAccessInfo;
659   typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList;
660 
AccessAnalysis(Loop * TheLoop,AAResults * AA,LoopInfo * LI,MemoryDepChecker::DepCandidates & DA,PredicatedScalarEvolution & PSE,SmallPtrSetImpl<MDNode * > & LoopAliasScopes)661   AccessAnalysis(Loop *TheLoop, AAResults *AA, LoopInfo *LI,
662                  MemoryDepChecker::DepCandidates &DA,
663                  PredicatedScalarEvolution &PSE,
664                  SmallPtrSetImpl<MDNode *> &LoopAliasScopes)
665       : TheLoop(TheLoop), BAA(*AA), AST(BAA), LI(LI), DepCands(DA), PSE(PSE),
666         LoopAliasScopes(LoopAliasScopes) {
667     // We're analyzing dependences across loop iterations.
668     BAA.enableCrossIterationMode();
669   }
670 
671   /// Register a load  and whether it is only read from.
addLoad(MemoryLocation & Loc,Type * AccessTy,bool IsReadOnly)672   void addLoad(MemoryLocation &Loc, Type *AccessTy, bool IsReadOnly) {
673     Value *Ptr = const_cast<Value *>(Loc.Ptr);
674     AST.add(adjustLoc(Loc));
675     Accesses[MemAccessInfo(Ptr, false)].insert(AccessTy);
676     if (IsReadOnly)
677       ReadOnlyPtr.insert(Ptr);
678   }
679 
680   /// Register a store.
addStore(MemoryLocation & Loc,Type * AccessTy)681   void addStore(MemoryLocation &Loc, Type *AccessTy) {
682     Value *Ptr = const_cast<Value *>(Loc.Ptr);
683     AST.add(adjustLoc(Loc));
684     Accesses[MemAccessInfo(Ptr, true)].insert(AccessTy);
685   }
686 
687   /// Check if we can emit a run-time no-alias check for \p Access.
688   ///
689   /// Returns true if we can emit a run-time no alias check for \p Access.
690   /// If we can check this access, this also adds it to a dependence set and
691   /// adds a run-time to check for it to \p RtCheck. If \p Assume is true,
692   /// we will attempt to use additional run-time checks in order to get
693   /// the bounds of the pointer.
694   bool createCheckForAccess(RuntimePointerChecking &RtCheck,
695                             MemAccessInfo Access, Type *AccessTy,
696                             const DenseMap<Value *, const SCEV *> &Strides,
697                             DenseMap<Value *, unsigned> &DepSetId,
698                             Loop *TheLoop, unsigned &RunningDepId,
699                             unsigned ASId, bool ShouldCheckStride, bool Assume);
700 
701   /// Check whether we can check the pointers at runtime for
702   /// non-intersection.
703   ///
704   /// Returns true if we need no check or if we do and we can generate them
705   /// (i.e. the pointers have computable bounds).
706   bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE,
707                        Loop *TheLoop, const DenseMap<Value *, const SCEV *> &Strides,
708                        Value *&UncomputablePtr, bool ShouldCheckWrap = false);
709 
710   /// Goes over all memory accesses, checks whether a RT check is needed
711   /// and builds sets of dependent accesses.
buildDependenceSets()712   void buildDependenceSets() {
713     processMemAccesses();
714   }
715 
716   /// Initial processing of memory accesses determined that we need to
717   /// perform dependency checking.
718   ///
719   /// Note that this can later be cleared if we retry memcheck analysis without
720   /// dependency checking (i.e. FoundNonConstantDistanceDependence).
isDependencyCheckNeeded()721   bool isDependencyCheckNeeded() { return !CheckDeps.empty(); }
722 
723   /// We decided that no dependence analysis would be used.  Reset the state.
resetDepChecks(MemoryDepChecker & DepChecker)724   void resetDepChecks(MemoryDepChecker &DepChecker) {
725     CheckDeps.clear();
726     DepChecker.clearDependences();
727   }
728 
getDependenciesToCheck()729   MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; }
730 
731 private:
732   typedef MapVector<MemAccessInfo, SmallSetVector<Type *, 1>> PtrAccessMap;
733 
734   /// Adjust the MemoryLocation so that it represents accesses to this
735   /// location across all iterations, rather than a single one.
adjustLoc(MemoryLocation Loc) const736   MemoryLocation adjustLoc(MemoryLocation Loc) const {
737     // The accessed location varies within the loop, but remains within the
738     // underlying object.
739     Loc.Size = LocationSize::beforeOrAfterPointer();
740     Loc.AATags.Scope = adjustAliasScopeList(Loc.AATags.Scope);
741     Loc.AATags.NoAlias = adjustAliasScopeList(Loc.AATags.NoAlias);
742     return Loc;
743   }
744 
745   /// Drop alias scopes that are only valid within a single loop iteration.
adjustAliasScopeList(MDNode * ScopeList) const746   MDNode *adjustAliasScopeList(MDNode *ScopeList) const {
747     if (!ScopeList)
748       return nullptr;
749 
750     // For the sake of simplicity, drop the whole scope list if any scope is
751     // iteration-local.
752     if (any_of(ScopeList->operands(), [&](Metadata *Scope) {
753           return LoopAliasScopes.contains(cast<MDNode>(Scope));
754         }))
755       return nullptr;
756 
757     return ScopeList;
758   }
759 
760   /// Go over all memory access and check whether runtime pointer checks
761   /// are needed and build sets of dependency check candidates.
762   void processMemAccesses();
763 
764   /// Map of all accesses. Values are the types used to access memory pointed to
765   /// by the pointer.
766   PtrAccessMap Accesses;
767 
768   /// The loop being checked.
769   const Loop *TheLoop;
770 
771   /// List of accesses that need a further dependence check.
772   MemAccessInfoList CheckDeps;
773 
774   /// Set of pointers that are read only.
775   SmallPtrSet<Value*, 16> ReadOnlyPtr;
776 
777   /// Batched alias analysis results.
778   BatchAAResults BAA;
779 
780   /// An alias set tracker to partition the access set by underlying object and
781   //intrinsic property (such as TBAA metadata).
782   AliasSetTracker AST;
783 
784   LoopInfo *LI;
785 
786   /// Sets of potentially dependent accesses - members of one set share an
787   /// underlying pointer. The set "CheckDeps" identfies which sets really need a
788   /// dependence check.
789   MemoryDepChecker::DepCandidates &DepCands;
790 
791   /// Initial processing of memory accesses determined that we may need
792   /// to add memchecks.  Perform the analysis to determine the necessary checks.
793   ///
794   /// Note that, this is different from isDependencyCheckNeeded.  When we retry
795   /// memcheck analysis without dependency checking
796   /// (i.e. FoundNonConstantDistanceDependence), isDependencyCheckNeeded is
797   /// cleared while this remains set if we have potentially dependent accesses.
798   bool IsRTCheckAnalysisNeeded = false;
799 
800   /// The SCEV predicate containing all the SCEV-related assumptions.
801   PredicatedScalarEvolution &PSE;
802 
803   DenseMap<Value *, SmallVector<const Value *, 16>> UnderlyingObjects;
804 
805   /// Alias scopes that are declared inside the loop, and as such not valid
806   /// across iterations.
807   SmallPtrSetImpl<MDNode *> &LoopAliasScopes;
808 };
809 
810 } // end anonymous namespace
811 
812 /// Check whether a pointer can participate in a runtime bounds check.
813 /// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
814 /// by adding run-time checks (overflow checks) if necessary.
hasComputableBounds(PredicatedScalarEvolution & PSE,Value * Ptr,const SCEV * PtrScev,Loop * L,bool Assume)815 static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr,
816                                 const SCEV *PtrScev, Loop *L, bool Assume) {
817   // The bounds for loop-invariant pointer is trivial.
818   if (PSE.getSE()->isLoopInvariant(PtrScev, L))
819     return true;
820 
821   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
822 
823   if (!AR && Assume)
824     AR = PSE.getAsAddRec(Ptr);
825 
826   if (!AR)
827     return false;
828 
829   return AR->isAffine();
830 }
831 
832 /// Check whether a pointer address cannot wrap.
isNoWrap(PredicatedScalarEvolution & PSE,const DenseMap<Value *,const SCEV * > & Strides,Value * Ptr,Type * AccessTy,Loop * L)833 static bool isNoWrap(PredicatedScalarEvolution &PSE,
834                      const DenseMap<Value *, const SCEV *> &Strides, Value *Ptr, Type *AccessTy,
835                      Loop *L) {
836   const SCEV *PtrScev = PSE.getSCEV(Ptr);
837   if (PSE.getSE()->isLoopInvariant(PtrScev, L))
838     return true;
839 
840   int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, L, Strides).value_or(0);
841   if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW))
842     return true;
843 
844   return false;
845 }
846 
visitPointers(Value * StartPtr,const Loop & InnermostLoop,function_ref<void (Value *)> AddPointer)847 static void visitPointers(Value *StartPtr, const Loop &InnermostLoop,
848                           function_ref<void(Value *)> AddPointer) {
849   SmallPtrSet<Value *, 8> Visited;
850   SmallVector<Value *> WorkList;
851   WorkList.push_back(StartPtr);
852 
853   while (!WorkList.empty()) {
854     Value *Ptr = WorkList.pop_back_val();
855     if (!Visited.insert(Ptr).second)
856       continue;
857     auto *PN = dyn_cast<PHINode>(Ptr);
858     // SCEV does not look through non-header PHIs inside the loop. Such phis
859     // can be analyzed by adding separate accesses for each incoming pointer
860     // value.
861     if (PN && InnermostLoop.contains(PN->getParent()) &&
862         PN->getParent() != InnermostLoop.getHeader()) {
863       for (const Use &Inc : PN->incoming_values())
864         WorkList.push_back(Inc);
865     } else
866       AddPointer(Ptr);
867   }
868 }
869 
870 // Walk back through the IR for a pointer, looking for a select like the
871 // following:
872 //
873 //  %offset = select i1 %cmp, i64 %a, i64 %b
874 //  %addr = getelementptr double, double* %base, i64 %offset
875 //  %ld = load double, double* %addr, align 8
876 //
877 // We won't be able to form a single SCEVAddRecExpr from this since the
878 // address for each loop iteration depends on %cmp. We could potentially
879 // produce multiple valid SCEVAddRecExprs, though, and check all of them for
880 // memory safety/aliasing if needed.
881 //
882 // If we encounter some IR we don't yet handle, or something obviously fine
883 // like a constant, then we just add the SCEV for that term to the list passed
884 // in by the caller. If we have a node that may potentially yield a valid
885 // SCEVAddRecExpr then we decompose it into parts and build the SCEV terms
886 // ourselves before adding to the list.
findForkedSCEVs(ScalarEvolution * SE,const Loop * L,Value * Ptr,SmallVectorImpl<PointerIntPair<const SCEV *,1,bool>> & ScevList,unsigned Depth)887 static void findForkedSCEVs(
888     ScalarEvolution *SE, const Loop *L, Value *Ptr,
889     SmallVectorImpl<PointerIntPair<const SCEV *, 1, bool>> &ScevList,
890     unsigned Depth) {
891   // If our Value is a SCEVAddRecExpr, loop invariant, not an instruction, or
892   // we've exceeded our limit on recursion, just return whatever we have
893   // regardless of whether it can be used for a forked pointer or not, along
894   // with an indication of whether it might be a poison or undef value.
895   const SCEV *Scev = SE->getSCEV(Ptr);
896   if (isa<SCEVAddRecExpr>(Scev) || L->isLoopInvariant(Ptr) ||
897       !isa<Instruction>(Ptr) || Depth == 0) {
898     ScevList.emplace_back(Scev, !isGuaranteedNotToBeUndefOrPoison(Ptr));
899     return;
900   }
901 
902   Depth--;
903 
904   auto UndefPoisonCheck = [](PointerIntPair<const SCEV *, 1, bool> S) {
905     return get<1>(S);
906   };
907 
908   auto GetBinOpExpr = [&SE](unsigned Opcode, const SCEV *L, const SCEV *R) {
909     switch (Opcode) {
910     case Instruction::Add:
911       return SE->getAddExpr(L, R);
912     case Instruction::Sub:
913       return SE->getMinusSCEV(L, R);
914     default:
915       llvm_unreachable("Unexpected binary operator when walking ForkedPtrs");
916     }
917   };
918 
919   Instruction *I = cast<Instruction>(Ptr);
920   unsigned Opcode = I->getOpcode();
921   switch (Opcode) {
922   case Instruction::GetElementPtr: {
923     GetElementPtrInst *GEP = cast<GetElementPtrInst>(I);
924     Type *SourceTy = GEP->getSourceElementType();
925     // We only handle base + single offset GEPs here for now.
926     // Not dealing with preexisting gathers yet, so no vectors.
927     if (I->getNumOperands() != 2 || SourceTy->isVectorTy()) {
928       ScevList.emplace_back(Scev, !isGuaranteedNotToBeUndefOrPoison(GEP));
929       break;
930     }
931     SmallVector<PointerIntPair<const SCEV *, 1, bool>, 2> BaseScevs;
932     SmallVector<PointerIntPair<const SCEV *, 1, bool>, 2> OffsetScevs;
933     findForkedSCEVs(SE, L, I->getOperand(0), BaseScevs, Depth);
934     findForkedSCEVs(SE, L, I->getOperand(1), OffsetScevs, Depth);
935 
936     // See if we need to freeze our fork...
937     bool NeedsFreeze = any_of(BaseScevs, UndefPoisonCheck) ||
938                        any_of(OffsetScevs, UndefPoisonCheck);
939 
940     // Check that we only have a single fork, on either the base or the offset.
941     // Copy the SCEV across for the one without a fork in order to generate
942     // the full SCEV for both sides of the GEP.
943     if (OffsetScevs.size() == 2 && BaseScevs.size() == 1)
944       BaseScevs.push_back(BaseScevs[0]);
945     else if (BaseScevs.size() == 2 && OffsetScevs.size() == 1)
946       OffsetScevs.push_back(OffsetScevs[0]);
947     else {
948       ScevList.emplace_back(Scev, NeedsFreeze);
949       break;
950     }
951 
952     // Find the pointer type we need to extend to.
953     Type *IntPtrTy = SE->getEffectiveSCEVType(
954         SE->getSCEV(GEP->getPointerOperand())->getType());
955 
956     // Find the size of the type being pointed to. We only have a single
957     // index term (guarded above) so we don't need to index into arrays or
958     // structures, just get the size of the scalar value.
959     const SCEV *Size = SE->getSizeOfExpr(IntPtrTy, SourceTy);
960 
961     // Scale up the offsets by the size of the type, then add to the bases.
962     const SCEV *Scaled1 = SE->getMulExpr(
963         Size, SE->getTruncateOrSignExtend(get<0>(OffsetScevs[0]), IntPtrTy));
964     const SCEV *Scaled2 = SE->getMulExpr(
965         Size, SE->getTruncateOrSignExtend(get<0>(OffsetScevs[1]), IntPtrTy));
966     ScevList.emplace_back(SE->getAddExpr(get<0>(BaseScevs[0]), Scaled1),
967                           NeedsFreeze);
968     ScevList.emplace_back(SE->getAddExpr(get<0>(BaseScevs[1]), Scaled2),
969                           NeedsFreeze);
970     break;
971   }
972   case Instruction::Select: {
973     SmallVector<PointerIntPair<const SCEV *, 1, bool>, 2> ChildScevs;
974     // A select means we've found a forked pointer, but we currently only
975     // support a single select per pointer so if there's another behind this
976     // then we just bail out and return the generic SCEV.
977     findForkedSCEVs(SE, L, I->getOperand(1), ChildScevs, Depth);
978     findForkedSCEVs(SE, L, I->getOperand(2), ChildScevs, Depth);
979     if (ChildScevs.size() == 2) {
980       ScevList.push_back(ChildScevs[0]);
981       ScevList.push_back(ChildScevs[1]);
982     } else
983       ScevList.emplace_back(Scev, !isGuaranteedNotToBeUndefOrPoison(Ptr));
984     break;
985   }
986   case Instruction::PHI: {
987     SmallVector<PointerIntPair<const SCEV *, 1, bool>, 2> ChildScevs;
988     // A phi means we've found a forked pointer, but we currently only
989     // support a single phi per pointer so if there's another behind this
990     // then we just bail out and return the generic SCEV.
991     if (I->getNumOperands() == 2) {
992       findForkedSCEVs(SE, L, I->getOperand(0), ChildScevs, Depth);
993       findForkedSCEVs(SE, L, I->getOperand(1), ChildScevs, Depth);
994     }
995     if (ChildScevs.size() == 2) {
996       ScevList.push_back(ChildScevs[0]);
997       ScevList.push_back(ChildScevs[1]);
998     } else
999       ScevList.emplace_back(Scev, !isGuaranteedNotToBeUndefOrPoison(Ptr));
1000     break;
1001   }
1002   case Instruction::Add:
1003   case Instruction::Sub: {
1004     SmallVector<PointerIntPair<const SCEV *, 1, bool>> LScevs;
1005     SmallVector<PointerIntPair<const SCEV *, 1, bool>> RScevs;
1006     findForkedSCEVs(SE, L, I->getOperand(0), LScevs, Depth);
1007     findForkedSCEVs(SE, L, I->getOperand(1), RScevs, Depth);
1008 
1009     // See if we need to freeze our fork...
1010     bool NeedsFreeze =
1011         any_of(LScevs, UndefPoisonCheck) || any_of(RScevs, UndefPoisonCheck);
1012 
1013     // Check that we only have a single fork, on either the left or right side.
1014     // Copy the SCEV across for the one without a fork in order to generate
1015     // the full SCEV for both sides of the BinOp.
1016     if (LScevs.size() == 2 && RScevs.size() == 1)
1017       RScevs.push_back(RScevs[0]);
1018     else if (RScevs.size() == 2 && LScevs.size() == 1)
1019       LScevs.push_back(LScevs[0]);
1020     else {
1021       ScevList.emplace_back(Scev, NeedsFreeze);
1022       break;
1023     }
1024 
1025     ScevList.emplace_back(
1026         GetBinOpExpr(Opcode, get<0>(LScevs[0]), get<0>(RScevs[0])),
1027         NeedsFreeze);
1028     ScevList.emplace_back(
1029         GetBinOpExpr(Opcode, get<0>(LScevs[1]), get<0>(RScevs[1])),
1030         NeedsFreeze);
1031     break;
1032   }
1033   default:
1034     // Just return the current SCEV if we haven't handled the instruction yet.
1035     LLVM_DEBUG(dbgs() << "ForkedPtr unhandled instruction: " << *I << "\n");
1036     ScevList.emplace_back(Scev, !isGuaranteedNotToBeUndefOrPoison(Ptr));
1037     break;
1038   }
1039 }
1040 
1041 static SmallVector<PointerIntPair<const SCEV *, 1, bool>>
findForkedPointer(PredicatedScalarEvolution & PSE,const DenseMap<Value *,const SCEV * > & StridesMap,Value * Ptr,const Loop * L)1042 findForkedPointer(PredicatedScalarEvolution &PSE,
1043                   const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr,
1044                   const Loop *L) {
1045   ScalarEvolution *SE = PSE.getSE();
1046   assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!");
1047   SmallVector<PointerIntPair<const SCEV *, 1, bool>> Scevs;
1048   findForkedSCEVs(SE, L, Ptr, Scevs, MaxForkedSCEVDepth);
1049 
1050   // For now, we will only accept a forked pointer with two possible SCEVs
1051   // that are either SCEVAddRecExprs or loop invariant.
1052   if (Scevs.size() == 2 &&
1053       (isa<SCEVAddRecExpr>(get<0>(Scevs[0])) ||
1054        SE->isLoopInvariant(get<0>(Scevs[0]), L)) &&
1055       (isa<SCEVAddRecExpr>(get<0>(Scevs[1])) ||
1056        SE->isLoopInvariant(get<0>(Scevs[1]), L))) {
1057     LLVM_DEBUG(dbgs() << "LAA: Found forked pointer: " << *Ptr << "\n");
1058     LLVM_DEBUG(dbgs() << "\t(1) " << *get<0>(Scevs[0]) << "\n");
1059     LLVM_DEBUG(dbgs() << "\t(2) " << *get<0>(Scevs[1]) << "\n");
1060     return Scevs;
1061   }
1062 
1063   return {{replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr), false}};
1064 }
1065 
createCheckForAccess(RuntimePointerChecking & RtCheck,MemAccessInfo Access,Type * AccessTy,const DenseMap<Value *,const SCEV * > & StridesMap,DenseMap<Value *,unsigned> & DepSetId,Loop * TheLoop,unsigned & RunningDepId,unsigned ASId,bool ShouldCheckWrap,bool Assume)1066 bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
1067                                           MemAccessInfo Access, Type *AccessTy,
1068                                           const DenseMap<Value *, const SCEV *> &StridesMap,
1069                                           DenseMap<Value *, unsigned> &DepSetId,
1070                                           Loop *TheLoop, unsigned &RunningDepId,
1071                                           unsigned ASId, bool ShouldCheckWrap,
1072                                           bool Assume) {
1073   Value *Ptr = Access.getPointer();
1074 
1075   SmallVector<PointerIntPair<const SCEV *, 1, bool>> TranslatedPtrs =
1076       findForkedPointer(PSE, StridesMap, Ptr, TheLoop);
1077 
1078   for (auto &P : TranslatedPtrs) {
1079     const SCEV *PtrExpr = get<0>(P);
1080     if (!hasComputableBounds(PSE, Ptr, PtrExpr, TheLoop, Assume))
1081       return false;
1082 
1083     // When we run after a failing dependency check we have to make sure
1084     // we don't have wrapping pointers.
1085     if (ShouldCheckWrap) {
1086       // Skip wrap checking when translating pointers.
1087       if (TranslatedPtrs.size() > 1)
1088         return false;
1089 
1090       if (!isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
1091         auto *Expr = PSE.getSCEV(Ptr);
1092         if (!Assume || !isa<SCEVAddRecExpr>(Expr))
1093           return false;
1094         PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
1095       }
1096     }
1097     // If there's only one option for Ptr, look it up after bounds and wrap
1098     // checking, because assumptions might have been added to PSE.
1099     if (TranslatedPtrs.size() == 1)
1100       TranslatedPtrs[0] = {replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr),
1101                            false};
1102   }
1103 
1104   for (auto [PtrExpr, NeedsFreeze] : TranslatedPtrs) {
1105     // The id of the dependence set.
1106     unsigned DepId;
1107 
1108     if (isDependencyCheckNeeded()) {
1109       Value *Leader = DepCands.getLeaderValue(Access).getPointer();
1110       unsigned &LeaderId = DepSetId[Leader];
1111       if (!LeaderId)
1112         LeaderId = RunningDepId++;
1113       DepId = LeaderId;
1114     } else
1115       // Each access has its own dependence set.
1116       DepId = RunningDepId++;
1117 
1118     bool IsWrite = Access.getInt();
1119     RtCheck.insert(TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE,
1120                    NeedsFreeze);
1121     LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
1122   }
1123 
1124   return true;
1125 }
1126 
canCheckPtrAtRT(RuntimePointerChecking & RtCheck,ScalarEvolution * SE,Loop * TheLoop,const DenseMap<Value *,const SCEV * > & StridesMap,Value * & UncomputablePtr,bool ShouldCheckWrap)1127 bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
1128                                      ScalarEvolution *SE, Loop *TheLoop,
1129                                      const DenseMap<Value *, const SCEV *> &StridesMap,
1130                                      Value *&UncomputablePtr, bool ShouldCheckWrap) {
1131   // Find pointers with computable bounds. We are going to use this information
1132   // to place a runtime bound check.
1133   bool CanDoRT = true;
1134 
1135   bool MayNeedRTCheck = false;
1136   if (!IsRTCheckAnalysisNeeded) return true;
1137 
1138   bool IsDepCheckNeeded = isDependencyCheckNeeded();
1139 
1140   // We assign a consecutive id to access from different alias sets.
1141   // Accesses between different groups doesn't need to be checked.
1142   unsigned ASId = 0;
1143   for (auto &AS : AST) {
1144     int NumReadPtrChecks = 0;
1145     int NumWritePtrChecks = 0;
1146     bool CanDoAliasSetRT = true;
1147     ++ASId;
1148     auto ASPointers = AS.getPointers();
1149 
1150     // We assign consecutive id to access from different dependence sets.
1151     // Accesses within the same set don't need a runtime check.
1152     unsigned RunningDepId = 1;
1153     DenseMap<Value *, unsigned> DepSetId;
1154 
1155     SmallVector<std::pair<MemAccessInfo, Type *>, 4> Retries;
1156 
1157     // First, count how many write and read accesses are in the alias set. Also
1158     // collect MemAccessInfos for later.
1159     SmallVector<MemAccessInfo, 4> AccessInfos;
1160     for (const Value *ConstPtr : ASPointers) {
1161       Value *Ptr = const_cast<Value *>(ConstPtr);
1162       bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true));
1163       if (IsWrite)
1164         ++NumWritePtrChecks;
1165       else
1166         ++NumReadPtrChecks;
1167       AccessInfos.emplace_back(Ptr, IsWrite);
1168     }
1169 
1170     // We do not need runtime checks for this alias set, if there are no writes
1171     // or a single write and no reads.
1172     if (NumWritePtrChecks == 0 ||
1173         (NumWritePtrChecks == 1 && NumReadPtrChecks == 0)) {
1174       assert((ASPointers.size() <= 1 ||
1175               all_of(ASPointers,
1176                      [this](const Value *Ptr) {
1177                        MemAccessInfo AccessWrite(const_cast<Value *>(Ptr),
1178                                                  true);
1179                        return DepCands.findValue(AccessWrite) == DepCands.end();
1180                      })) &&
1181              "Can only skip updating CanDoRT below, if all entries in AS "
1182              "are reads or there is at most 1 entry");
1183       continue;
1184     }
1185 
1186     for (auto &Access : AccessInfos) {
1187       for (const auto &AccessTy : Accesses[Access]) {
1188         if (!createCheckForAccess(RtCheck, Access, AccessTy, StridesMap,
1189                                   DepSetId, TheLoop, RunningDepId, ASId,
1190                                   ShouldCheckWrap, false)) {
1191           LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:"
1192                             << *Access.getPointer() << '\n');
1193           Retries.push_back({Access, AccessTy});
1194           CanDoAliasSetRT = false;
1195         }
1196       }
1197     }
1198 
1199     // Note that this function computes CanDoRT and MayNeedRTCheck
1200     // independently. For example CanDoRT=false, MayNeedRTCheck=false means that
1201     // we have a pointer for which we couldn't find the bounds but we don't
1202     // actually need to emit any checks so it does not matter.
1203     //
1204     // We need runtime checks for this alias set, if there are at least 2
1205     // dependence sets (in which case RunningDepId > 2) or if we need to re-try
1206     // any bound checks (because in that case the number of dependence sets is
1207     // incomplete).
1208     bool NeedsAliasSetRTCheck = RunningDepId > 2 || !Retries.empty();
1209 
1210     // We need to perform run-time alias checks, but some pointers had bounds
1211     // that couldn't be checked.
1212     if (NeedsAliasSetRTCheck && !CanDoAliasSetRT) {
1213       // Reset the CanDoSetRt flag and retry all accesses that have failed.
1214       // We know that we need these checks, so we can now be more aggressive
1215       // and add further checks if required (overflow checks).
1216       CanDoAliasSetRT = true;
1217       for (const auto &[Access, AccessTy] : Retries) {
1218         if (!createCheckForAccess(RtCheck, Access, AccessTy, StridesMap,
1219                                   DepSetId, TheLoop, RunningDepId, ASId,
1220                                   ShouldCheckWrap, /*Assume=*/true)) {
1221           CanDoAliasSetRT = false;
1222           UncomputablePtr = Access.getPointer();
1223           break;
1224         }
1225       }
1226     }
1227 
1228     CanDoRT &= CanDoAliasSetRT;
1229     MayNeedRTCheck |= NeedsAliasSetRTCheck;
1230     ++ASId;
1231   }
1232 
1233   // If the pointers that we would use for the bounds comparison have different
1234   // address spaces, assume the values aren't directly comparable, so we can't
1235   // use them for the runtime check. We also have to assume they could
1236   // overlap. In the future there should be metadata for whether address spaces
1237   // are disjoint.
1238   unsigned NumPointers = RtCheck.Pointers.size();
1239   for (unsigned i = 0; i < NumPointers; ++i) {
1240     for (unsigned j = i + 1; j < NumPointers; ++j) {
1241       // Only need to check pointers between two different dependency sets.
1242       if (RtCheck.Pointers[i].DependencySetId ==
1243           RtCheck.Pointers[j].DependencySetId)
1244        continue;
1245       // Only need to check pointers in the same alias set.
1246       if (RtCheck.Pointers[i].AliasSetId != RtCheck.Pointers[j].AliasSetId)
1247         continue;
1248 
1249       Value *PtrI = RtCheck.Pointers[i].PointerValue;
1250       Value *PtrJ = RtCheck.Pointers[j].PointerValue;
1251 
1252       unsigned ASi = PtrI->getType()->getPointerAddressSpace();
1253       unsigned ASj = PtrJ->getType()->getPointerAddressSpace();
1254       if (ASi != ASj) {
1255         LLVM_DEBUG(
1256             dbgs() << "LAA: Runtime check would require comparison between"
1257                       " different address spaces\n");
1258         return false;
1259       }
1260     }
1261   }
1262 
1263   if (MayNeedRTCheck && CanDoRT)
1264     RtCheck.generateChecks(DepCands, IsDepCheckNeeded);
1265 
1266   LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks()
1267                     << " pointer comparisons.\n");
1268 
1269   // If we can do run-time checks, but there are no checks, no runtime checks
1270   // are needed. This can happen when all pointers point to the same underlying
1271   // object for example.
1272   RtCheck.Need = CanDoRT ? RtCheck.getNumberOfChecks() != 0 : MayNeedRTCheck;
1273 
1274   bool CanDoRTIfNeeded = !RtCheck.Need || CanDoRT;
1275   if (!CanDoRTIfNeeded)
1276     RtCheck.reset();
1277   return CanDoRTIfNeeded;
1278 }
1279 
processMemAccesses()1280 void AccessAnalysis::processMemAccesses() {
1281   // We process the set twice: first we process read-write pointers, last we
1282   // process read-only pointers. This allows us to skip dependence tests for
1283   // read-only pointers.
1284 
1285   LLVM_DEBUG(dbgs() << "LAA: Processing memory accesses...\n");
1286   LLVM_DEBUG(dbgs() << "  AST: "; AST.dump());
1287   LLVM_DEBUG(dbgs() << "LAA:   Accesses(" << Accesses.size() << "):\n");
1288   LLVM_DEBUG({
1289     for (const auto &[A, _] : Accesses)
1290       dbgs() << "\t" << *A.getPointer() << " ("
1291              << (A.getInt() ? "write"
1292                             : (ReadOnlyPtr.count(A.getPointer()) ? "read-only"
1293                                                                  : "read"))
1294              << ")\n";
1295   });
1296 
1297   // The AliasSetTracker has nicely partitioned our pointers by metadata
1298   // compatibility and potential for underlying-object overlap. As a result, we
1299   // only need to check for potential pointer dependencies within each alias
1300   // set.
1301   for (const auto &AS : AST) {
1302     // Note that both the alias-set tracker and the alias sets themselves used
1303     // ordered collections internally and so the iteration order here is
1304     // deterministic.
1305     auto ASPointers = AS.getPointers();
1306 
1307     bool SetHasWrite = false;
1308 
1309     // Map of pointers to last access encountered.
1310     typedef DenseMap<const Value*, MemAccessInfo> UnderlyingObjToAccessMap;
1311     UnderlyingObjToAccessMap ObjToLastAccess;
1312 
1313     // Set of access to check after all writes have been processed.
1314     PtrAccessMap DeferredAccesses;
1315 
1316     // Iterate over each alias set twice, once to process read/write pointers,
1317     // and then to process read-only pointers.
1318     for (int SetIteration = 0; SetIteration < 2; ++SetIteration) {
1319       bool UseDeferred = SetIteration > 0;
1320       PtrAccessMap &S = UseDeferred ? DeferredAccesses : Accesses;
1321 
1322       for (const Value *ConstPtr : ASPointers) {
1323         Value *Ptr = const_cast<Value *>(ConstPtr);
1324 
1325         // For a single memory access in AliasSetTracker, Accesses may contain
1326         // both read and write, and they both need to be handled for CheckDeps.
1327         for (const auto &[AC, _] : S) {
1328           if (AC.getPointer() != Ptr)
1329             continue;
1330 
1331           bool IsWrite = AC.getInt();
1332 
1333           // If we're using the deferred access set, then it contains only
1334           // reads.
1335           bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite;
1336           if (UseDeferred && !IsReadOnlyPtr)
1337             continue;
1338           // Otherwise, the pointer must be in the PtrAccessSet, either as a
1339           // read or a write.
1340           assert(((IsReadOnlyPtr && UseDeferred) || IsWrite ||
1341                   S.count(MemAccessInfo(Ptr, false))) &&
1342                  "Alias-set pointer not in the access set?");
1343 
1344           MemAccessInfo Access(Ptr, IsWrite);
1345           DepCands.insert(Access);
1346 
1347           // Memorize read-only pointers for later processing and skip them in
1348           // the first round (they need to be checked after we have seen all
1349           // write pointers). Note: we also mark pointer that are not
1350           // consecutive as "read-only" pointers (so that we check
1351           // "a[b[i]] +="). Hence, we need the second check for "!IsWrite".
1352           if (!UseDeferred && IsReadOnlyPtr) {
1353             // We only use the pointer keys, the types vector values don't
1354             // matter.
1355             DeferredAccesses.insert({Access, {}});
1356             continue;
1357           }
1358 
1359           // If this is a write - check other reads and writes for conflicts. If
1360           // this is a read only check other writes for conflicts (but only if
1361           // there is no other write to the ptr - this is an optimization to
1362           // catch "a[i] = a[i] + " without having to do a dependence check).
1363           if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) {
1364             CheckDeps.push_back(Access);
1365             IsRTCheckAnalysisNeeded = true;
1366           }
1367 
1368           if (IsWrite)
1369             SetHasWrite = true;
1370 
1371           // Create sets of pointers connected by a shared alias set and
1372           // underlying object.
1373           typedef SmallVector<const Value *, 16> ValueVector;
1374           ValueVector TempObjects;
1375 
1376           UnderlyingObjects[Ptr] = {};
1377           SmallVector<const Value *, 16> &UOs = UnderlyingObjects[Ptr];
1378           ::getUnderlyingObjects(Ptr, UOs, LI);
1379           LLVM_DEBUG(dbgs()
1380                      << "Underlying objects for pointer " << *Ptr << "\n");
1381           for (const Value *UnderlyingObj : UOs) {
1382             // nullptr never alias, don't join sets for pointer that have "null"
1383             // in their UnderlyingObjects list.
1384             if (isa<ConstantPointerNull>(UnderlyingObj) &&
1385                 !NullPointerIsDefined(
1386                     TheLoop->getHeader()->getParent(),
1387                     UnderlyingObj->getType()->getPointerAddressSpace()))
1388               continue;
1389 
1390             UnderlyingObjToAccessMap::iterator Prev =
1391                 ObjToLastAccess.find(UnderlyingObj);
1392             if (Prev != ObjToLastAccess.end())
1393               DepCands.unionSets(Access, Prev->second);
1394 
1395             ObjToLastAccess[UnderlyingObj] = Access;
1396             LLVM_DEBUG(dbgs() << "  " << *UnderlyingObj << "\n");
1397           }
1398         }
1399       }
1400     }
1401   }
1402 }
1403 
1404 /// Return true if an AddRec pointer \p Ptr is unsigned non-wrapping,
1405 /// i.e. monotonically increasing/decreasing.
isNoWrapAddRec(Value * Ptr,const SCEVAddRecExpr * AR,PredicatedScalarEvolution & PSE,const Loop * L)1406 static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
1407                            PredicatedScalarEvolution &PSE, const Loop *L) {
1408 
1409   // FIXME: This should probably only return true for NUW.
1410   if (AR->getNoWrapFlags(SCEV::NoWrapMask))
1411     return true;
1412 
1413   if (PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW))
1414     return true;
1415 
1416   // Scalar evolution does not propagate the non-wrapping flags to values that
1417   // are derived from a non-wrapping induction variable because non-wrapping
1418   // could be flow-sensitive.
1419   //
1420   // Look through the potentially overflowing instruction to try to prove
1421   // non-wrapping for the *specific* value of Ptr.
1422 
1423   // The arithmetic implied by an inbounds GEP can't overflow.
1424   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
1425   if (!GEP || !GEP->isInBounds())
1426     return false;
1427 
1428   // Make sure there is only one non-const index and analyze that.
1429   Value *NonConstIndex = nullptr;
1430   for (Value *Index : GEP->indices())
1431     if (!isa<ConstantInt>(Index)) {
1432       if (NonConstIndex)
1433         return false;
1434       NonConstIndex = Index;
1435     }
1436   if (!NonConstIndex)
1437     // The recurrence is on the pointer, ignore for now.
1438     return false;
1439 
1440   // The index in GEP is signed.  It is non-wrapping if it's derived from a NSW
1441   // AddRec using a NSW operation.
1442   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(NonConstIndex))
1443     if (OBO->hasNoSignedWrap() &&
1444         // Assume constant for other the operand so that the AddRec can be
1445         // easily found.
1446         isa<ConstantInt>(OBO->getOperand(1))) {
1447       auto *OpScev = PSE.getSCEV(OBO->getOperand(0));
1448 
1449       if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev))
1450         return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW);
1451     }
1452 
1453   return false;
1454 }
1455 
1456 /// Check whether the access through \p Ptr has a constant stride.
1457 std::optional<int64_t>
getPtrStride(PredicatedScalarEvolution & PSE,Type * AccessTy,Value * Ptr,const Loop * Lp,const DenseMap<Value *,const SCEV * > & StridesMap,bool Assume,bool ShouldCheckWrap)1458 llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
1459                    const Loop *Lp,
1460                    const DenseMap<Value *, const SCEV *> &StridesMap,
1461                    bool Assume, bool ShouldCheckWrap) {
1462   const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
1463   if (PSE.getSE()->isLoopInvariant(PtrScev, Lp))
1464     return {0};
1465 
1466   Type *Ty = Ptr->getType();
1467   assert(Ty->isPointerTy() && "Unexpected non-ptr");
1468   if (isa<ScalableVectorType>(AccessTy)) {
1469     LLVM_DEBUG(dbgs() << "LAA: Bad stride - Scalable object: " << *AccessTy
1470                       << "\n");
1471     return std::nullopt;
1472   }
1473 
1474   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
1475   if (Assume && !AR)
1476     AR = PSE.getAsAddRec(Ptr);
1477 
1478   if (!AR) {
1479     LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr
1480                       << " SCEV: " << *PtrScev << "\n");
1481     return std::nullopt;
1482   }
1483 
1484   // The access function must stride over the innermost loop.
1485   if (Lp != AR->getLoop()) {
1486     LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop "
1487                       << *Ptr << " SCEV: " << *AR << "\n");
1488     return std::nullopt;
1489   }
1490 
1491   // Check the step is constant.
1492   const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
1493 
1494   // Calculate the pointer stride and check if it is constant.
1495   const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
1496   if (!C) {
1497     LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr
1498                       << " SCEV: " << *AR << "\n");
1499     return std::nullopt;
1500   }
1501 
1502   auto &DL = Lp->getHeader()->getDataLayout();
1503   TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
1504   int64_t Size = AllocSize.getFixedValue();
1505   const APInt &APStepVal = C->getAPInt();
1506 
1507   // Huge step value - give up.
1508   if (APStepVal.getBitWidth() > 64)
1509     return std::nullopt;
1510 
1511   int64_t StepVal = APStepVal.getSExtValue();
1512 
1513   // Strided access.
1514   int64_t Stride = StepVal / Size;
1515   int64_t Rem = StepVal % Size;
1516   if (Rem)
1517     return std::nullopt;
1518 
1519   if (!ShouldCheckWrap)
1520     return Stride;
1521 
1522   // The address calculation must not wrap. Otherwise, a dependence could be
1523   // inverted.
1524   if (isNoWrapAddRec(Ptr, AR, PSE, Lp))
1525     return Stride;
1526 
1527   // An inbounds getelementptr that is a AddRec with a unit stride
1528   // cannot wrap per definition.  If it did, the result would be poison
1529   // and any memory access dependent on it would be immediate UB
1530   // when executed.
1531   if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
1532       GEP && GEP->isInBounds() && (Stride == 1 || Stride == -1))
1533     return Stride;
1534 
1535   // If the null pointer is undefined, then a access sequence which would
1536   // otherwise access it can be assumed not to unsigned wrap.  Note that this
1537   // assumes the object in memory is aligned to the natural alignment.
1538   unsigned AddrSpace = Ty->getPointerAddressSpace();
1539   if (!NullPointerIsDefined(Lp->getHeader()->getParent(), AddrSpace) &&
1540       (Stride == 1 || Stride == -1))
1541     return Stride;
1542 
1543   if (Assume) {
1544     PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
1545     LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap:\n"
1546                       << "LAA:   Pointer: " << *Ptr << "\n"
1547                       << "LAA:   SCEV: " << *AR << "\n"
1548                       << "LAA:   Added an overflow assumption\n");
1549     return Stride;
1550   }
1551   LLVM_DEBUG(
1552       dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
1553              << *Ptr << " SCEV: " << *AR << "\n");
1554   return std::nullopt;
1555 }
1556 
getPointersDiff(Type * ElemTyA,Value * PtrA,Type * ElemTyB,Value * PtrB,const DataLayout & DL,ScalarEvolution & SE,bool StrictCheck,bool CheckType)1557 std::optional<int> llvm::getPointersDiff(Type *ElemTyA, Value *PtrA,
1558                                          Type *ElemTyB, Value *PtrB,
1559                                          const DataLayout &DL,
1560                                          ScalarEvolution &SE, bool StrictCheck,
1561                                          bool CheckType) {
1562   assert(PtrA && PtrB && "Expected non-nullptr pointers.");
1563 
1564   // Make sure that A and B are different pointers.
1565   if (PtrA == PtrB)
1566     return 0;
1567 
1568   // Make sure that the element types are the same if required.
1569   if (CheckType && ElemTyA != ElemTyB)
1570     return std::nullopt;
1571 
1572   unsigned ASA = PtrA->getType()->getPointerAddressSpace();
1573   unsigned ASB = PtrB->getType()->getPointerAddressSpace();
1574 
1575   // Check that the address spaces match.
1576   if (ASA != ASB)
1577     return std::nullopt;
1578   unsigned IdxWidth = DL.getIndexSizeInBits(ASA);
1579 
1580   APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0);
1581   Value *PtrA1 = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
1582   Value *PtrB1 = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
1583 
1584   int Val;
1585   if (PtrA1 == PtrB1) {
1586     // Retrieve the address space again as pointer stripping now tracks through
1587     // `addrspacecast`.
1588     ASA = cast<PointerType>(PtrA1->getType())->getAddressSpace();
1589     ASB = cast<PointerType>(PtrB1->getType())->getAddressSpace();
1590     // Check that the address spaces match and that the pointers are valid.
1591     if (ASA != ASB)
1592       return std::nullopt;
1593 
1594     IdxWidth = DL.getIndexSizeInBits(ASA);
1595     OffsetA = OffsetA.sextOrTrunc(IdxWidth);
1596     OffsetB = OffsetB.sextOrTrunc(IdxWidth);
1597 
1598     OffsetB -= OffsetA;
1599     Val = OffsetB.getSExtValue();
1600   } else {
1601     // Otherwise compute the distance with SCEV between the base pointers.
1602     const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
1603     const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
1604     const auto *Diff =
1605         dyn_cast<SCEVConstant>(SE.getMinusSCEV(PtrSCEVB, PtrSCEVA));
1606     if (!Diff)
1607       return std::nullopt;
1608     Val = Diff->getAPInt().getSExtValue();
1609   }
1610   int Size = DL.getTypeStoreSize(ElemTyA);
1611   int Dist = Val / Size;
1612 
1613   // Ensure that the calculated distance matches the type-based one after all
1614   // the bitcasts removal in the provided pointers.
1615   if (!StrictCheck || Dist * Size == Val)
1616     return Dist;
1617   return std::nullopt;
1618 }
1619 
sortPtrAccesses(ArrayRef<Value * > VL,Type * ElemTy,const DataLayout & DL,ScalarEvolution & SE,SmallVectorImpl<unsigned> & SortedIndices)1620 bool llvm::sortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
1621                            const DataLayout &DL, ScalarEvolution &SE,
1622                            SmallVectorImpl<unsigned> &SortedIndices) {
1623   assert(llvm::all_of(
1624              VL, [](const Value *V) { return V->getType()->isPointerTy(); }) &&
1625          "Expected list of pointer operands.");
1626   // Walk over the pointers, and map each of them to an offset relative to
1627   // first pointer in the array.
1628   Value *Ptr0 = VL[0];
1629 
1630   using DistOrdPair = std::pair<int64_t, int>;
1631   auto Compare = llvm::less_first();
1632   std::set<DistOrdPair, decltype(Compare)> Offsets(Compare);
1633   Offsets.emplace(0, 0);
1634   bool IsConsecutive = true;
1635   for (auto [Idx, Ptr] : drop_begin(enumerate(VL))) {
1636     std::optional<int> Diff = getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, DL, SE,
1637                                               /*StrictCheck=*/true);
1638     if (!Diff)
1639       return false;
1640 
1641     // Check if the pointer with the same offset is found.
1642     int64_t Offset = *Diff;
1643     auto [It, IsInserted] = Offsets.emplace(Offset, Idx);
1644     if (!IsInserted)
1645       return false;
1646     // Consecutive order if the inserted element is the last one.
1647     IsConsecutive &= std::next(It) == Offsets.end();
1648   }
1649   SortedIndices.clear();
1650   if (!IsConsecutive) {
1651     // Fill SortedIndices array only if it is non-consecutive.
1652     SortedIndices.resize(VL.size());
1653     for (auto [Idx, Off] : enumerate(Offsets))
1654       SortedIndices[Idx] = Off.second;
1655   }
1656   return true;
1657 }
1658 
1659 /// Returns true if the memory operations \p A and \p B are consecutive.
isConsecutiveAccess(Value * A,Value * B,const DataLayout & DL,ScalarEvolution & SE,bool CheckType)1660 bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL,
1661                                ScalarEvolution &SE, bool CheckType) {
1662   Value *PtrA = getLoadStorePointerOperand(A);
1663   Value *PtrB = getLoadStorePointerOperand(B);
1664   if (!PtrA || !PtrB)
1665     return false;
1666   Type *ElemTyA = getLoadStoreType(A);
1667   Type *ElemTyB = getLoadStoreType(B);
1668   std::optional<int> Diff =
1669       getPointersDiff(ElemTyA, PtrA, ElemTyB, PtrB, DL, SE,
1670                       /*StrictCheck=*/true, CheckType);
1671   return Diff && *Diff == 1;
1672 }
1673 
addAccess(StoreInst * SI)1674 void MemoryDepChecker::addAccess(StoreInst *SI) {
1675   visitPointers(SI->getPointerOperand(), *InnermostLoop,
1676                 [this, SI](Value *Ptr) {
1677                   Accesses[MemAccessInfo(Ptr, true)].push_back(AccessIdx);
1678                   InstMap.push_back(SI);
1679                   ++AccessIdx;
1680                 });
1681 }
1682 
addAccess(LoadInst * LI)1683 void MemoryDepChecker::addAccess(LoadInst *LI) {
1684   visitPointers(LI->getPointerOperand(), *InnermostLoop,
1685                 [this, LI](Value *Ptr) {
1686                   Accesses[MemAccessInfo(Ptr, false)].push_back(AccessIdx);
1687                   InstMap.push_back(LI);
1688                   ++AccessIdx;
1689                 });
1690 }
1691 
1692 MemoryDepChecker::VectorizationSafetyStatus
isSafeForVectorization(DepType Type)1693 MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) {
1694   switch (Type) {
1695   case NoDep:
1696   case Forward:
1697   case BackwardVectorizable:
1698     return VectorizationSafetyStatus::Safe;
1699 
1700   case Unknown:
1701     return VectorizationSafetyStatus::PossiblySafeWithRtChecks;
1702   case ForwardButPreventsForwarding:
1703   case Backward:
1704   case BackwardVectorizableButPreventsForwarding:
1705   case IndirectUnsafe:
1706     return VectorizationSafetyStatus::Unsafe;
1707   }
1708   llvm_unreachable("unexpected DepType!");
1709 }
1710 
isBackward() const1711 bool MemoryDepChecker::Dependence::isBackward() const {
1712   switch (Type) {
1713   case NoDep:
1714   case Forward:
1715   case ForwardButPreventsForwarding:
1716   case Unknown:
1717   case IndirectUnsafe:
1718     return false;
1719 
1720   case BackwardVectorizable:
1721   case Backward:
1722   case BackwardVectorizableButPreventsForwarding:
1723     return true;
1724   }
1725   llvm_unreachable("unexpected DepType!");
1726 }
1727 
isPossiblyBackward() const1728 bool MemoryDepChecker::Dependence::isPossiblyBackward() const {
1729   return isBackward() || Type == Unknown || Type == IndirectUnsafe;
1730 }
1731 
isForward() const1732 bool MemoryDepChecker::Dependence::isForward() const {
1733   switch (Type) {
1734   case Forward:
1735   case ForwardButPreventsForwarding:
1736     return true;
1737 
1738   case NoDep:
1739   case Unknown:
1740   case BackwardVectorizable:
1741   case Backward:
1742   case BackwardVectorizableButPreventsForwarding:
1743   case IndirectUnsafe:
1744     return false;
1745   }
1746   llvm_unreachable("unexpected DepType!");
1747 }
1748 
couldPreventStoreLoadForward(uint64_t Distance,uint64_t TypeByteSize)1749 bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance,
1750                                                     uint64_t TypeByteSize) {
1751   // If loads occur at a distance that is not a multiple of a feasible vector
1752   // factor store-load forwarding does not take place.
1753   // Positive dependences might cause troubles because vectorizing them might
1754   // prevent store-load forwarding making vectorized code run a lot slower.
1755   //   a[i] = a[i-3] ^ a[i-8];
1756   //   The stores to a[i:i+1] don't align with the stores to a[i-3:i-2] and
1757   //   hence on your typical architecture store-load forwarding does not take
1758   //   place. Vectorizing in such cases does not make sense.
1759   // Store-load forwarding distance.
1760 
1761   // After this many iterations store-to-load forwarding conflicts should not
1762   // cause any slowdowns.
1763   const uint64_t NumItersForStoreLoadThroughMemory = 8 * TypeByteSize;
1764   // Maximum vector factor.
1765   uint64_t MaxVFWithoutSLForwardIssues = std::min(
1766       VectorizerParams::MaxVectorWidth * TypeByteSize, MinDepDistBytes);
1767 
1768   // Compute the smallest VF at which the store and load would be misaligned.
1769   for (uint64_t VF = 2 * TypeByteSize; VF <= MaxVFWithoutSLForwardIssues;
1770        VF *= 2) {
1771     // If the number of vector iteration between the store and the load are
1772     // small we could incur conflicts.
1773     if (Distance % VF && Distance / VF < NumItersForStoreLoadThroughMemory) {
1774       MaxVFWithoutSLForwardIssues = (VF >> 1);
1775       break;
1776     }
1777   }
1778 
1779   if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) {
1780     LLVM_DEBUG(
1781         dbgs() << "LAA: Distance " << Distance
1782                << " that could cause a store-load forwarding conflict\n");
1783     return true;
1784   }
1785 
1786   if (MaxVFWithoutSLForwardIssues < MinDepDistBytes &&
1787       MaxVFWithoutSLForwardIssues !=
1788           VectorizerParams::MaxVectorWidth * TypeByteSize)
1789     MinDepDistBytes = MaxVFWithoutSLForwardIssues;
1790   return false;
1791 }
1792 
mergeInStatus(VectorizationSafetyStatus S)1793 void MemoryDepChecker::mergeInStatus(VectorizationSafetyStatus S) {
1794   if (Status < S)
1795     Status = S;
1796 }
1797 
1798 /// Given a dependence-distance \p Dist between two
1799 /// memory accesses, that have strides in the same direction whose absolute
1800 /// value of the maximum stride is given in \p MaxStride, and that have the same
1801 /// type size \p TypeByteSize, in a loop whose maximum backedge taken count is
1802 /// \p MaxBTC, check if it is possible to prove statically that the dependence
1803 /// distance is larger than the range that the accesses will travel through the
1804 /// execution of the loop. If so, return true; false otherwise. This is useful
1805 /// for example in loops such as the following (PR31098):
1806 ///     for (i = 0; i < D; ++i) {
1807 ///                = out[i];
1808 ///       out[i+D] =
1809 ///     }
isSafeDependenceDistance(const DataLayout & DL,ScalarEvolution & SE,const SCEV & MaxBTC,const SCEV & Dist,uint64_t MaxStride,uint64_t TypeByteSize)1810 static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE,
1811                                      const SCEV &MaxBTC, const SCEV &Dist,
1812                                      uint64_t MaxStride,
1813                                      uint64_t TypeByteSize) {
1814 
1815   // If we can prove that
1816   //      (**) |Dist| > MaxBTC * Step
1817   // where Step is the absolute stride of the memory accesses in bytes,
1818   // then there is no dependence.
1819   //
1820   // Rationale:
1821   // We basically want to check if the absolute distance (|Dist/Step|)
1822   // is >= the loop iteration count (or > MaxBTC).
1823   // This is equivalent to the Strong SIV Test (Practical Dependence Testing,
1824   // Section 4.2.1); Note, that for vectorization it is sufficient to prove
1825   // that the dependence distance is >= VF; This is checked elsewhere.
1826   // But in some cases we can prune dependence distances early, and
1827   // even before selecting the VF, and without a runtime test, by comparing
1828   // the distance against the loop iteration count. Since the vectorized code
1829   // will be executed only if LoopCount >= VF, proving distance >= LoopCount
1830   // also guarantees that distance >= VF.
1831   //
1832   const uint64_t ByteStride = MaxStride * TypeByteSize;
1833   const SCEV *Step = SE.getConstant(MaxBTC.getType(), ByteStride);
1834   const SCEV *Product = SE.getMulExpr(&MaxBTC, Step);
1835 
1836   const SCEV *CastedDist = &Dist;
1837   const SCEV *CastedProduct = Product;
1838   uint64_t DistTypeSizeBits = DL.getTypeSizeInBits(Dist.getType());
1839   uint64_t ProductTypeSizeBits = DL.getTypeSizeInBits(Product->getType());
1840 
1841   // The dependence distance can be positive/negative, so we sign extend Dist;
1842   // The multiplication of the absolute stride in bytes and the
1843   // backedgeTakenCount is non-negative, so we zero extend Product.
1844   if (DistTypeSizeBits > ProductTypeSizeBits)
1845     CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType());
1846   else
1847     CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType());
1848 
1849   // Is  Dist - (MaxBTC * Step) > 0 ?
1850   // (If so, then we have proven (**) because |Dist| >= Dist)
1851   const SCEV *Minus = SE.getMinusSCEV(CastedDist, CastedProduct);
1852   if (SE.isKnownPositive(Minus))
1853     return true;
1854 
1855   // Second try: Is  -Dist - (MaxBTC * Step) > 0 ?
1856   // (If so, then we have proven (**) because |Dist| >= -1*Dist)
1857   const SCEV *NegDist = SE.getNegativeSCEV(CastedDist);
1858   Minus = SE.getMinusSCEV(NegDist, CastedProduct);
1859   return SE.isKnownPositive(Minus);
1860 }
1861 
1862 /// Check the dependence for two accesses with the same stride \p Stride.
1863 /// \p Distance is the positive distance and \p TypeByteSize is type size in
1864 /// bytes.
1865 ///
1866 /// \returns true if they are independent.
areStridedAccessesIndependent(uint64_t Distance,uint64_t Stride,uint64_t TypeByteSize)1867 static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride,
1868                                           uint64_t TypeByteSize) {
1869   assert(Stride > 1 && "The stride must be greater than 1");
1870   assert(TypeByteSize > 0 && "The type size in byte must be non-zero");
1871   assert(Distance > 0 && "The distance must be non-zero");
1872 
1873   // Skip if the distance is not multiple of type byte size.
1874   if (Distance % TypeByteSize)
1875     return false;
1876 
1877   uint64_t ScaledDist = Distance / TypeByteSize;
1878 
1879   // No dependence if the scaled distance is not multiple of the stride.
1880   // E.g.
1881   //      for (i = 0; i < 1024 ; i += 4)
1882   //        A[i+2] = A[i] + 1;
1883   //
1884   // Two accesses in memory (scaled distance is 2, stride is 4):
1885   //     | A[0] |      |      |      | A[4] |      |      |      |
1886   //     |      |      | A[2] |      |      |      | A[6] |      |
1887   //
1888   // E.g.
1889   //      for (i = 0; i < 1024 ; i += 3)
1890   //        A[i+4] = A[i] + 1;
1891   //
1892   // Two accesses in memory (scaled distance is 4, stride is 3):
1893   //     | A[0] |      |      | A[3] |      |      | A[6] |      |      |
1894   //     |      |      |      |      | A[4] |      |      | A[7] |      |
1895   return ScaledDist % Stride;
1896 }
1897 
1898 std::variant<MemoryDepChecker::Dependence::DepType,
1899              MemoryDepChecker::DepDistanceStrideAndSizeInfo>
getDependenceDistanceStrideAndSize(const AccessAnalysis::MemAccessInfo & A,Instruction * AInst,const AccessAnalysis::MemAccessInfo & B,Instruction * BInst)1900 MemoryDepChecker::getDependenceDistanceStrideAndSize(
1901     const AccessAnalysis::MemAccessInfo &A, Instruction *AInst,
1902     const AccessAnalysis::MemAccessInfo &B, Instruction *BInst) {
1903   const auto &DL = InnermostLoop->getHeader()->getDataLayout();
1904   auto &SE = *PSE.getSE();
1905   auto [APtr, AIsWrite] = A;
1906   auto [BPtr, BIsWrite] = B;
1907 
1908   // Two reads are independent.
1909   if (!AIsWrite && !BIsWrite)
1910     return MemoryDepChecker::Dependence::NoDep;
1911 
1912   Type *ATy = getLoadStoreType(AInst);
1913   Type *BTy = getLoadStoreType(BInst);
1914 
1915   // We cannot check pointers in different address spaces.
1916   if (APtr->getType()->getPointerAddressSpace() !=
1917       BPtr->getType()->getPointerAddressSpace())
1918     return MemoryDepChecker::Dependence::Unknown;
1919 
1920   std::optional<int64_t> StrideAPtr =
1921       getPtrStride(PSE, ATy, APtr, InnermostLoop, SymbolicStrides, true, true);
1922   std::optional<int64_t> StrideBPtr =
1923       getPtrStride(PSE, BTy, BPtr, InnermostLoop, SymbolicStrides, true, true);
1924 
1925   const SCEV *Src = PSE.getSCEV(APtr);
1926   const SCEV *Sink = PSE.getSCEV(BPtr);
1927 
1928   // If the induction step is negative we have to invert source and sink of the
1929   // dependence when measuring the distance between them. We should not swap
1930   // AIsWrite with BIsWrite, as their uses expect them in program order.
1931   if (StrideAPtr && *StrideAPtr < 0) {
1932     std::swap(Src, Sink);
1933     std::swap(AInst, BInst);
1934     std::swap(StrideAPtr, StrideBPtr);
1935   }
1936 
1937   const SCEV *Dist = SE.getMinusSCEV(Sink, Src);
1938 
1939   LLVM_DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink
1940                     << "\n");
1941   LLVM_DEBUG(dbgs() << "LAA: Distance for " << *AInst << " to " << *BInst
1942                     << ": " << *Dist << "\n");
1943 
1944   // Check if we can prove that Sink only accesses memory after Src's end or
1945   // vice versa. At the moment this is limited to cases where either source or
1946   // sink are loop invariant to avoid compile-time increases. This is not
1947   // required for correctness.
1948   if (SE.isLoopInvariant(Src, InnermostLoop) ||
1949       SE.isLoopInvariant(Sink, InnermostLoop)) {
1950     const auto &[SrcStart, SrcEnd] =
1951         getStartAndEndForAccess(InnermostLoop, Src, ATy, PSE, PointerBounds);
1952     const auto &[SinkStart, SinkEnd] =
1953         getStartAndEndForAccess(InnermostLoop, Sink, BTy, PSE, PointerBounds);
1954     if (!isa<SCEVCouldNotCompute>(SrcStart) &&
1955         !isa<SCEVCouldNotCompute>(SrcEnd) &&
1956         !isa<SCEVCouldNotCompute>(SinkStart) &&
1957         !isa<SCEVCouldNotCompute>(SinkEnd)) {
1958       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
1959         return MemoryDepChecker::Dependence::NoDep;
1960       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
1961         return MemoryDepChecker::Dependence::NoDep;
1962     }
1963   }
1964 
1965   // Need accesses with constant strides and the same direction for further
1966   // dependence analysis. We don't want to vectorize "A[B[i]] += ..." and
1967   // similar code or pointer arithmetic that could wrap in the address space.
1968 
1969   // If either Src or Sink are not strided (i.e. not a non-wrapping AddRec) and
1970   // not loop-invariant (stride will be 0 in that case), we cannot analyze the
1971   // dependence further and also cannot generate runtime checks.
1972   if (!StrideAPtr || !StrideBPtr) {
1973     LLVM_DEBUG(dbgs() << "Pointer access with non-constant stride\n");
1974     return MemoryDepChecker::Dependence::IndirectUnsafe;
1975   }
1976 
1977   int64_t StrideAPtrInt = *StrideAPtr;
1978   int64_t StrideBPtrInt = *StrideBPtr;
1979   LLVM_DEBUG(dbgs() << "LAA:  Src induction step: " << StrideAPtrInt
1980                     << " Sink induction step: " << StrideBPtrInt << "\n");
1981   // At least Src or Sink are loop invariant and the other is strided or
1982   // invariant. We can generate a runtime check to disambiguate the accesses.
1983   if (StrideAPtrInt == 0 || StrideBPtrInt == 0)
1984     return MemoryDepChecker::Dependence::Unknown;
1985 
1986   // Both Src and Sink have a constant stride, check if they are in the same
1987   // direction.
1988   if ((StrideAPtrInt > 0 && StrideBPtrInt < 0) ||
1989       (StrideAPtrInt < 0 && StrideBPtrInt > 0)) {
1990     LLVM_DEBUG(
1991         dbgs() << "Pointer access with strides in different directions\n");
1992     return MemoryDepChecker::Dependence::Unknown;
1993   }
1994 
1995   uint64_t TypeByteSize = DL.getTypeAllocSize(ATy);
1996   bool HasSameSize =
1997       DL.getTypeStoreSizeInBits(ATy) == DL.getTypeStoreSizeInBits(BTy);
1998   if (!HasSameSize)
1999     TypeByteSize = 0;
2000   return DepDistanceStrideAndSizeInfo(Dist, std::abs(StrideAPtrInt),
2001                                       std::abs(StrideBPtrInt), TypeByteSize,
2002                                       AIsWrite, BIsWrite);
2003 }
2004 
2005 MemoryDepChecker::Dependence::DepType
isDependent(const MemAccessInfo & A,unsigned AIdx,const MemAccessInfo & B,unsigned BIdx)2006 MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
2007                               const MemAccessInfo &B, unsigned BIdx) {
2008   assert(AIdx < BIdx && "Must pass arguments in program order");
2009 
2010   // Get the dependence distance, stride, type size and what access writes for
2011   // the dependence between A and B.
2012   auto Res =
2013       getDependenceDistanceStrideAndSize(A, InstMap[AIdx], B, InstMap[BIdx]);
2014   if (std::holds_alternative<Dependence::DepType>(Res))
2015     return std::get<Dependence::DepType>(Res);
2016 
2017   auto &[Dist, StrideA, StrideB, TypeByteSize, AIsWrite, BIsWrite] =
2018       std::get<DepDistanceStrideAndSizeInfo>(Res);
2019   bool HasSameSize = TypeByteSize > 0;
2020 
2021   std::optional<uint64_t> CommonStride =
2022       StrideA == StrideB ? std::make_optional(StrideA) : std::nullopt;
2023   if (isa<SCEVCouldNotCompute>(Dist)) {
2024     // TODO: Relax requirement that there is a common stride to retry with
2025     // non-constant distance dependencies.
2026     FoundNonConstantDistanceDependence |= CommonStride.has_value();
2027     LLVM_DEBUG(dbgs() << "LAA: Dependence because of uncomputable distance.\n");
2028     return Dependence::Unknown;
2029   }
2030 
2031   ScalarEvolution &SE = *PSE.getSE();
2032   auto &DL = InnermostLoop->getHeader()->getDataLayout();
2033   uint64_t MaxStride = std::max(StrideA, StrideB);
2034 
2035   // If the distance between the acecsses is larger than their maximum absolute
2036   // stride multiplied by the symbolic maximum backedge taken count (which is an
2037   // upper bound of the number of iterations), the accesses are independet, i.e.
2038   // they are far enough appart that accesses won't access the same location
2039   // across all loop ierations.
2040   if (HasSameSize && isSafeDependenceDistance(
2041                          DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()),
2042                          *Dist, MaxStride, TypeByteSize))
2043     return Dependence::NoDep;
2044 
2045   const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist);
2046 
2047   // Attempt to prove strided accesses independent.
2048   if (C) {
2049     const APInt &Val = C->getAPInt();
2050     int64_t Distance = Val.getSExtValue();
2051 
2052     // If the distance between accesses and their strides are known constants,
2053     // check whether the accesses interlace each other.
2054     if (std::abs(Distance) > 0 && CommonStride && *CommonStride > 1 &&
2055         HasSameSize &&
2056         areStridedAccessesIndependent(std::abs(Distance), *CommonStride,
2057                                       TypeByteSize)) {
2058       LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n");
2059       return Dependence::NoDep;
2060     }
2061   } else
2062     Dist = SE.applyLoopGuards(Dist, InnermostLoop);
2063 
2064   // Negative distances are not plausible dependencies.
2065   if (SE.isKnownNonPositive(Dist)) {
2066     if (SE.isKnownNonNegative(Dist)) {
2067       if (HasSameSize) {
2068         // Write to the same location with the same size.
2069         return Dependence::Forward;
2070       }
2071       LLVM_DEBUG(dbgs() << "LAA: possibly zero dependence difference but "
2072                            "different type sizes\n");
2073       return Dependence::Unknown;
2074     }
2075 
2076     bool IsTrueDataDependence = (AIsWrite && !BIsWrite);
2077     // Check if the first access writes to a location that is read in a later
2078     // iteration, where the distance between them is not a multiple of a vector
2079     // factor and relatively small.
2080     //
2081     // NOTE: There is no need to update MaxSafeVectorWidthInBits after call to
2082     // couldPreventStoreLoadForward, even if it changed MinDepDistBytes, since a
2083     // forward dependency will allow vectorization using any width.
2084 
2085     if (IsTrueDataDependence && EnableForwardingConflictDetection) {
2086       if (!C) {
2087         // TODO: FoundNonConstantDistanceDependence is used as a necessary
2088         // condition to consider retrying with runtime checks. Historically, we
2089         // did not set it when strides were different but there is no inherent
2090         // reason to.
2091         FoundNonConstantDistanceDependence |= CommonStride.has_value();
2092         return Dependence::Unknown;
2093       }
2094       if (!HasSameSize ||
2095           couldPreventStoreLoadForward(C->getAPInt().abs().getZExtValue(),
2096                                        TypeByteSize)) {
2097         LLVM_DEBUG(
2098             dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
2099         return Dependence::ForwardButPreventsForwarding;
2100       }
2101     }
2102 
2103     LLVM_DEBUG(dbgs() << "LAA: Dependence is negative\n");
2104     return Dependence::Forward;
2105   }
2106 
2107   int64_t MinDistance = SE.getSignedRangeMin(Dist).getSExtValue();
2108   // Below we only handle strictly positive distances.
2109   if (MinDistance <= 0) {
2110     FoundNonConstantDistanceDependence |= CommonStride.has_value();
2111     return Dependence::Unknown;
2112   }
2113 
2114   if (!isa<SCEVConstant>(Dist)) {
2115     // Previously this case would be treated as Unknown, possibly setting
2116     // FoundNonConstantDistanceDependence to force re-trying with runtime
2117     // checks. Until the TODO below is addressed, set it here to preserve
2118     // original behavior w.r.t. re-trying with runtime checks.
2119     // TODO: FoundNonConstantDistanceDependence is used as a necessary
2120     // condition to consider retrying with runtime checks. Historically, we
2121     // did not set it when strides were different but there is no inherent
2122     // reason to.
2123     FoundNonConstantDistanceDependence |= CommonStride.has_value();
2124   }
2125 
2126   if (!HasSameSize) {
2127     LLVM_DEBUG(dbgs() << "LAA: ReadWrite-Write positive dependency with "
2128                          "different type sizes\n");
2129     return Dependence::Unknown;
2130   }
2131 
2132   if (!CommonStride)
2133     return Dependence::Unknown;
2134 
2135   // Bail out early if passed-in parameters make vectorization not feasible.
2136   unsigned ForcedFactor = (VectorizerParams::VectorizationFactor ?
2137                            VectorizerParams::VectorizationFactor : 1);
2138   unsigned ForcedUnroll = (VectorizerParams::VectorizationInterleave ?
2139                            VectorizerParams::VectorizationInterleave : 1);
2140   // The minimum number of iterations for a vectorized/unrolled version.
2141   unsigned MinNumIter = std::max(ForcedFactor * ForcedUnroll, 2U);
2142 
2143   // It's not vectorizable if the distance is smaller than the minimum distance
2144   // needed for a vectroized/unrolled version. Vectorizing one iteration in
2145   // front needs TypeByteSize * Stride. Vectorizing the last iteration needs
2146   // TypeByteSize (No need to plus the last gap distance).
2147   //
2148   // E.g. Assume one char is 1 byte in memory and one int is 4 bytes.
2149   //      foo(int *A) {
2150   //        int *B = (int *)((char *)A + 14);
2151   //        for (i = 0 ; i < 1024 ; i += 2)
2152   //          B[i] = A[i] + 1;
2153   //      }
2154   //
2155   // Two accesses in memory (stride is 2):
2156   //     | A[0] |      | A[2] |      | A[4] |      | A[6] |      |
2157   //                              | B[0] |      | B[2] |      | B[4] |
2158   //
2159   // MinDistance needs for vectorizing iterations except the last iteration:
2160   // 4 * 2 * (MinNumIter - 1). MinDistance needs for the last iteration: 4.
2161   // So the minimum distance needed is: 4 * 2 * (MinNumIter - 1) + 4.
2162   //
2163   // If MinNumIter is 2, it is vectorizable as the minimum distance needed is
2164   // 12, which is less than distance.
2165   //
2166   // If MinNumIter is 4 (Say if a user forces the vectorization factor to be 4),
2167   // the minimum distance needed is 28, which is greater than distance. It is
2168   // not safe to do vectorization.
2169 
2170   // We know that Dist is positive, but it may not be constant. Use the signed
2171   // minimum for computations below, as this ensures we compute the closest
2172   // possible dependence distance.
2173   uint64_t MinDistanceNeeded =
2174       TypeByteSize * *CommonStride * (MinNumIter - 1) + TypeByteSize;
2175   if (MinDistanceNeeded > static_cast<uint64_t>(MinDistance)) {
2176     if (!isa<SCEVConstant>(Dist)) {
2177       // For non-constant distances, we checked the lower bound of the
2178       // dependence distance and the distance may be larger at runtime (and safe
2179       // for vectorization). Classify it as Unknown, so we re-try with runtime
2180       // checks.
2181       return Dependence::Unknown;
2182     }
2183     LLVM_DEBUG(dbgs() << "LAA: Failure because of positive minimum distance "
2184                       << MinDistance << '\n');
2185     return Dependence::Backward;
2186   }
2187 
2188   // Unsafe if the minimum distance needed is greater than smallest dependence
2189   // distance distance.
2190   if (MinDistanceNeeded > MinDepDistBytes) {
2191     LLVM_DEBUG(dbgs() << "LAA: Failure because it needs at least "
2192                       << MinDistanceNeeded << " size in bytes\n");
2193     return Dependence::Backward;
2194   }
2195 
2196   // Positive distance bigger than max vectorization factor.
2197   // FIXME: Should use max factor instead of max distance in bytes, which could
2198   // not handle different types.
2199   // E.g. Assume one char is 1 byte in memory and one int is 4 bytes.
2200   //      void foo (int *A, char *B) {
2201   //        for (unsigned i = 0; i < 1024; i++) {
2202   //          A[i+2] = A[i] + 1;
2203   //          B[i+2] = B[i] + 1;
2204   //        }
2205   //      }
2206   //
2207   // This case is currently unsafe according to the max safe distance. If we
2208   // analyze the two accesses on array B, the max safe dependence distance
2209   // is 2. Then we analyze the accesses on array A, the minimum distance needed
2210   // is 8, which is less than 2 and forbidden vectorization, But actually
2211   // both A and B could be vectorized by 2 iterations.
2212   MinDepDistBytes =
2213       std::min(static_cast<uint64_t>(MinDistance), MinDepDistBytes);
2214 
2215   bool IsTrueDataDependence = (!AIsWrite && BIsWrite);
2216   uint64_t MinDepDistBytesOld = MinDepDistBytes;
2217   if (IsTrueDataDependence && EnableForwardingConflictDetection &&
2218       isa<SCEVConstant>(Dist) &&
2219       couldPreventStoreLoadForward(MinDistance, TypeByteSize)) {
2220     // Sanity check that we didn't update MinDepDistBytes when calling
2221     // couldPreventStoreLoadForward
2222     assert(MinDepDistBytes == MinDepDistBytesOld &&
2223            "An update to MinDepDistBytes requires an update to "
2224            "MaxSafeVectorWidthInBits");
2225     (void)MinDepDistBytesOld;
2226     return Dependence::BackwardVectorizableButPreventsForwarding;
2227   }
2228 
2229   // An update to MinDepDistBytes requires an update to MaxSafeVectorWidthInBits
2230   // since there is a backwards dependency.
2231   uint64_t MaxVF = MinDepDistBytes / (TypeByteSize * *CommonStride);
2232   LLVM_DEBUG(dbgs() << "LAA: Positive min distance " << MinDistance
2233                     << " with max VF = " << MaxVF << '\n');
2234 
2235   uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8;
2236   if (!isa<SCEVConstant>(Dist) && MaxVFInBits < MaxTargetVectorWidthInBits) {
2237     // For non-constant distances, we checked the lower bound of the dependence
2238     // distance and the distance may be larger at runtime (and safe for
2239     // vectorization). Classify it as Unknown, so we re-try with runtime checks.
2240     return Dependence::Unknown;
2241   }
2242 
2243   MaxSafeVectorWidthInBits = std::min(MaxSafeVectorWidthInBits, MaxVFInBits);
2244   return Dependence::BackwardVectorizable;
2245 }
2246 
areDepsSafe(const DepCandidates & AccessSets,const MemAccessInfoList & CheckDeps)2247 bool MemoryDepChecker::areDepsSafe(const DepCandidates &AccessSets,
2248                                    const MemAccessInfoList &CheckDeps) {
2249 
2250   MinDepDistBytes = -1;
2251   SmallPtrSet<MemAccessInfo, 8> Visited;
2252   for (MemAccessInfo CurAccess : CheckDeps) {
2253     if (Visited.count(CurAccess))
2254       continue;
2255 
2256     // Get the relevant memory access set.
2257     EquivalenceClasses<MemAccessInfo>::iterator I =
2258       AccessSets.findValue(AccessSets.getLeaderValue(CurAccess));
2259 
2260     // Check accesses within this set.
2261     EquivalenceClasses<MemAccessInfo>::member_iterator AI =
2262         AccessSets.member_begin(I);
2263     EquivalenceClasses<MemAccessInfo>::member_iterator AE =
2264         AccessSets.member_end();
2265 
2266     // Check every access pair.
2267     while (AI != AE) {
2268       Visited.insert(*AI);
2269       bool AIIsWrite = AI->getInt();
2270       // Check loads only against next equivalent class, but stores also against
2271       // other stores in the same equivalence class - to the same address.
2272       EquivalenceClasses<MemAccessInfo>::member_iterator OI =
2273           (AIIsWrite ? AI : std::next(AI));
2274       while (OI != AE) {
2275         // Check every accessing instruction pair in program order.
2276         for (std::vector<unsigned>::iterator I1 = Accesses[*AI].begin(),
2277              I1E = Accesses[*AI].end(); I1 != I1E; ++I1)
2278           // Scan all accesses of another equivalence class, but only the next
2279           // accesses of the same equivalent class.
2280           for (std::vector<unsigned>::iterator
2281                    I2 = (OI == AI ? std::next(I1) : Accesses[*OI].begin()),
2282                    I2E = (OI == AI ? I1E : Accesses[*OI].end());
2283                I2 != I2E; ++I2) {
2284             auto A = std::make_pair(&*AI, *I1);
2285             auto B = std::make_pair(&*OI, *I2);
2286 
2287             assert(*I1 != *I2);
2288             if (*I1 > *I2)
2289               std::swap(A, B);
2290 
2291             Dependence::DepType Type =
2292                 isDependent(*A.first, A.second, *B.first, B.second);
2293             mergeInStatus(Dependence::isSafeForVectorization(Type));
2294 
2295             // Gather dependences unless we accumulated MaxDependences
2296             // dependences.  In that case return as soon as we find the first
2297             // unsafe dependence.  This puts a limit on this quadratic
2298             // algorithm.
2299             if (RecordDependences) {
2300               if (Type != Dependence::NoDep)
2301                 Dependences.push_back(Dependence(A.second, B.second, Type));
2302 
2303               if (Dependences.size() >= MaxDependences) {
2304                 RecordDependences = false;
2305                 Dependences.clear();
2306                 LLVM_DEBUG(dbgs()
2307                            << "Too many dependences, stopped recording\n");
2308               }
2309             }
2310             if (!RecordDependences && !isSafeForVectorization())
2311               return false;
2312           }
2313         ++OI;
2314       }
2315       ++AI;
2316     }
2317   }
2318 
2319   LLVM_DEBUG(dbgs() << "Total Dependences: " << Dependences.size() << "\n");
2320   return isSafeForVectorization();
2321 }
2322 
2323 SmallVector<Instruction *, 4>
getInstructionsForAccess(Value * Ptr,bool IsWrite) const2324 MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool IsWrite) const {
2325   MemAccessInfo Access(Ptr, IsWrite);
2326   auto &IndexVector = Accesses.find(Access)->second;
2327 
2328   SmallVector<Instruction *, 4> Insts;
2329   transform(IndexVector,
2330                  std::back_inserter(Insts),
2331                  [&](unsigned Idx) { return this->InstMap[Idx]; });
2332   return Insts;
2333 }
2334 
2335 const char *MemoryDepChecker::Dependence::DepName[] = {
2336     "NoDep",
2337     "Unknown",
2338     "IndirectUnsafe",
2339     "Forward",
2340     "ForwardButPreventsForwarding",
2341     "Backward",
2342     "BackwardVectorizable",
2343     "BackwardVectorizableButPreventsForwarding"};
2344 
print(raw_ostream & OS,unsigned Depth,const SmallVectorImpl<Instruction * > & Instrs) const2345 void MemoryDepChecker::Dependence::print(
2346     raw_ostream &OS, unsigned Depth,
2347     const SmallVectorImpl<Instruction *> &Instrs) const {
2348   OS.indent(Depth) << DepName[Type] << ":\n";
2349   OS.indent(Depth + 2) << *Instrs[Source] << " -> \n";
2350   OS.indent(Depth + 2) << *Instrs[Destination] << "\n";
2351 }
2352 
canAnalyzeLoop()2353 bool LoopAccessInfo::canAnalyzeLoop() {
2354   // We need to have a loop header.
2355   LLVM_DEBUG(dbgs() << "\nLAA: Checking a loop in '"
2356                     << TheLoop->getHeader()->getParent()->getName() << "' from "
2357                     << TheLoop->getLocStr() << "\n");
2358 
2359   // We can only analyze innermost loops.
2360   if (!TheLoop->isInnermost()) {
2361     LLVM_DEBUG(dbgs() << "LAA: loop is not the innermost loop\n");
2362     recordAnalysis("NotInnerMostLoop") << "loop is not the innermost loop";
2363     return false;
2364   }
2365 
2366   // We must have a single backedge.
2367   if (TheLoop->getNumBackEdges() != 1) {
2368     LLVM_DEBUG(
2369         dbgs() << "LAA: loop control flow is not understood by analyzer\n");
2370     recordAnalysis("CFGNotUnderstood")
2371         << "loop control flow is not understood by analyzer";
2372     return false;
2373   }
2374 
2375   // ScalarEvolution needs to be able to find the symbolic max backedge taken
2376   // count, which is an upper bound on the number of loop iterations. The loop
2377   // may execute fewer iterations, if it exits via an uncountable exit.
2378   const SCEV *ExitCount = PSE->getSymbolicMaxBackedgeTakenCount();
2379   if (isa<SCEVCouldNotCompute>(ExitCount)) {
2380     recordAnalysis("CantComputeNumberOfIterations")
2381         << "could not determine number of loop iterations";
2382     LLVM_DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n");
2383     return false;
2384   }
2385 
2386   LLVM_DEBUG(dbgs() << "LAA: Found an analyzable loop: "
2387                     << TheLoop->getHeader()->getName() << "\n");
2388   return true;
2389 }
2390 
analyzeLoop(AAResults * AA,LoopInfo * LI,const TargetLibraryInfo * TLI,DominatorTree * DT)2391 bool LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
2392                                  const TargetLibraryInfo *TLI,
2393                                  DominatorTree *DT) {
2394   // Holds the Load and Store instructions.
2395   SmallVector<LoadInst *, 16> Loads;
2396   SmallVector<StoreInst *, 16> Stores;
2397   SmallPtrSet<MDNode *, 8> LoopAliasScopes;
2398 
2399   // Holds all the different accesses in the loop.
2400   unsigned NumReads = 0;
2401   unsigned NumReadWrites = 0;
2402 
2403   bool HasComplexMemInst = false;
2404 
2405   // A runtime check is only legal to insert if there are no convergent calls.
2406   HasConvergentOp = false;
2407 
2408   PtrRtChecking->Pointers.clear();
2409   PtrRtChecking->Need = false;
2410 
2411   const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel();
2412 
2413   const bool EnableMemAccessVersioningOfLoop =
2414       EnableMemAccessVersioning &&
2415       !TheLoop->getHeader()->getParent()->hasOptSize();
2416 
2417   // Traverse blocks in fixed RPOT order, regardless of their storage in the
2418   // loop info, as it may be arbitrary.
2419   LoopBlocksRPO RPOT(TheLoop);
2420   RPOT.perform(LI);
2421   for (BasicBlock *BB : RPOT) {
2422     // Scan the BB and collect legal loads and stores. Also detect any
2423     // convergent instructions.
2424     for (Instruction &I : *BB) {
2425       if (auto *Call = dyn_cast<CallBase>(&I)) {
2426         if (Call->isConvergent())
2427           HasConvergentOp = true;
2428       }
2429 
2430       // With both a non-vectorizable memory instruction and a convergent
2431       // operation, found in this loop, no reason to continue the search.
2432       if (HasComplexMemInst && HasConvergentOp)
2433         return false;
2434 
2435       // Avoid hitting recordAnalysis multiple times.
2436       if (HasComplexMemInst)
2437         continue;
2438 
2439       // Record alias scopes defined inside the loop.
2440       if (auto *Decl = dyn_cast<NoAliasScopeDeclInst>(&I))
2441         for (Metadata *Op : Decl->getScopeList()->operands())
2442           LoopAliasScopes.insert(cast<MDNode>(Op));
2443 
2444       // Many math library functions read the rounding mode. We will only
2445       // vectorize a loop if it contains known function calls that don't set
2446       // the flag. Therefore, it is safe to ignore this read from memory.
2447       auto *Call = dyn_cast<CallInst>(&I);
2448       if (Call && getVectorIntrinsicIDForCall(Call, TLI))
2449         continue;
2450 
2451       // If this is a load, save it. If this instruction can read from memory
2452       // but is not a load, then we quit. Notice that we don't handle function
2453       // calls that read or write.
2454       if (I.mayReadFromMemory()) {
2455         // If the function has an explicit vectorized counterpart, we can safely
2456         // assume that it can be vectorized.
2457         if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() &&
2458             !VFDatabase::getMappings(*Call).empty())
2459           continue;
2460 
2461         auto *Ld = dyn_cast<LoadInst>(&I);
2462         if (!Ld) {
2463           recordAnalysis("CantVectorizeInstruction", Ld)
2464             << "instruction cannot be vectorized";
2465           HasComplexMemInst = true;
2466           continue;
2467         }
2468         if (!Ld->isSimple() && !IsAnnotatedParallel) {
2469           recordAnalysis("NonSimpleLoad", Ld)
2470               << "read with atomic ordering or volatile read";
2471           LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n");
2472           HasComplexMemInst = true;
2473           continue;
2474         }
2475         NumLoads++;
2476         Loads.push_back(Ld);
2477         DepChecker->addAccess(Ld);
2478         if (EnableMemAccessVersioningOfLoop)
2479           collectStridedAccess(Ld);
2480         continue;
2481       }
2482 
2483       // Save 'store' instructions. Abort if other instructions write to memory.
2484       if (I.mayWriteToMemory()) {
2485         auto *St = dyn_cast<StoreInst>(&I);
2486         if (!St) {
2487           recordAnalysis("CantVectorizeInstruction", St)
2488               << "instruction cannot be vectorized";
2489           HasComplexMemInst = true;
2490           continue;
2491         }
2492         if (!St->isSimple() && !IsAnnotatedParallel) {
2493           recordAnalysis("NonSimpleStore", St)
2494               << "write with atomic ordering or volatile write";
2495           LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n");
2496           HasComplexMemInst = true;
2497           continue;
2498         }
2499         NumStores++;
2500         Stores.push_back(St);
2501         DepChecker->addAccess(St);
2502         if (EnableMemAccessVersioningOfLoop)
2503           collectStridedAccess(St);
2504       }
2505     } // Next instr.
2506   } // Next block.
2507 
2508   if (HasComplexMemInst)
2509     return false;
2510 
2511   // Now we have two lists that hold the loads and the stores.
2512   // Next, we find the pointers that they use.
2513 
2514   // Check if we see any stores. If there are no stores, then we don't
2515   // care if the pointers are *restrict*.
2516   if (!Stores.size()) {
2517     LLVM_DEBUG(dbgs() << "LAA: Found a read-only loop!\n");
2518     return true;
2519   }
2520 
2521   MemoryDepChecker::DepCandidates DependentAccesses;
2522   AccessAnalysis Accesses(TheLoop, AA, LI, DependentAccesses, *PSE,
2523                           LoopAliasScopes);
2524 
2525   // Holds the analyzed pointers. We don't want to call getUnderlyingObjects
2526   // multiple times on the same object. If the ptr is accessed twice, once
2527   // for read and once for write, it will only appear once (on the write
2528   // list). This is okay, since we are going to check for conflicts between
2529   // writes and between reads and writes, but not between reads and reads.
2530   SmallSet<std::pair<Value *, Type *>, 16> Seen;
2531 
2532   // Record uniform store addresses to identify if we have multiple stores
2533   // to the same address.
2534   SmallPtrSet<Value *, 16> UniformStores;
2535 
2536   for (StoreInst *ST : Stores) {
2537     Value *Ptr = ST->getPointerOperand();
2538 
2539     if (isInvariant(Ptr)) {
2540       // Record store instructions to loop invariant addresses
2541       StoresToInvariantAddresses.push_back(ST);
2542       HasStoreStoreDependenceInvolvingLoopInvariantAddress |=
2543           !UniformStores.insert(Ptr).second;
2544     }
2545 
2546     // If we did *not* see this pointer before, insert it to  the read-write
2547     // list. At this phase it is only a 'write' list.
2548     Type *AccessTy = getLoadStoreType(ST);
2549     if (Seen.insert({Ptr, AccessTy}).second) {
2550       ++NumReadWrites;
2551 
2552       MemoryLocation Loc = MemoryLocation::get(ST);
2553       // The TBAA metadata could have a control dependency on the predication
2554       // condition, so we cannot rely on it when determining whether or not we
2555       // need runtime pointer checks.
2556       if (blockNeedsPredication(ST->getParent(), TheLoop, DT))
2557         Loc.AATags.TBAA = nullptr;
2558 
2559       visitPointers(const_cast<Value *>(Loc.Ptr), *TheLoop,
2560                     [&Accesses, AccessTy, Loc](Value *Ptr) {
2561                       MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr);
2562                       Accesses.addStore(NewLoc, AccessTy);
2563                     });
2564     }
2565   }
2566 
2567   if (IsAnnotatedParallel) {
2568     LLVM_DEBUG(
2569         dbgs() << "LAA: A loop annotated parallel, ignore memory dependency "
2570                << "checks.\n");
2571     return true;
2572   }
2573 
2574   for (LoadInst *LD : Loads) {
2575     Value *Ptr = LD->getPointerOperand();
2576     // If we did *not* see this pointer before, insert it to the
2577     // read list. If we *did* see it before, then it is already in
2578     // the read-write list. This allows us to vectorize expressions
2579     // such as A[i] += x;  Because the address of A[i] is a read-write
2580     // pointer. This only works if the index of A[i] is consecutive.
2581     // If the address of i is unknown (for example A[B[i]]) then we may
2582     // read a few words, modify, and write a few words, and some of the
2583     // words may be written to the same address.
2584     bool IsReadOnlyPtr = false;
2585     Type *AccessTy = getLoadStoreType(LD);
2586     if (Seen.insert({Ptr, AccessTy}).second ||
2587         !getPtrStride(*PSE, LD->getType(), Ptr, TheLoop, SymbolicStrides).value_or(0)) {
2588       ++NumReads;
2589       IsReadOnlyPtr = true;
2590     }
2591 
2592     // See if there is an unsafe dependency between a load to a uniform address and
2593     // store to the same uniform address.
2594     if (UniformStores.count(Ptr)) {
2595       LLVM_DEBUG(dbgs() << "LAA: Found an unsafe dependency between a uniform "
2596                            "load and uniform store to the same address!\n");
2597       HasLoadStoreDependenceInvolvingLoopInvariantAddress = true;
2598     }
2599 
2600     MemoryLocation Loc = MemoryLocation::get(LD);
2601     // The TBAA metadata could have a control dependency on the predication
2602     // condition, so we cannot rely on it when determining whether or not we
2603     // need runtime pointer checks.
2604     if (blockNeedsPredication(LD->getParent(), TheLoop, DT))
2605       Loc.AATags.TBAA = nullptr;
2606 
2607     visitPointers(const_cast<Value *>(Loc.Ptr), *TheLoop,
2608                   [&Accesses, AccessTy, Loc, IsReadOnlyPtr](Value *Ptr) {
2609                     MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr);
2610                     Accesses.addLoad(NewLoc, AccessTy, IsReadOnlyPtr);
2611                   });
2612   }
2613 
2614   // If we write (or read-write) to a single destination and there are no
2615   // other reads in this loop then is it safe to vectorize.
2616   if (NumReadWrites == 1 && NumReads == 0) {
2617     LLVM_DEBUG(dbgs() << "LAA: Found a write-only loop!\n");
2618     return true;
2619   }
2620 
2621   // Build dependence sets and check whether we need a runtime pointer bounds
2622   // check.
2623   Accesses.buildDependenceSets();
2624 
2625   // Find pointers with computable bounds. We are going to use this information
2626   // to place a runtime bound check.
2627   Value *UncomputablePtr = nullptr;
2628   bool CanDoRTIfNeeded =
2629       Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), TheLoop,
2630                                SymbolicStrides, UncomputablePtr, false);
2631   if (!CanDoRTIfNeeded) {
2632     auto *I = dyn_cast_or_null<Instruction>(UncomputablePtr);
2633     recordAnalysis("CantIdentifyArrayBounds", I)
2634         << "cannot identify array bounds";
2635     LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find "
2636                       << "the array bounds.\n");
2637     return false;
2638   }
2639 
2640   LLVM_DEBUG(
2641     dbgs() << "LAA: May be able to perform a memory runtime check if needed.\n");
2642 
2643   bool DepsAreSafe = true;
2644   if (Accesses.isDependencyCheckNeeded()) {
2645     LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n");
2646     DepsAreSafe = DepChecker->areDepsSafe(DependentAccesses,
2647                                           Accesses.getDependenciesToCheck());
2648 
2649     if (!DepsAreSafe && DepChecker->shouldRetryWithRuntimeCheck()) {
2650       LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n");
2651 
2652       // Clear the dependency checks. We assume they are not needed.
2653       Accesses.resetDepChecks(*DepChecker);
2654 
2655       PtrRtChecking->reset();
2656       PtrRtChecking->Need = true;
2657 
2658       auto *SE = PSE->getSE();
2659       UncomputablePtr = nullptr;
2660       CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(
2661           *PtrRtChecking, SE, TheLoop, SymbolicStrides, UncomputablePtr, true);
2662 
2663       // Check that we found the bounds for the pointer.
2664       if (!CanDoRTIfNeeded) {
2665         auto *I = dyn_cast_or_null<Instruction>(UncomputablePtr);
2666         recordAnalysis("CantCheckMemDepsAtRunTime", I)
2667             << "cannot check memory dependencies at runtime";
2668         LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n");
2669         return false;
2670       }
2671       DepsAreSafe = true;
2672     }
2673   }
2674 
2675   if (HasConvergentOp) {
2676     recordAnalysis("CantInsertRuntimeCheckWithConvergent")
2677         << "cannot add control dependency to convergent operation";
2678     LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because a runtime check "
2679                          "would be needed with a convergent operation\n");
2680     return false;
2681   }
2682 
2683   if (DepsAreSafe) {
2684     LLVM_DEBUG(
2685         dbgs() << "LAA: No unsafe dependent memory operations in loop.  We"
2686                << (PtrRtChecking->Need ? "" : " don't")
2687                << " need runtime memory checks.\n");
2688     return true;
2689   }
2690 
2691   emitUnsafeDependenceRemark();
2692   return false;
2693 }
2694 
emitUnsafeDependenceRemark()2695 void LoopAccessInfo::emitUnsafeDependenceRemark() {
2696   const auto *Deps = getDepChecker().getDependences();
2697   if (!Deps)
2698     return;
2699   const auto *Found =
2700       llvm::find_if(*Deps, [](const MemoryDepChecker::Dependence &D) {
2701         return MemoryDepChecker::Dependence::isSafeForVectorization(D.Type) !=
2702                MemoryDepChecker::VectorizationSafetyStatus::Safe;
2703       });
2704   if (Found == Deps->end())
2705     return;
2706   MemoryDepChecker::Dependence Dep = *Found;
2707 
2708   LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n");
2709 
2710   // Emit remark for first unsafe dependence
2711   bool HasForcedDistribution = false;
2712   std::optional<const MDOperand *> Value =
2713       findStringMetadataForLoop(TheLoop, "llvm.loop.distribute.enable");
2714   if (Value) {
2715     const MDOperand *Op = *Value;
2716     assert(Op && mdconst::hasa<ConstantInt>(*Op) && "invalid metadata");
2717     HasForcedDistribution = mdconst::extract<ConstantInt>(*Op)->getZExtValue();
2718   }
2719 
2720   const std::string Info =
2721       HasForcedDistribution
2722           ? "unsafe dependent memory operations in loop."
2723           : "unsafe dependent memory operations in loop. Use "
2724             "#pragma clang loop distribute(enable) to allow loop distribution "
2725             "to attempt to isolate the offending operations into a separate "
2726             "loop";
2727   OptimizationRemarkAnalysis &R =
2728       recordAnalysis("UnsafeDep", Dep.getDestination(getDepChecker())) << Info;
2729 
2730   switch (Dep.Type) {
2731   case MemoryDepChecker::Dependence::NoDep:
2732   case MemoryDepChecker::Dependence::Forward:
2733   case MemoryDepChecker::Dependence::BackwardVectorizable:
2734     llvm_unreachable("Unexpected dependence");
2735   case MemoryDepChecker::Dependence::Backward:
2736     R << "\nBackward loop carried data dependence.";
2737     break;
2738   case MemoryDepChecker::Dependence::ForwardButPreventsForwarding:
2739     R << "\nForward loop carried data dependence that prevents "
2740          "store-to-load forwarding.";
2741     break;
2742   case MemoryDepChecker::Dependence::BackwardVectorizableButPreventsForwarding:
2743     R << "\nBackward loop carried data dependence that prevents "
2744          "store-to-load forwarding.";
2745     break;
2746   case MemoryDepChecker::Dependence::IndirectUnsafe:
2747     R << "\nUnsafe indirect dependence.";
2748     break;
2749   case MemoryDepChecker::Dependence::Unknown:
2750     R << "\nUnknown data dependence.";
2751     break;
2752   }
2753 
2754   if (Instruction *I = Dep.getSource(getDepChecker())) {
2755     DebugLoc SourceLoc = I->getDebugLoc();
2756     if (auto *DD = dyn_cast_or_null<Instruction>(getPointerOperand(I)))
2757       SourceLoc = DD->getDebugLoc();
2758     if (SourceLoc)
2759       R << " Memory location is the same as accessed at "
2760         << ore::NV("Location", SourceLoc);
2761   }
2762 }
2763 
blockNeedsPredication(BasicBlock * BB,Loop * TheLoop,DominatorTree * DT)2764 bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop,
2765                                            DominatorTree *DT)  {
2766   assert(TheLoop->contains(BB) && "Unknown block used");
2767 
2768   // Blocks that do not dominate the latch need predication.
2769   BasicBlock* Latch = TheLoop->getLoopLatch();
2770   return !DT->dominates(BB, Latch);
2771 }
2772 
recordAnalysis(StringRef RemarkName,Instruction * I)2773 OptimizationRemarkAnalysis &LoopAccessInfo::recordAnalysis(StringRef RemarkName,
2774                                                            Instruction *I) {
2775   assert(!Report && "Multiple reports generated");
2776 
2777   Value *CodeRegion = TheLoop->getHeader();
2778   DebugLoc DL = TheLoop->getStartLoc();
2779 
2780   if (I) {
2781     CodeRegion = I->getParent();
2782     // If there is no debug location attached to the instruction, revert back to
2783     // using the loop's.
2784     if (I->getDebugLoc())
2785       DL = I->getDebugLoc();
2786   }
2787 
2788   Report = std::make_unique<OptimizationRemarkAnalysis>(DEBUG_TYPE, RemarkName, DL,
2789                                                    CodeRegion);
2790   return *Report;
2791 }
2792 
isInvariant(Value * V) const2793 bool LoopAccessInfo::isInvariant(Value *V) const {
2794   auto *SE = PSE->getSE();
2795   // TODO: Is this really what we want? Even without FP SCEV, we may want some
2796   // trivially loop-invariant FP values to be considered invariant.
2797   if (!SE->isSCEVable(V->getType()))
2798     return false;
2799   const SCEV *S = SE->getSCEV(V);
2800   return SE->isLoopInvariant(S, TheLoop);
2801 }
2802 
2803 /// Find the operand of the GEP that should be checked for consecutive
2804 /// stores. This ignores trailing indices that have no effect on the final
2805 /// pointer.
getGEPInductionOperand(const GetElementPtrInst * Gep)2806 static unsigned getGEPInductionOperand(const GetElementPtrInst *Gep) {
2807   const DataLayout &DL = Gep->getDataLayout();
2808   unsigned LastOperand = Gep->getNumOperands() - 1;
2809   TypeSize GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType());
2810 
2811   // Walk backwards and try to peel off zeros.
2812   while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) {
2813     // Find the type we're currently indexing into.
2814     gep_type_iterator GEPTI = gep_type_begin(Gep);
2815     std::advance(GEPTI, LastOperand - 2);
2816 
2817     // If it's a type with the same allocation size as the result of the GEP we
2818     // can peel off the zero index.
2819     TypeSize ElemSize = GEPTI.isStruct()
2820                             ? DL.getTypeAllocSize(GEPTI.getIndexedType())
2821                             : GEPTI.getSequentialElementStride(DL);
2822     if (ElemSize != GEPAllocSize)
2823       break;
2824     --LastOperand;
2825   }
2826 
2827   return LastOperand;
2828 }
2829 
2830 /// If the argument is a GEP, then returns the operand identified by
2831 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
2832 /// operand, it returns that instead.
stripGetElementPtr(Value * Ptr,ScalarEvolution * SE,Loop * Lp)2833 static Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
2834   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
2835   if (!GEP)
2836     return Ptr;
2837 
2838   unsigned InductionOperand = getGEPInductionOperand(GEP);
2839 
2840   // Check that all of the gep indices are uniform except for our induction
2841   // operand.
2842   for (unsigned I = 0, E = GEP->getNumOperands(); I != E; ++I)
2843     if (I != InductionOperand &&
2844         !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(I)), Lp))
2845       return Ptr;
2846   return GEP->getOperand(InductionOperand);
2847 }
2848 
2849 /// Get the stride of a pointer access in a loop. Looks for symbolic
2850 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
getStrideFromPointer(Value * Ptr,ScalarEvolution * SE,Loop * Lp)2851 static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
2852   auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
2853   if (!PtrTy || PtrTy->isAggregateType())
2854     return nullptr;
2855 
2856   // Try to remove a gep instruction to make the pointer (actually index at this
2857   // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the
2858   // pointer, otherwise, we are analyzing the index.
2859   Value *OrigPtr = Ptr;
2860 
2861   // The size of the pointer access.
2862   int64_t PtrAccessSize = 1;
2863 
2864   Ptr = stripGetElementPtr(Ptr, SE, Lp);
2865   const SCEV *V = SE->getSCEV(Ptr);
2866 
2867   if (Ptr != OrigPtr)
2868     // Strip off casts.
2869     while (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V))
2870       V = C->getOperand();
2871 
2872   const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V);
2873   if (!S)
2874     return nullptr;
2875 
2876   // If the pointer is invariant then there is no stride and it makes no
2877   // sense to add it here.
2878   if (Lp != S->getLoop())
2879     return nullptr;
2880 
2881   V = S->getStepRecurrence(*SE);
2882   if (!V)
2883     return nullptr;
2884 
2885   // Strip off the size of access multiplication if we are still analyzing the
2886   // pointer.
2887   if (OrigPtr == Ptr) {
2888     if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) {
2889       if (M->getOperand(0)->getSCEVType() != scConstant)
2890         return nullptr;
2891 
2892       const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt();
2893 
2894       // Huge step value - give up.
2895       if (APStepVal.getBitWidth() > 64)
2896         return nullptr;
2897 
2898       int64_t StepVal = APStepVal.getSExtValue();
2899       if (PtrAccessSize != StepVal)
2900         return nullptr;
2901       V = M->getOperand(1);
2902     }
2903   }
2904 
2905   // Note that the restriction after this loop invariant check are only
2906   // profitability restrictions.
2907   if (!SE->isLoopInvariant(V, Lp))
2908     return nullptr;
2909 
2910   // Look for the loop invariant symbolic value.
2911   if (isa<SCEVUnknown>(V))
2912     return V;
2913 
2914   if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
2915     if (isa<SCEVUnknown>(C->getOperand()))
2916       return V;
2917 
2918   return nullptr;
2919 }
2920 
collectStridedAccess(Value * MemAccess)2921 void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
2922   Value *Ptr = getLoadStorePointerOperand(MemAccess);
2923   if (!Ptr)
2924     return;
2925 
2926   // Note: getStrideFromPointer is a *profitability* heuristic.  We
2927   // could broaden the scope of values returned here - to anything
2928   // which happens to be loop invariant and contributes to the
2929   // computation of an interesting IV - but we chose not to as we
2930   // don't have a cost model here, and broadening the scope exposes
2931   // far too many unprofitable cases.
2932   const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
2933   if (!StrideExpr)
2934     return;
2935 
2936   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
2937                        "versioning:");
2938   LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
2939 
2940   if (!SpeculateUnitStride) {
2941     LLVM_DEBUG(dbgs() << "  Chose not to due to -laa-speculate-unit-stride\n");
2942     return;
2943   }
2944 
2945   // Avoid adding the "Stride == 1" predicate when we know that
2946   // Stride >= Trip-Count. Such a predicate will effectively optimize a single
2947   // or zero iteration loop, as Trip-Count <= Stride == 1.
2948   //
2949   // TODO: We are currently not making a very informed decision on when it is
2950   // beneficial to apply stride versioning. It might make more sense that the
2951   // users of this analysis (such as the vectorizer) will trigger it, based on
2952   // their specific cost considerations; For example, in cases where stride
2953   // versioning does  not help resolving memory accesses/dependences, the
2954   // vectorizer should evaluate the cost of the runtime test, and the benefit
2955   // of various possible stride specializations, considering the alternatives
2956   // of using gather/scatters (if available).
2957 
2958   const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
2959 
2960   // Match the types so we can compare the stride and the MaxBTC.
2961   // The Stride can be positive/negative, so we sign extend Stride;
2962   // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
2963   const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
2964   uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
2965   uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
2966   const SCEV *CastedStride = StrideExpr;
2967   const SCEV *CastedBECount = MaxBTC;
2968   ScalarEvolution *SE = PSE->getSE();
2969   if (BETypeSizeBits >= StrideTypeSizeBits)
2970     CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
2971   else
2972     CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
2973   const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
2974   // Since TripCount == BackEdgeTakenCount + 1, checking:
2975   // "Stride >= TripCount" is equivalent to checking:
2976   // Stride - MaxBTC> 0
2977   if (SE->isKnownPositive(StrideMinusBETaken)) {
2978     LLVM_DEBUG(
2979         dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
2980                   "Stride==1 predicate will imply that the loop executes "
2981                   "at most once.\n");
2982     return;
2983   }
2984   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
2985 
2986   // Strip back off the integer cast, and check that our result is a
2987   // SCEVUnknown as we expect.
2988   const SCEV *StrideBase = StrideExpr;
2989   if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
2990     StrideBase = C->getOperand();
2991   SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
2992 }
2993 
LoopAccessInfo(Loop * L,ScalarEvolution * SE,const TargetTransformInfo * TTI,const TargetLibraryInfo * TLI,AAResults * AA,DominatorTree * DT,LoopInfo * LI)2994 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
2995                                const TargetTransformInfo *TTI,
2996                                const TargetLibraryInfo *TLI, AAResults *AA,
2997                                DominatorTree *DT, LoopInfo *LI)
2998     : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)),
2999       PtrRtChecking(nullptr), TheLoop(L) {
3000   unsigned MaxTargetVectorWidthInBits = std::numeric_limits<unsigned>::max();
3001   if (TTI) {
3002     TypeSize FixedWidth =
3003         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
3004     if (FixedWidth.isNonZero()) {
3005       // Scale the vector width by 2 as rough estimate to also consider
3006       // interleaving.
3007       MaxTargetVectorWidthInBits = FixedWidth.getFixedValue() * 2;
3008     }
3009 
3010     TypeSize ScalableWidth =
3011         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_ScalableVector);
3012     if (ScalableWidth.isNonZero())
3013       MaxTargetVectorWidthInBits = std::numeric_limits<unsigned>::max();
3014   }
3015   DepChecker = std::make_unique<MemoryDepChecker>(*PSE, L, SymbolicStrides,
3016                                                   MaxTargetVectorWidthInBits);
3017   PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE);
3018   if (canAnalyzeLoop())
3019     CanVecMem = analyzeLoop(AA, LI, TLI, DT);
3020 }
3021 
print(raw_ostream & OS,unsigned Depth) const3022 void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
3023   if (CanVecMem) {
3024     OS.indent(Depth) << "Memory dependences are safe";
3025     const MemoryDepChecker &DC = getDepChecker();
3026     if (!DC.isSafeForAnyVectorWidth())
3027       OS << " with a maximum safe vector width of "
3028          << DC.getMaxSafeVectorWidthInBits() << " bits";
3029     if (PtrRtChecking->Need)
3030       OS << " with run-time checks";
3031     OS << "\n";
3032   }
3033 
3034   if (HasConvergentOp)
3035     OS.indent(Depth) << "Has convergent operation in loop\n";
3036 
3037   if (Report)
3038     OS.indent(Depth) << "Report: " << Report->getMsg() << "\n";
3039 
3040   if (auto *Dependences = DepChecker->getDependences()) {
3041     OS.indent(Depth) << "Dependences:\n";
3042     for (const auto &Dep : *Dependences) {
3043       Dep.print(OS, Depth + 2, DepChecker->getMemoryInstructions());
3044       OS << "\n";
3045     }
3046   } else
3047     OS.indent(Depth) << "Too many dependences, not recorded\n";
3048 
3049   // List the pair of accesses need run-time checks to prove independence.
3050   PtrRtChecking->print(OS, Depth);
3051   OS << "\n";
3052 
3053   OS.indent(Depth)
3054       << "Non vectorizable stores to invariant address were "
3055       << (HasStoreStoreDependenceInvolvingLoopInvariantAddress ||
3056                   HasLoadStoreDependenceInvolvingLoopInvariantAddress
3057               ? ""
3058               : "not ")
3059       << "found in loop.\n";
3060 
3061   OS.indent(Depth) << "SCEV assumptions:\n";
3062   PSE->getPredicate().print(OS, Depth);
3063 
3064   OS << "\n";
3065 
3066   OS.indent(Depth) << "Expressions re-written:\n";
3067   PSE->print(OS, Depth);
3068 }
3069 
getInfo(Loop & L)3070 const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L) {
3071   auto [It, Inserted] = LoopAccessInfoMap.insert({&L, nullptr});
3072 
3073   if (Inserted)
3074     It->second =
3075         std::make_unique<LoopAccessInfo>(&L, &SE, TTI, TLI, &AA, &DT, &LI);
3076 
3077   return *It->second;
3078 }
clear()3079 void LoopAccessInfoManager::clear() {
3080   SmallVector<Loop *> ToRemove;
3081   // Collect LoopAccessInfo entries that may keep references to IR outside the
3082   // analyzed loop or SCEVs that may have been modified or invalidated. At the
3083   // moment, that is loops requiring memory or SCEV runtime checks, as those cache
3084   // SCEVs, e.g. for pointer expressions.
3085   for (const auto &[L, LAI] : LoopAccessInfoMap) {
3086     if (LAI->getRuntimePointerChecking()->getChecks().empty() &&
3087         LAI->getPSE().getPredicate().isAlwaysTrue())
3088       continue;
3089     ToRemove.push_back(L);
3090   }
3091 
3092   for (Loop *L : ToRemove)
3093     LoopAccessInfoMap.erase(L);
3094 }
3095 
invalidate(Function & F,const PreservedAnalyses & PA,FunctionAnalysisManager::Invalidator & Inv)3096 bool LoopAccessInfoManager::invalidate(
3097     Function &F, const PreservedAnalyses &PA,
3098     FunctionAnalysisManager::Invalidator &Inv) {
3099   // Check whether our analysis is preserved.
3100   auto PAC = PA.getChecker<LoopAccessAnalysis>();
3101   if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>())
3102     // If not, give up now.
3103     return true;
3104 
3105   // Check whether the analyses we depend on became invalid for any reason.
3106   // Skip checking TargetLibraryAnalysis as it is immutable and can't become
3107   // invalid.
3108   return Inv.invalidate<AAManager>(F, PA) ||
3109          Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) ||
3110          Inv.invalidate<LoopAnalysis>(F, PA) ||
3111          Inv.invalidate<DominatorTreeAnalysis>(F, PA);
3112 }
3113 
run(Function & F,FunctionAnalysisManager & FAM)3114 LoopAccessInfoManager LoopAccessAnalysis::run(Function &F,
3115                                               FunctionAnalysisManager &FAM) {
3116   auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
3117   auto &AA = FAM.getResult<AAManager>(F);
3118   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
3119   auto &LI = FAM.getResult<LoopAnalysis>(F);
3120   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
3121   auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
3122   return LoopAccessInfoManager(SE, AA, DT, LI, &TTI, &TLI);
3123 }
3124 
3125 AnalysisKey LoopAccessAnalysis::Key;
3126