xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp (revision 5b56413d04e608379c9a306373554a8e4d321bc0)
1 //===- LoopDistribute.cpp - Loop Distribution Pass ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Loop Distribution Pass.  Its main focus is to
10 // distribute loops that cannot be vectorized due to dependence cycles.  It
11 // tries to isolate the offending dependences into a new loop allowing
12 // vectorization of the remaining parts.
13 //
14 // For dependence analysis, the pass uses the LoopVectorizer's
15 // LoopAccessAnalysis.  Because this analysis presumes no change in the order of
16 // memory operations, special care is taken to preserve the lexical order of
17 // these operations.
18 //
19 // Similarly to the Vectorizer, the pass also supports loop versioning to
20 // run-time disambiguate potentially overlapping arrays.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "llvm/Transforms/Scalar/LoopDistribute.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/DepthFirstIterator.h"
27 #include "llvm/ADT/EquivalenceClasses.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallPtrSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/Statistic.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/Twine.h"
34 #include "llvm/ADT/iterator_range.h"
35 #include "llvm/Analysis/AssumptionCache.h"
36 #include "llvm/Analysis/GlobalsModRef.h"
37 #include "llvm/Analysis/LoopAccessAnalysis.h"
38 #include "llvm/Analysis/LoopAnalysisManager.h"
39 #include "llvm/Analysis/LoopInfo.h"
40 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
41 #include "llvm/Analysis/ScalarEvolution.h"
42 #include "llvm/Analysis/TargetLibraryInfo.h"
43 #include "llvm/Analysis/TargetTransformInfo.h"
44 #include "llvm/IR/BasicBlock.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DiagnosticInfo.h"
47 #include "llvm/IR/Dominators.h"
48 #include "llvm/IR/Function.h"
49 #include "llvm/IR/Instruction.h"
50 #include "llvm/IR/Instructions.h"
51 #include "llvm/IR/LLVMContext.h"
52 #include "llvm/IR/Metadata.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Value.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/raw_ostream.h"
59 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
60 #include "llvm/Transforms/Utils/Cloning.h"
61 #include "llvm/Transforms/Utils/LoopUtils.h"
62 #include "llvm/Transforms/Utils/LoopVersioning.h"
63 #include "llvm/Transforms/Utils/ValueMapper.h"
64 #include <cassert>
65 #include <functional>
66 #include <list>
67 #include <tuple>
68 #include <utility>
69 
70 using namespace llvm;
71 
72 #define LDIST_NAME "loop-distribute"
73 #define DEBUG_TYPE LDIST_NAME
74 
75 /// @{
76 /// Metadata attribute names
77 static const char *const LLVMLoopDistributeFollowupAll =
78     "llvm.loop.distribute.followup_all";
79 static const char *const LLVMLoopDistributeFollowupCoincident =
80     "llvm.loop.distribute.followup_coincident";
81 static const char *const LLVMLoopDistributeFollowupSequential =
82     "llvm.loop.distribute.followup_sequential";
83 static const char *const LLVMLoopDistributeFollowupFallback =
84     "llvm.loop.distribute.followup_fallback";
85 /// @}
86 
87 static cl::opt<bool>
88     LDistVerify("loop-distribute-verify", cl::Hidden,
89                 cl::desc("Turn on DominatorTree and LoopInfo verification "
90                          "after Loop Distribution"),
91                 cl::init(false));
92 
93 static cl::opt<bool> DistributeNonIfConvertible(
94     "loop-distribute-non-if-convertible", cl::Hidden,
95     cl::desc("Whether to distribute into a loop that may not be "
96              "if-convertible by the loop vectorizer"),
97     cl::init(false));
98 
99 static cl::opt<unsigned> DistributeSCEVCheckThreshold(
100     "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden,
101     cl::desc("The maximum number of SCEV checks allowed for Loop "
102              "Distribution"));
103 
104 static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold(
105     "loop-distribute-scev-check-threshold-with-pragma", cl::init(128),
106     cl::Hidden,
107     cl::desc("The maximum number of SCEV checks allowed for Loop "
108              "Distribution for loop marked with #pragma clang loop "
109              "distribute(enable)"));
110 
111 static cl::opt<bool> EnableLoopDistribute(
112     "enable-loop-distribute", cl::Hidden,
113     cl::desc("Enable the new, experimental LoopDistribution Pass"),
114     cl::init(false));
115 
116 STATISTIC(NumLoopsDistributed, "Number of loops distributed");
117 
118 namespace {
119 
120 /// Maintains the set of instructions of the loop for a partition before
121 /// cloning.  After cloning, it hosts the new loop.
122 class InstPartition {
123   using InstructionSet = SmallPtrSet<Instruction *, 8>;
124 
125 public:
126   InstPartition(Instruction *I, Loop *L, bool DepCycle = false)
127       : DepCycle(DepCycle), OrigLoop(L) {
128     Set.insert(I);
129   }
130 
131   /// Returns whether this partition contains a dependence cycle.
132   bool hasDepCycle() const { return DepCycle; }
133 
134   /// Adds an instruction to this partition.
135   void add(Instruction *I) { Set.insert(I); }
136 
137   /// Collection accessors.
138   InstructionSet::iterator begin() { return Set.begin(); }
139   InstructionSet::iterator end() { return Set.end(); }
140   InstructionSet::const_iterator begin() const { return Set.begin(); }
141   InstructionSet::const_iterator end() const { return Set.end(); }
142   bool empty() const { return Set.empty(); }
143 
144   /// Moves this partition into \p Other.  This partition becomes empty
145   /// after this.
146   void moveTo(InstPartition &Other) {
147     Other.Set.insert(Set.begin(), Set.end());
148     Set.clear();
149     Other.DepCycle |= DepCycle;
150   }
151 
152   /// Populates the partition with a transitive closure of all the
153   /// instructions that the seeded instructions dependent on.
154   void populateUsedSet() {
155     // FIXME: We currently don't use control-dependence but simply include all
156     // blocks (possibly empty at the end) and let simplifycfg mostly clean this
157     // up.
158     for (auto *B : OrigLoop->getBlocks())
159       Set.insert(B->getTerminator());
160 
161     // Follow the use-def chains to form a transitive closure of all the
162     // instructions that the originally seeded instructions depend on.
163     SmallVector<Instruction *, 8> Worklist(Set.begin(), Set.end());
164     while (!Worklist.empty()) {
165       Instruction *I = Worklist.pop_back_val();
166       // Insert instructions from the loop that we depend on.
167       for (Value *V : I->operand_values()) {
168         auto *I = dyn_cast<Instruction>(V);
169         if (I && OrigLoop->contains(I->getParent()) && Set.insert(I).second)
170           Worklist.push_back(I);
171       }
172     }
173   }
174 
175   /// Clones the original loop.
176   ///
177   /// Updates LoopInfo and DominatorTree using the information that block \p
178   /// LoopDomBB dominates the loop.
179   Loop *cloneLoopWithPreheader(BasicBlock *InsertBefore, BasicBlock *LoopDomBB,
180                                unsigned Index, LoopInfo *LI,
181                                DominatorTree *DT) {
182     ClonedLoop = ::cloneLoopWithPreheader(InsertBefore, LoopDomBB, OrigLoop,
183                                           VMap, Twine(".ldist") + Twine(Index),
184                                           LI, DT, ClonedLoopBlocks);
185     return ClonedLoop;
186   }
187 
188   /// The cloned loop.  If this partition is mapped to the original loop,
189   /// this is null.
190   const Loop *getClonedLoop() const { return ClonedLoop; }
191 
192   /// Returns the loop where this partition ends up after distribution.
193   /// If this partition is mapped to the original loop then use the block from
194   /// the loop.
195   Loop *getDistributedLoop() const {
196     return ClonedLoop ? ClonedLoop : OrigLoop;
197   }
198 
199   /// The VMap that is populated by cloning and then used in
200   /// remapinstruction to remap the cloned instructions.
201   ValueToValueMapTy &getVMap() { return VMap; }
202 
203   /// Remaps the cloned instructions using VMap.
204   void remapInstructions() {
205     remapInstructionsInBlocks(ClonedLoopBlocks, VMap);
206   }
207 
208   /// Based on the set of instructions selected for this partition,
209   /// removes the unnecessary ones.
210   void removeUnusedInsts() {
211     SmallVector<Instruction *, 8> Unused;
212 
213     for (auto *Block : OrigLoop->getBlocks())
214       for (auto &Inst : *Block)
215         if (!Set.count(&Inst)) {
216           Instruction *NewInst = &Inst;
217           if (!VMap.empty())
218             NewInst = cast<Instruction>(VMap[NewInst]);
219 
220           assert(!isa<BranchInst>(NewInst) &&
221                  "Branches are marked used early on");
222           Unused.push_back(NewInst);
223         }
224 
225     // Delete the instructions backwards, as it has a reduced likelihood of
226     // having to update as many def-use and use-def chains.
227     for (auto *Inst : reverse(Unused)) {
228       if (!Inst->use_empty())
229         Inst->replaceAllUsesWith(PoisonValue::get(Inst->getType()));
230       Inst->eraseFromParent();
231     }
232   }
233 
234   void print() const {
235     if (DepCycle)
236       dbgs() << "  (cycle)\n";
237     for (auto *I : Set)
238       // Prefix with the block name.
239       dbgs() << "  " << I->getParent()->getName() << ":" << *I << "\n";
240   }
241 
242   void printBlocks() const {
243     for (auto *BB : getDistributedLoop()->getBlocks())
244       dbgs() << *BB;
245   }
246 
247 private:
248   /// Instructions from OrigLoop selected for this partition.
249   InstructionSet Set;
250 
251   /// Whether this partition contains a dependence cycle.
252   bool DepCycle;
253 
254   /// The original loop.
255   Loop *OrigLoop;
256 
257   /// The cloned loop.  If this partition is mapped to the original loop,
258   /// this is null.
259   Loop *ClonedLoop = nullptr;
260 
261   /// The blocks of ClonedLoop including the preheader.  If this
262   /// partition is mapped to the original loop, this is empty.
263   SmallVector<BasicBlock *, 8> ClonedLoopBlocks;
264 
265   /// These gets populated once the set of instructions have been
266   /// finalized. If this partition is mapped to the original loop, these are not
267   /// set.
268   ValueToValueMapTy VMap;
269 };
270 
271 /// Holds the set of Partitions.  It populates them, merges them and then
272 /// clones the loops.
273 class InstPartitionContainer {
274   using InstToPartitionIdT = DenseMap<Instruction *, int>;
275 
276 public:
277   InstPartitionContainer(Loop *L, LoopInfo *LI, DominatorTree *DT)
278       : L(L), LI(LI), DT(DT) {}
279 
280   /// Returns the number of partitions.
281   unsigned getSize() const { return PartitionContainer.size(); }
282 
283   /// Adds \p Inst into the current partition if that is marked to
284   /// contain cycles.  Otherwise start a new partition for it.
285   void addToCyclicPartition(Instruction *Inst) {
286     // If the current partition is non-cyclic.  Start a new one.
287     if (PartitionContainer.empty() || !PartitionContainer.back().hasDepCycle())
288       PartitionContainer.emplace_back(Inst, L, /*DepCycle=*/true);
289     else
290       PartitionContainer.back().add(Inst);
291   }
292 
293   /// Adds \p Inst into a partition that is not marked to contain
294   /// dependence cycles.
295   ///
296   //  Initially we isolate memory instructions into as many partitions as
297   //  possible, then later we may merge them back together.
298   void addToNewNonCyclicPartition(Instruction *Inst) {
299     PartitionContainer.emplace_back(Inst, L);
300   }
301 
302   /// Merges adjacent non-cyclic partitions.
303   ///
304   /// The idea is that we currently only want to isolate the non-vectorizable
305   /// partition.  We could later allow more distribution among these partition
306   /// too.
307   void mergeAdjacentNonCyclic() {
308     mergeAdjacentPartitionsIf(
309         [](const InstPartition *P) { return !P->hasDepCycle(); });
310   }
311 
312   /// If a partition contains only conditional stores, we won't vectorize
313   /// it.  Try to merge it with a previous cyclic partition.
314   void mergeNonIfConvertible() {
315     mergeAdjacentPartitionsIf([&](const InstPartition *Partition) {
316       if (Partition->hasDepCycle())
317         return true;
318 
319       // Now, check if all stores are conditional in this partition.
320       bool seenStore = false;
321 
322       for (auto *Inst : *Partition)
323         if (isa<StoreInst>(Inst)) {
324           seenStore = true;
325           if (!LoopAccessInfo::blockNeedsPredication(Inst->getParent(), L, DT))
326             return false;
327         }
328       return seenStore;
329     });
330   }
331 
332   /// Merges the partitions according to various heuristics.
333   void mergeBeforePopulating() {
334     mergeAdjacentNonCyclic();
335     if (!DistributeNonIfConvertible)
336       mergeNonIfConvertible();
337   }
338 
339   /// Merges partitions in order to ensure that no loads are duplicated.
340   ///
341   /// We can't duplicate loads because that could potentially reorder them.
342   /// LoopAccessAnalysis provides dependency information with the context that
343   /// the order of memory operation is preserved.
344   ///
345   /// Return if any partitions were merged.
346   bool mergeToAvoidDuplicatedLoads() {
347     using LoadToPartitionT = DenseMap<Instruction *, InstPartition *>;
348     using ToBeMergedT = EquivalenceClasses<InstPartition *>;
349 
350     LoadToPartitionT LoadToPartition;
351     ToBeMergedT ToBeMerged;
352 
353     // Step through the partitions and create equivalence between partitions
354     // that contain the same load.  Also put partitions in between them in the
355     // same equivalence class to avoid reordering of memory operations.
356     for (PartitionContainerT::iterator I = PartitionContainer.begin(),
357                                        E = PartitionContainer.end();
358          I != E; ++I) {
359       auto *PartI = &*I;
360 
361       // If a load occurs in two partitions PartI and PartJ, merge all
362       // partitions (PartI, PartJ] into PartI.
363       for (Instruction *Inst : *PartI)
364         if (isa<LoadInst>(Inst)) {
365           bool NewElt;
366           LoadToPartitionT::iterator LoadToPart;
367 
368           std::tie(LoadToPart, NewElt) =
369               LoadToPartition.insert(std::make_pair(Inst, PartI));
370           if (!NewElt) {
371             LLVM_DEBUG(dbgs()
372                        << "Merging partitions due to this load in multiple "
373                        << "partitions: " << PartI << ", " << LoadToPart->second
374                        << "\n"
375                        << *Inst << "\n");
376 
377             auto PartJ = I;
378             do {
379               --PartJ;
380               ToBeMerged.unionSets(PartI, &*PartJ);
381             } while (&*PartJ != LoadToPart->second);
382           }
383         }
384     }
385     if (ToBeMerged.empty())
386       return false;
387 
388     // Merge the member of an equivalence class into its class leader.  This
389     // makes the members empty.
390     for (ToBeMergedT::iterator I = ToBeMerged.begin(), E = ToBeMerged.end();
391          I != E; ++I) {
392       if (!I->isLeader())
393         continue;
394 
395       auto PartI = I->getData();
396       for (auto *PartJ : make_range(std::next(ToBeMerged.member_begin(I)),
397                                    ToBeMerged.member_end())) {
398         PartJ->moveTo(*PartI);
399       }
400     }
401 
402     // Remove the empty partitions.
403     PartitionContainer.remove_if(
404         [](const InstPartition &P) { return P.empty(); });
405 
406     return true;
407   }
408 
409   /// Sets up the mapping between instructions to partitions.  If the
410   /// instruction is duplicated across multiple partitions, set the entry to -1.
411   void setupPartitionIdOnInstructions() {
412     int PartitionID = 0;
413     for (const auto &Partition : PartitionContainer) {
414       for (Instruction *Inst : Partition) {
415         bool NewElt;
416         InstToPartitionIdT::iterator Iter;
417 
418         std::tie(Iter, NewElt) =
419             InstToPartitionId.insert(std::make_pair(Inst, PartitionID));
420         if (!NewElt)
421           Iter->second = -1;
422       }
423       ++PartitionID;
424     }
425   }
426 
427   /// Populates the partition with everything that the seeding
428   /// instructions require.
429   void populateUsedSet() {
430     for (auto &P : PartitionContainer)
431       P.populateUsedSet();
432   }
433 
434   /// This performs the main chunk of the work of cloning the loops for
435   /// the partitions.
436   void cloneLoops() {
437     BasicBlock *OrigPH = L->getLoopPreheader();
438     // At this point the predecessor of the preheader is either the memcheck
439     // block or the top part of the original preheader.
440     BasicBlock *Pred = OrigPH->getSinglePredecessor();
441     assert(Pred && "Preheader does not have a single predecessor");
442     BasicBlock *ExitBlock = L->getExitBlock();
443     assert(ExitBlock && "No single exit block");
444     Loop *NewLoop;
445 
446     assert(!PartitionContainer.empty() && "at least two partitions expected");
447     // We're cloning the preheader along with the loop so we already made sure
448     // it was empty.
449     assert(&*OrigPH->begin() == OrigPH->getTerminator() &&
450            "preheader not empty");
451 
452     // Preserve the original loop ID for use after the transformation.
453     MDNode *OrigLoopID = L->getLoopID();
454 
455     // Create a loop for each partition except the last.  Clone the original
456     // loop before PH along with adding a preheader for the cloned loop.  Then
457     // update PH to point to the newly added preheader.
458     BasicBlock *TopPH = OrigPH;
459     unsigned Index = getSize() - 1;
460     for (auto &Part : llvm::drop_begin(llvm::reverse(PartitionContainer))) {
461       NewLoop = Part.cloneLoopWithPreheader(TopPH, Pred, Index, LI, DT);
462 
463       Part.getVMap()[ExitBlock] = TopPH;
464       Part.remapInstructions();
465       setNewLoopID(OrigLoopID, &Part);
466       --Index;
467       TopPH = NewLoop->getLoopPreheader();
468     }
469     Pred->getTerminator()->replaceUsesOfWith(OrigPH, TopPH);
470 
471     // Also set a new loop ID for the last loop.
472     setNewLoopID(OrigLoopID, &PartitionContainer.back());
473 
474     // Now go in forward order and update the immediate dominator for the
475     // preheaders with the exiting block of the previous loop.  Dominance
476     // within the loop is updated in cloneLoopWithPreheader.
477     for (auto Curr = PartitionContainer.cbegin(),
478               Next = std::next(PartitionContainer.cbegin()),
479               E = PartitionContainer.cend();
480          Next != E; ++Curr, ++Next)
481       DT->changeImmediateDominator(
482           Next->getDistributedLoop()->getLoopPreheader(),
483           Curr->getDistributedLoop()->getExitingBlock());
484   }
485 
486   /// Removes the dead instructions from the cloned loops.
487   void removeUnusedInsts() {
488     for (auto &Partition : PartitionContainer)
489       Partition.removeUnusedInsts();
490   }
491 
492   /// For each memory pointer, it computes the partitionId the pointer is
493   /// used in.
494   ///
495   /// This returns an array of int where the I-th entry corresponds to I-th
496   /// entry in LAI.getRuntimePointerCheck().  If the pointer is used in multiple
497   /// partitions its entry is set to -1.
498   SmallVector<int, 8>
499   computePartitionSetForPointers(const LoopAccessInfo &LAI) {
500     const RuntimePointerChecking *RtPtrCheck = LAI.getRuntimePointerChecking();
501 
502     unsigned N = RtPtrCheck->Pointers.size();
503     SmallVector<int, 8> PtrToPartitions(N);
504     for (unsigned I = 0; I < N; ++I) {
505       Value *Ptr = RtPtrCheck->Pointers[I].PointerValue;
506       auto Instructions =
507           LAI.getInstructionsForAccess(Ptr, RtPtrCheck->Pointers[I].IsWritePtr);
508 
509       int &Partition = PtrToPartitions[I];
510       // First set it to uninitialized.
511       Partition = -2;
512       for (Instruction *Inst : Instructions) {
513         // Note that this could be -1 if Inst is duplicated across multiple
514         // partitions.
515         int ThisPartition = this->InstToPartitionId[Inst];
516         if (Partition == -2)
517           Partition = ThisPartition;
518         // -1 means belonging to multiple partitions.
519         else if (Partition == -1)
520           break;
521         else if (Partition != (int)ThisPartition)
522           Partition = -1;
523       }
524       assert(Partition != -2 && "Pointer not belonging to any partition");
525     }
526 
527     return PtrToPartitions;
528   }
529 
530   void print(raw_ostream &OS) const {
531     unsigned Index = 0;
532     for (const auto &P : PartitionContainer) {
533       OS << "Partition " << Index++ << " (" << &P << "):\n";
534       P.print();
535     }
536   }
537 
538   void dump() const { print(dbgs()); }
539 
540 #ifndef NDEBUG
541   friend raw_ostream &operator<<(raw_ostream &OS,
542                                  const InstPartitionContainer &Partitions) {
543     Partitions.print(OS);
544     return OS;
545   }
546 #endif
547 
548   void printBlocks() const {
549     unsigned Index = 0;
550     for (const auto &P : PartitionContainer) {
551       dbgs() << "\nPartition " << Index++ << " (" << &P << "):\n";
552       P.printBlocks();
553     }
554   }
555 
556 private:
557   using PartitionContainerT = std::list<InstPartition>;
558 
559   /// List of partitions.
560   PartitionContainerT PartitionContainer;
561 
562   /// Mapping from Instruction to partition Id.  If the instruction
563   /// belongs to multiple partitions the entry contains -1.
564   InstToPartitionIdT InstToPartitionId;
565 
566   Loop *L;
567   LoopInfo *LI;
568   DominatorTree *DT;
569 
570   /// The control structure to merge adjacent partitions if both satisfy
571   /// the \p Predicate.
572   template <class UnaryPredicate>
573   void mergeAdjacentPartitionsIf(UnaryPredicate Predicate) {
574     InstPartition *PrevMatch = nullptr;
575     for (auto I = PartitionContainer.begin(); I != PartitionContainer.end();) {
576       auto DoesMatch = Predicate(&*I);
577       if (PrevMatch == nullptr && DoesMatch) {
578         PrevMatch = &*I;
579         ++I;
580       } else if (PrevMatch != nullptr && DoesMatch) {
581         I->moveTo(*PrevMatch);
582         I = PartitionContainer.erase(I);
583       } else {
584         PrevMatch = nullptr;
585         ++I;
586       }
587     }
588   }
589 
590   /// Assign new LoopIDs for the partition's cloned loop.
591   void setNewLoopID(MDNode *OrigLoopID, InstPartition *Part) {
592     std::optional<MDNode *> PartitionID = makeFollowupLoopID(
593         OrigLoopID,
594         {LLVMLoopDistributeFollowupAll,
595          Part->hasDepCycle() ? LLVMLoopDistributeFollowupSequential
596                              : LLVMLoopDistributeFollowupCoincident});
597     if (PartitionID) {
598       Loop *NewLoop = Part->getDistributedLoop();
599       NewLoop->setLoopID(*PartitionID);
600     }
601   }
602 };
603 
604 /// For each memory instruction, this class maintains difference of the
605 /// number of unsafe dependences that start out from this instruction minus
606 /// those that end here.
607 ///
608 /// By traversing the memory instructions in program order and accumulating this
609 /// number, we know whether any unsafe dependence crosses over a program point.
610 class MemoryInstructionDependences {
611   using Dependence = MemoryDepChecker::Dependence;
612 
613 public:
614   struct Entry {
615     Instruction *Inst;
616     unsigned NumUnsafeDependencesStartOrEnd = 0;
617 
618     Entry(Instruction *Inst) : Inst(Inst) {}
619   };
620 
621   using AccessesType = SmallVector<Entry, 8>;
622 
623   AccessesType::const_iterator begin() const { return Accesses.begin(); }
624   AccessesType::const_iterator end() const { return Accesses.end(); }
625 
626   MemoryInstructionDependences(
627       const SmallVectorImpl<Instruction *> &Instructions,
628       const SmallVectorImpl<Dependence> &Dependences) {
629     Accesses.append(Instructions.begin(), Instructions.end());
630 
631     LLVM_DEBUG(dbgs() << "Backward dependences:\n");
632     for (const auto &Dep : Dependences)
633       if (Dep.isPossiblyBackward()) {
634         // Note that the designations source and destination follow the program
635         // order, i.e. source is always first.  (The direction is given by the
636         // DepType.)
637         ++Accesses[Dep.Source].NumUnsafeDependencesStartOrEnd;
638         --Accesses[Dep.Destination].NumUnsafeDependencesStartOrEnd;
639 
640         LLVM_DEBUG(Dep.print(dbgs(), 2, Instructions));
641       }
642   }
643 
644 private:
645   AccessesType Accesses;
646 };
647 
648 /// The actual class performing the per-loop work.
649 class LoopDistributeForLoop {
650 public:
651   LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT,
652                         ScalarEvolution *SE, LoopAccessInfoManager &LAIs,
653                         OptimizationRemarkEmitter *ORE)
654       : L(L), F(F), LI(LI), DT(DT), SE(SE), LAIs(LAIs), ORE(ORE) {
655     setForced();
656   }
657 
658   /// Try to distribute an inner-most loop.
659   bool processLoop() {
660     assert(L->isInnermost() && "Only process inner loops.");
661 
662     LLVM_DEBUG(dbgs() << "\nLDist: In \""
663                       << L->getHeader()->getParent()->getName()
664                       << "\" checking " << *L << "\n");
665 
666     // Having a single exit block implies there's also one exiting block.
667     if (!L->getExitBlock())
668       return fail("MultipleExitBlocks", "multiple exit blocks");
669     if (!L->isLoopSimplifyForm())
670       return fail("NotLoopSimplifyForm",
671                   "loop is not in loop-simplify form");
672     if (!L->isRotatedForm())
673       return fail("NotBottomTested", "loop is not bottom tested");
674 
675     BasicBlock *PH = L->getLoopPreheader();
676 
677     LAI = &LAIs.getInfo(*L);
678 
679     // Currently, we only distribute to isolate the part of the loop with
680     // dependence cycles to enable partial vectorization.
681     if (LAI->canVectorizeMemory())
682       return fail("MemOpsCanBeVectorized",
683                   "memory operations are safe for vectorization");
684 
685     auto *Dependences = LAI->getDepChecker().getDependences();
686     if (!Dependences || Dependences->empty())
687       return fail("NoUnsafeDeps", "no unsafe dependences to isolate");
688 
689     InstPartitionContainer Partitions(L, LI, DT);
690 
691     // First, go through each memory operation and assign them to consecutive
692     // partitions (the order of partitions follows program order).  Put those
693     // with unsafe dependences into "cyclic" partition otherwise put each store
694     // in its own "non-cyclic" partition (we'll merge these later).
695     //
696     // Note that a memory operation (e.g. Load2 below) at a program point that
697     // has an unsafe dependence (Store3->Load1) spanning over it must be
698     // included in the same cyclic partition as the dependent operations.  This
699     // is to preserve the original program order after distribution.  E.g.:
700     //
701     //                NumUnsafeDependencesStartOrEnd  NumUnsafeDependencesActive
702     //  Load1   -.                     1                       0->1
703     //  Load2    | /Unsafe/            0                       1
704     //  Store3  -'                    -1                       1->0
705     //  Load4                          0                       0
706     //
707     // NumUnsafeDependencesActive > 0 indicates this situation and in this case
708     // we just keep assigning to the same cyclic partition until
709     // NumUnsafeDependencesActive reaches 0.
710     const MemoryDepChecker &DepChecker = LAI->getDepChecker();
711     MemoryInstructionDependences MID(DepChecker.getMemoryInstructions(),
712                                      *Dependences);
713 
714     int NumUnsafeDependencesActive = 0;
715     for (const auto &InstDep : MID) {
716       Instruction *I = InstDep.Inst;
717       // We update NumUnsafeDependencesActive post-instruction, catch the
718       // start of a dependence directly via NumUnsafeDependencesStartOrEnd.
719       if (NumUnsafeDependencesActive ||
720           InstDep.NumUnsafeDependencesStartOrEnd > 0)
721         Partitions.addToCyclicPartition(I);
722       else
723         Partitions.addToNewNonCyclicPartition(I);
724       NumUnsafeDependencesActive += InstDep.NumUnsafeDependencesStartOrEnd;
725       assert(NumUnsafeDependencesActive >= 0 &&
726              "Negative number of dependences active");
727     }
728 
729     // Add partitions for values used outside.  These partitions can be out of
730     // order from the original program order.  This is OK because if the
731     // partition uses a load we will merge this partition with the original
732     // partition of the load that we set up in the previous loop (see
733     // mergeToAvoidDuplicatedLoads).
734     auto DefsUsedOutside = findDefsUsedOutsideOfLoop(L);
735     for (auto *Inst : DefsUsedOutside)
736       Partitions.addToNewNonCyclicPartition(Inst);
737 
738     LLVM_DEBUG(dbgs() << "Seeded partitions:\n" << Partitions);
739     if (Partitions.getSize() < 2)
740       return fail("CantIsolateUnsafeDeps",
741                   "cannot isolate unsafe dependencies");
742 
743     // Run the merge heuristics: Merge non-cyclic adjacent partitions since we
744     // should be able to vectorize these together.
745     Partitions.mergeBeforePopulating();
746     LLVM_DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions);
747     if (Partitions.getSize() < 2)
748       return fail("CantIsolateUnsafeDeps",
749                   "cannot isolate unsafe dependencies");
750 
751     // Now, populate the partitions with non-memory operations.
752     Partitions.populateUsedSet();
753     LLVM_DEBUG(dbgs() << "\nPopulated partitions:\n" << Partitions);
754 
755     // In order to preserve original lexical order for loads, keep them in the
756     // partition that we set up in the MemoryInstructionDependences loop.
757     if (Partitions.mergeToAvoidDuplicatedLoads()) {
758       LLVM_DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n"
759                         << Partitions);
760       if (Partitions.getSize() < 2)
761         return fail("CantIsolateUnsafeDeps",
762                     "cannot isolate unsafe dependencies");
763     }
764 
765     // Don't distribute the loop if we need too many SCEV run-time checks, or
766     // any if it's illegal.
767     const SCEVPredicate &Pred = LAI->getPSE().getPredicate();
768     if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) {
769       return fail("RuntimeCheckWithConvergent",
770                   "may not insert runtime check with convergent operation");
771     }
772 
773     if (Pred.getComplexity() > (IsForced.value_or(false)
774                                     ? PragmaDistributeSCEVCheckThreshold
775                                     : DistributeSCEVCheckThreshold))
776       return fail("TooManySCEVRuntimeChecks",
777                   "too many SCEV run-time checks needed.\n");
778 
779     if (!IsForced.value_or(false) && hasDisableAllTransformsHint(L))
780       return fail("HeuristicDisabled", "distribution heuristic disabled");
781 
782     LLVM_DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n");
783     // We're done forming the partitions set up the reverse mapping from
784     // instructions to partitions.
785     Partitions.setupPartitionIdOnInstructions();
786 
787     // If we need run-time checks, version the loop now.
788     auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI);
789     const auto *RtPtrChecking = LAI->getRuntimePointerChecking();
790     const auto &AllChecks = RtPtrChecking->getChecks();
791     auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition,
792                                                   RtPtrChecking);
793 
794     if (LAI->hasConvergentOp() && !Checks.empty()) {
795       return fail("RuntimeCheckWithConvergent",
796                   "may not insert runtime check with convergent operation");
797     }
798 
799     // To keep things simple have an empty preheader before we version or clone
800     // the loop.  (Also split if this has no predecessor, i.e. entry, because we
801     // rely on PH having a predecessor.)
802     if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
803       SplitBlock(PH, PH->getTerminator(), DT, LI);
804 
805     if (!Pred.isAlwaysTrue() || !Checks.empty()) {
806       assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning");
807 
808       MDNode *OrigLoopID = L->getLoopID();
809 
810       LLVM_DEBUG(dbgs() << "\nPointers:\n");
811       LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks));
812       LoopVersioning LVer(*LAI, Checks, L, LI, DT, SE);
813       LVer.versionLoop(DefsUsedOutside);
814       LVer.annotateLoopWithNoAlias();
815 
816       // The unversioned loop will not be changed, so we inherit all attributes
817       // from the original loop, but remove the loop distribution metadata to
818       // avoid to distribute it again.
819       MDNode *UnversionedLoopID = *makeFollowupLoopID(
820           OrigLoopID,
821           {LLVMLoopDistributeFollowupAll, LLVMLoopDistributeFollowupFallback},
822           "llvm.loop.distribute.", true);
823       LVer.getNonVersionedLoop()->setLoopID(UnversionedLoopID);
824     }
825 
826     // Create identical copies of the original loop for each partition and hook
827     // them up sequentially.
828     Partitions.cloneLoops();
829 
830     // Now, we remove the instruction from each loop that don't belong to that
831     // partition.
832     Partitions.removeUnusedInsts();
833     LLVM_DEBUG(dbgs() << "\nAfter removing unused Instrs:\n");
834     LLVM_DEBUG(Partitions.printBlocks());
835 
836     if (LDistVerify) {
837       LI->verify(*DT);
838       assert(DT->verify(DominatorTree::VerificationLevel::Fast));
839     }
840 
841     ++NumLoopsDistributed;
842     // Report the success.
843     ORE->emit([&]() {
844       return OptimizationRemark(LDIST_NAME, "Distribute", L->getStartLoc(),
845                                 L->getHeader())
846              << "distributed loop";
847     });
848     return true;
849   }
850 
851   /// Provide diagnostics then \return with false.
852   bool fail(StringRef RemarkName, StringRef Message) {
853     LLVMContext &Ctx = F->getContext();
854     bool Forced = isForced().value_or(false);
855 
856     LLVM_DEBUG(dbgs() << "Skipping; " << Message << "\n");
857 
858     // With Rpass-missed report that distribution failed.
859     ORE->emit([&]() {
860       return OptimizationRemarkMissed(LDIST_NAME, "NotDistributed",
861                                       L->getStartLoc(), L->getHeader())
862              << "loop not distributed: use -Rpass-analysis=loop-distribute for "
863                 "more "
864                 "info";
865     });
866 
867     // With Rpass-analysis report why.  This is on by default if distribution
868     // was requested explicitly.
869     ORE->emit(OptimizationRemarkAnalysis(
870                   Forced ? OptimizationRemarkAnalysis::AlwaysPrint : LDIST_NAME,
871                   RemarkName, L->getStartLoc(), L->getHeader())
872               << "loop not distributed: " << Message);
873 
874     // Also issue a warning if distribution was requested explicitly but it
875     // failed.
876     if (Forced)
877       Ctx.diagnose(DiagnosticInfoOptimizationFailure(
878           *F, L->getStartLoc(), "loop not distributed: failed "
879                                 "explicitly specified loop distribution"));
880 
881     return false;
882   }
883 
884   /// Return if distribution forced to be enabled/disabled for the loop.
885   ///
886   /// If the optional has a value, it indicates whether distribution was forced
887   /// to be enabled (true) or disabled (false).  If the optional has no value
888   /// distribution was not forced either way.
889   const std::optional<bool> &isForced() const { return IsForced; }
890 
891 private:
892   /// Filter out checks between pointers from the same partition.
893   ///
894   /// \p PtrToPartition contains the partition number for pointers.  Partition
895   /// number -1 means that the pointer is used in multiple partitions.  In this
896   /// case we can't safely omit the check.
897   SmallVector<RuntimePointerCheck, 4> includeOnlyCrossPartitionChecks(
898       const SmallVectorImpl<RuntimePointerCheck> &AllChecks,
899       const SmallVectorImpl<int> &PtrToPartition,
900       const RuntimePointerChecking *RtPtrChecking) {
901     SmallVector<RuntimePointerCheck, 4> Checks;
902 
903     copy_if(AllChecks, std::back_inserter(Checks),
904             [&](const RuntimePointerCheck &Check) {
905               for (unsigned PtrIdx1 : Check.first->Members)
906                 for (unsigned PtrIdx2 : Check.second->Members)
907                   // Only include this check if there is a pair of pointers
908                   // that require checking and the pointers fall into
909                   // separate partitions.
910                   //
911                   // (Note that we already know at this point that the two
912                   // pointer groups need checking but it doesn't follow
913                   // that each pair of pointers within the two groups need
914                   // checking as well.
915                   //
916                   // In other words we don't want to include a check just
917                   // because there is a pair of pointers between the two
918                   // pointer groups that require checks and a different
919                   // pair whose pointers fall into different partitions.)
920                   if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) &&
921                       !RuntimePointerChecking::arePointersInSamePartition(
922                           PtrToPartition, PtrIdx1, PtrIdx2))
923                     return true;
924               return false;
925             });
926 
927     return Checks;
928   }
929 
930   /// Check whether the loop metadata is forcing distribution to be
931   /// enabled/disabled.
932   void setForced() {
933     std::optional<const MDOperand *> Value =
934         findStringMetadataForLoop(L, "llvm.loop.distribute.enable");
935     if (!Value)
936       return;
937 
938     const MDOperand *Op = *Value;
939     assert(Op && mdconst::hasa<ConstantInt>(*Op) && "invalid metadata");
940     IsForced = mdconst::extract<ConstantInt>(*Op)->getZExtValue();
941   }
942 
943   Loop *L;
944   Function *F;
945 
946   // Analyses used.
947   LoopInfo *LI;
948   const LoopAccessInfo *LAI = nullptr;
949   DominatorTree *DT;
950   ScalarEvolution *SE;
951   LoopAccessInfoManager &LAIs;
952   OptimizationRemarkEmitter *ORE;
953 
954   /// Indicates whether distribution is forced to be enabled/disabled for
955   /// the loop.
956   ///
957   /// If the optional has a value, it indicates whether distribution was forced
958   /// to be enabled (true) or disabled (false).  If the optional has no value
959   /// distribution was not forced either way.
960   std::optional<bool> IsForced;
961 };
962 
963 } // end anonymous namespace
964 
965 /// Shared implementation between new and old PMs.
966 static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT,
967                     ScalarEvolution *SE, OptimizationRemarkEmitter *ORE,
968                     LoopAccessInfoManager &LAIs) {
969   // Build up a worklist of inner-loops to vectorize. This is necessary as the
970   // act of distributing a loop creates new loops and can invalidate iterators
971   // across the loops.
972   SmallVector<Loop *, 8> Worklist;
973 
974   for (Loop *TopLevelLoop : *LI)
975     for (Loop *L : depth_first(TopLevelLoop))
976       // We only handle inner-most loops.
977       if (L->isInnermost())
978         Worklist.push_back(L);
979 
980   // Now walk the identified inner loops.
981   bool Changed = false;
982   for (Loop *L : Worklist) {
983     LoopDistributeForLoop LDL(L, &F, LI, DT, SE, LAIs, ORE);
984 
985     // If distribution was forced for the specific loop to be
986     // enabled/disabled, follow that.  Otherwise use the global flag.
987     if (LDL.isForced().value_or(EnableLoopDistribute))
988       Changed |= LDL.processLoop();
989   }
990 
991   // Process each loop nest in the function.
992   return Changed;
993 }
994 
995 PreservedAnalyses LoopDistributePass::run(Function &F,
996                                           FunctionAnalysisManager &AM) {
997   auto &LI = AM.getResult<LoopAnalysis>(F);
998   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
999   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
1000   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
1001 
1002   LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F);
1003   bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, LAIs);
1004   if (!Changed)
1005     return PreservedAnalyses::all();
1006   PreservedAnalyses PA;
1007   PA.preserve<LoopAnalysis>();
1008   PA.preserve<DominatorTreeAnalysis>();
1009   return PA;
1010 }
1011