xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision a85404906bc8f402318524b4ccd196712fc09fbd)
1  //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements loop unroll and jam as a routine, much like
10  // LoopUnroll.cpp implements loop unroll.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/ADT/ArrayRef.h"
15  #include "llvm/ADT/DenseMap.h"
16  #include "llvm/ADT/Optional.h"
17  #include "llvm/ADT/STLExtras.h"
18  #include "llvm/ADT/Sequence.h"
19  #include "llvm/ADT/SmallPtrSet.h"
20  #include "llvm/ADT/SmallVector.h"
21  #include "llvm/ADT/Statistic.h"
22  #include "llvm/ADT/StringRef.h"
23  #include "llvm/ADT/Twine.h"
24  #include "llvm/ADT/iterator_range.h"
25  #include "llvm/Analysis/AssumptionCache.h"
26  #include "llvm/Analysis/DependenceAnalysis.h"
27  #include "llvm/Analysis/DomTreeUpdater.h"
28  #include "llvm/Analysis/LoopInfo.h"
29  #include "llvm/Analysis/LoopIterator.h"
30  #include "llvm/Analysis/MustExecute.h"
31  #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32  #include "llvm/Analysis/ScalarEvolution.h"
33  #include "llvm/IR/BasicBlock.h"
34  #include "llvm/IR/DebugInfoMetadata.h"
35  #include "llvm/IR/DebugLoc.h"
36  #include "llvm/IR/DiagnosticInfo.h"
37  #include "llvm/IR/Dominators.h"
38  #include "llvm/IR/Function.h"
39  #include "llvm/IR/Instruction.h"
40  #include "llvm/IR/Instructions.h"
41  #include "llvm/IR/IntrinsicInst.h"
42  #include "llvm/IR/Use.h"
43  #include "llvm/IR/User.h"
44  #include "llvm/IR/Value.h"
45  #include "llvm/IR/ValueHandle.h"
46  #include "llvm/IR/ValueMap.h"
47  #include "llvm/Support/Casting.h"
48  #include "llvm/Support/Debug.h"
49  #include "llvm/Support/ErrorHandling.h"
50  #include "llvm/Support/GenericDomTree.h"
51  #include "llvm/Support/raw_ostream.h"
52  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53  #include "llvm/Transforms/Utils/Cloning.h"
54  #include "llvm/Transforms/Utils/LoopUtils.h"
55  #include "llvm/Transforms/Utils/UnrollLoop.h"
56  #include "llvm/Transforms/Utils/ValueMapper.h"
57  #include <assert.h>
58  #include <memory>
59  #include <type_traits>
60  #include <vector>
61  
62  using namespace llvm;
63  
64  #define DEBUG_TYPE "loop-unroll-and-jam"
65  
66  STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67  STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68  
69  typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70  
71  // Partition blocks in an outer/inner loop pair into blocks before and after
72  // the loop
73  static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                  BasicBlockSet &AftBlocks, DominatorTree &DT) {
75    Loop *SubLoop = L.getSubLoops()[0];
76    BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77  
78    for (BasicBlock *BB : L.blocks()) {
79      if (!SubLoop->contains(BB)) {
80        if (DT.dominates(SubLoopLatch, BB))
81          AftBlocks.insert(BB);
82        else
83          ForeBlocks.insert(BB);
84      }
85    }
86  
87    // Check that all blocks in ForeBlocks together dominate the subloop
88    // TODO: This might ideally be done better with a dominator/postdominators.
89    BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90    for (BasicBlock *BB : ForeBlocks) {
91      if (BB == SubLoopPreHeader)
92        continue;
93      Instruction *TI = BB->getTerminator();
94      for (BasicBlock *Succ : successors(TI))
95        if (!ForeBlocks.count(Succ))
96          return false;
97    }
98  
99    return true;
100  }
101  
102  /// Partition blocks in a loop nest into blocks before and after each inner
103  /// loop.
104  static bool partitionOuterLoopBlocks(
105      Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106      DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107      DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108    JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109  
110    for (Loop *L : Root.getLoopsInPreorder()) {
111      if (L == &JamLoop)
112        break;
113  
114      if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115        return false;
116    }
117  
118    return true;
119  }
120  
121  // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122  // than 2 levels loop.
123  static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                       BasicBlockSet &ForeBlocks,
125                                       BasicBlockSet &SubLoopBlocks,
126                                       BasicBlockSet &AftBlocks,
127                                       DominatorTree *DT) {
128    SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129    return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130  }
131  
132  // Looks at the phi nodes in Header for values coming from Latch. For these
133  // instructions and all their operands calls Visit on them, keeping going for
134  // all the operands in AftBlocks. Returns false if Visit returns false,
135  // otherwise returns true. This is used to process the instructions in the
136  // Aft blocks that need to be moved before the subloop. It is used in two
137  // places. One to check that the required set of instructions can be moved
138  // before the loop. Then to collect the instructions to actually move in
139  // moveHeaderPhiOperandsToForeBlocks.
140  template <typename T>
141  static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                       BasicBlockSet &AftBlocks, T Visit) {
143    SmallVector<Instruction *, 8> Worklist;
144    for (auto &Phi : Header->phis()) {
145      Value *V = Phi.getIncomingValueForBlock(Latch);
146      if (Instruction *I = dyn_cast<Instruction>(V))
147        Worklist.push_back(I);
148    }
149  
150    while (!Worklist.empty()) {
151      Instruction *I = Worklist.pop_back_val();
152      if (!Visit(I))
153        return false;
154  
155      if (AftBlocks.count(I->getParent()))
156        for (auto &U : I->operands())
157          if (Instruction *II = dyn_cast<Instruction>(U))
158            Worklist.push_back(II);
159    }
160  
161    return true;
162  }
163  
164  // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
165  static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
166                                                BasicBlock *Latch,
167                                                Instruction *InsertLoc,
168                                                BasicBlockSet &AftBlocks) {
169    // We need to ensure we move the instructions in the correct order,
170    // starting with the earliest required instruction and moving forward.
171    std::vector<Instruction *> Visited;
172    processHeaderPhiOperands(Header, Latch, AftBlocks,
173                             [&Visited, &AftBlocks](Instruction *I) {
174                               if (AftBlocks.count(I->getParent()))
175                                 Visited.push_back(I);
176                               return true;
177                             });
178  
179    // Move all instructions in program order to before the InsertLoc
180    BasicBlock *InsertLocBB = InsertLoc->getParent();
181    for (Instruction *I : reverse(Visited)) {
182      if (I->getParent() != InsertLocBB)
183        I->moveBefore(InsertLoc);
184    }
185  }
186  
187  /*
188    This method performs Unroll and Jam. For a simple loop like:
189    for (i = ..)
190      Fore(i)
191      for (j = ..)
192        SubLoop(i, j)
193      Aft(i)
194  
195    Instead of doing normal inner or outer unrolling, we do:
196    for (i = .., i+=2)
197      Fore(i)
198      Fore(i+1)
199      for (j = ..)
200        SubLoop(i, j)
201        SubLoop(i+1, j)
202      Aft(i)
203      Aft(i+1)
204  
205    So the outer loop is essetially unrolled and then the inner loops are fused
206    ("jammed") together into a single loop. This can increase speed when there
207    are loads in SubLoop that are invariant to i, as they become shared between
208    the now jammed inner loops.
209  
210    We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
211    Fore blocks are those before the inner loop, Aft are those after. Normal
212    Unroll code is used to copy each of these sets of blocks and the results are
213    combined together into the final form above.
214  
215    isSafeToUnrollAndJam should be used prior to calling this to make sure the
216    unrolling will be valid. Checking profitablility is also advisable.
217  
218    If EpilogueLoop is non-null, it receives the epilogue loop (if it was
219    necessary to create one and not fully unrolled).
220  */
221  LoopUnrollResult
222  llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
223                         unsigned TripMultiple, bool UnrollRemainder,
224                         LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
225                         AssumptionCache *AC, const TargetTransformInfo *TTI,
226                         OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
227  
228    // When we enter here we should have already checked that it is safe
229    BasicBlock *Header = L->getHeader();
230    assert(Header && "No header.");
231    assert(L->getSubLoops().size() == 1);
232    Loop *SubLoop = *L->begin();
233  
234    // Don't enter the unroll code if there is nothing to do.
235    if (TripCount == 0 && Count < 2) {
236      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
237      return LoopUnrollResult::Unmodified;
238    }
239  
240    assert(Count > 0);
241    assert(TripMultiple > 0);
242    assert(TripCount == 0 || TripCount % TripMultiple == 0);
243  
244    // Are we eliminating the loop control altogether?
245    bool CompletelyUnroll = (Count == TripCount);
246  
247    // We use the runtime remainder in cases where we don't know trip multiple
248    if (TripMultiple == 1 || TripMultiple % Count != 0) {
249      if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
250                                      /*UseEpilogRemainder*/ true,
251                                      UnrollRemainder, /*ForgetAllSCEV*/ false,
252                                      LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
253        LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
254                             "generated when assuming runtime trip count\n");
255        return LoopUnrollResult::Unmodified;
256      }
257    }
258  
259    // Notify ScalarEvolution that the loop will be substantially changed,
260    // if not outright eliminated.
261    if (SE) {
262      SE->forgetLoop(L);
263      SE->forgetLoop(SubLoop);
264    }
265  
266    using namespace ore;
267    // Report the unrolling decision.
268    if (CompletelyUnroll) {
269      LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
270                        << Header->getName() << " with trip count " << TripCount
271                        << "!\n");
272      ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
273                                   L->getHeader())
274                << "completely unroll and jammed loop with "
275                << NV("UnrollCount", TripCount) << " iterations");
276    } else {
277      auto DiagBuilder = [&]() {
278        OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
279                                L->getHeader());
280        return Diag << "unroll and jammed loop by a factor of "
281                    << NV("UnrollCount", Count);
282      };
283  
284      LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
285                        << " by " << Count);
286      if (TripMultiple != 1) {
287        LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
288        ORE->emit([&]() {
289          return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
290                               << " trips per branch";
291        });
292      } else {
293        LLVM_DEBUG(dbgs() << " with run-time trip count");
294        ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
295      }
296      LLVM_DEBUG(dbgs() << "!\n");
297    }
298  
299    BasicBlock *Preheader = L->getLoopPreheader();
300    BasicBlock *LatchBlock = L->getLoopLatch();
301    assert(Preheader && "No preheader");
302    assert(LatchBlock && "No latch block");
303    BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
304    assert(BI && !BI->isUnconditional());
305    bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
306    BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
307    bool SubLoopContinueOnTrue = SubLoop->contains(
308        SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
309  
310    // Partition blocks in an outer/inner loop pair into blocks before and after
311    // the loop
312    BasicBlockSet SubLoopBlocks;
313    BasicBlockSet ForeBlocks;
314    BasicBlockSet AftBlocks;
315    partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
316                             DT);
317  
318    // We keep track of the entering/first and exiting/last block of each of
319    // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
320    // blocks easier.
321    std::vector<BasicBlock *> ForeBlocksFirst;
322    std::vector<BasicBlock *> ForeBlocksLast;
323    std::vector<BasicBlock *> SubLoopBlocksFirst;
324    std::vector<BasicBlock *> SubLoopBlocksLast;
325    std::vector<BasicBlock *> AftBlocksFirst;
326    std::vector<BasicBlock *> AftBlocksLast;
327    ForeBlocksFirst.push_back(Header);
328    ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
329    SubLoopBlocksFirst.push_back(SubLoop->getHeader());
330    SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
331    AftBlocksFirst.push_back(SubLoop->getExitBlock());
332    AftBlocksLast.push_back(L->getExitingBlock());
333    // Maps Blocks[0] -> Blocks[It]
334    ValueToValueMapTy LastValueMap;
335  
336    // Move any instructions from fore phi operands from AftBlocks into Fore.
337    moveHeaderPhiOperandsToForeBlocks(
338        Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
339  
340    // The current on-the-fly SSA update requires blocks to be processed in
341    // reverse postorder so that LastValueMap contains the correct value at each
342    // exit.
343    LoopBlocksDFS DFS(L);
344    DFS.perform(LI);
345    // Stash the DFS iterators before adding blocks to the loop.
346    LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
347    LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
348  
349    if (Header->getParent()->isDebugInfoForProfiling())
350      for (BasicBlock *BB : L->getBlocks())
351        for (Instruction &I : *BB)
352          if (!isa<DbgInfoIntrinsic>(&I))
353            if (const DILocation *DIL = I.getDebugLoc()) {
354              auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
355              if (NewDIL)
356                I.setDebugLoc(NewDIL.getValue());
357              else
358                LLVM_DEBUG(dbgs()
359                           << "Failed to create new discriminator: "
360                           << DIL->getFilename() << " Line: " << DIL->getLine());
361            }
362  
363    // Copy all blocks
364    for (unsigned It = 1; It != Count; ++It) {
365      SmallVector<BasicBlock *, 8> NewBlocks;
366      // Maps Blocks[It] -> Blocks[It-1]
367      DenseMap<Value *, Value *> PrevItValueMap;
368      SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
369      NewLoops[L] = L;
370      NewLoops[SubLoop] = SubLoop;
371  
372      for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
373        ValueToValueMapTy VMap;
374        BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
375        Header->getParent()->getBasicBlockList().push_back(New);
376  
377        // Tell LI about New.
378        addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
379  
380        if (ForeBlocks.count(*BB)) {
381          if (*BB == ForeBlocksFirst[0])
382            ForeBlocksFirst.push_back(New);
383          if (*BB == ForeBlocksLast[0])
384            ForeBlocksLast.push_back(New);
385        } else if (SubLoopBlocks.count(*BB)) {
386          if (*BB == SubLoopBlocksFirst[0])
387            SubLoopBlocksFirst.push_back(New);
388          if (*BB == SubLoopBlocksLast[0])
389            SubLoopBlocksLast.push_back(New);
390        } else if (AftBlocks.count(*BB)) {
391          if (*BB == AftBlocksFirst[0])
392            AftBlocksFirst.push_back(New);
393          if (*BB == AftBlocksLast[0])
394            AftBlocksLast.push_back(New);
395        } else {
396          llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
397        }
398  
399        // Update our running maps of newest clones
400        PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
401        LastValueMap[*BB] = New;
402        for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
403             VI != VE; ++VI) {
404          PrevItValueMap[VI->second] =
405              const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
406          LastValueMap[VI->first] = VI->second;
407        }
408  
409        NewBlocks.push_back(New);
410  
411        // Update DomTree:
412        if (*BB == ForeBlocksFirst[0])
413          DT->addNewBlock(New, ForeBlocksLast[It - 1]);
414        else if (*BB == SubLoopBlocksFirst[0])
415          DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
416        else if (*BB == AftBlocksFirst[0])
417          DT->addNewBlock(New, AftBlocksLast[It - 1]);
418        else {
419          // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
420          // structure.
421          auto BBDomNode = DT->getNode(*BB);
422          auto BBIDom = BBDomNode->getIDom();
423          BasicBlock *OriginalBBIDom = BBIDom->getBlock();
424          assert(OriginalBBIDom);
425          assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
426          DT->addNewBlock(
427              New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
428        }
429      }
430  
431      // Remap all instructions in the most recent iteration
432      remapInstructionsInBlocks(NewBlocks, LastValueMap);
433      for (BasicBlock *NewBlock : NewBlocks) {
434        for (Instruction &I : *NewBlock) {
435          if (auto *II = dyn_cast<IntrinsicInst>(&I))
436            if (II->getIntrinsicID() == Intrinsic::assume)
437              AC->registerAssumption(II);
438        }
439      }
440  
441      // Alter the ForeBlocks phi's, pointing them at the latest version of the
442      // value from the previous iteration's phis
443      for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
444        Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
445        assert(OldValue && "should have incoming edge from Aft[It]");
446        Value *NewValue = OldValue;
447        if (Value *PrevValue = PrevItValueMap[OldValue])
448          NewValue = PrevValue;
449  
450        assert(Phi.getNumOperands() == 2);
451        Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
452        Phi.setIncomingValue(0, NewValue);
453        Phi.removeIncomingValue(1);
454      }
455    }
456  
457    // Now that all the basic blocks for the unrolled iterations are in place,
458    // finish up connecting the blocks and phi nodes. At this point LastValueMap
459    // is the last unrolled iterations values.
460  
461    // Update Phis in BB from OldBB to point to NewBB and use the latest value
462    // from LastValueMap
463    auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
464                                       BasicBlock *NewBB,
465                                       ValueToValueMapTy &LastValueMap) {
466      for (PHINode &Phi : BB->phis()) {
467        for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
468          if (Phi.getIncomingBlock(b) == OldBB) {
469            Value *OldValue = Phi.getIncomingValue(b);
470            if (Value *LastValue = LastValueMap[OldValue])
471              Phi.setIncomingValue(b, LastValue);
472            Phi.setIncomingBlock(b, NewBB);
473            break;
474          }
475        }
476      }
477    };
478    // Move all the phis from Src into Dest
479    auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
480      Instruction *insertPoint = Dest->getFirstNonPHI();
481      while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
482        Phi->moveBefore(insertPoint);
483    };
484  
485    // Update the PHI values outside the loop to point to the last block
486    updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
487                             LastValueMap);
488  
489    // Update ForeBlocks successors and phi nodes
490    BranchInst *ForeTerm =
491        cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
492    assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
493    ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
494  
495    if (CompletelyUnroll) {
496      while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
497        Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
498        Phi->getParent()->getInstList().erase(Phi);
499      }
500    } else {
501      // Update the PHI values to point to the last aft block
502      updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
503                               AftBlocksLast.back(), LastValueMap);
504    }
505  
506    for (unsigned It = 1; It != Count; It++) {
507      // Remap ForeBlock successors from previous iteration to this
508      BranchInst *ForeTerm =
509          cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
510      assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
511      ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
512    }
513  
514    // Subloop successors and phis
515    BranchInst *SubTerm =
516        cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
517    SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
518    SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
519    SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
520                                              ForeBlocksLast.back());
521    SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
522                                              SubLoopBlocksLast.back());
523  
524    for (unsigned It = 1; It != Count; It++) {
525      // Replace the conditional branch of the previous iteration subloop with an
526      // unconditional one to this one
527      BranchInst *SubTerm =
528          cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
529      BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
530      SubTerm->eraseFromParent();
531  
532      SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
533                                                 ForeBlocksLast.back());
534      SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
535                                                 SubLoopBlocksLast.back());
536      movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
537    }
538  
539    // Aft blocks successors and phis
540    BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
541    if (CompletelyUnroll) {
542      BranchInst::Create(LoopExit, AftTerm);
543      AftTerm->eraseFromParent();
544    } else {
545      AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
546      assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
547             "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
548    }
549    AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
550                                          SubLoopBlocksLast.back());
551  
552    for (unsigned It = 1; It != Count; It++) {
553      // Replace the conditional branch of the previous iteration subloop with an
554      // unconditional one to this one
555      BranchInst *AftTerm =
556          cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
557      BranchInst::Create(AftBlocksFirst[It], AftTerm);
558      AftTerm->eraseFromParent();
559  
560      AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
561                                             SubLoopBlocksLast.back());
562      movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
563    }
564  
565    DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
566    // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
567    // new ones required.
568    if (Count != 1) {
569      SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
570      DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
571                             SubLoopBlocksFirst[0]);
572      DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
573                             SubLoopBlocksLast[0], AftBlocksFirst[0]);
574  
575      DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
576                             ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
577      DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
578                             SubLoopBlocksLast.back(), AftBlocksFirst[0]);
579      DTU.applyUpdatesPermissive(DTUpdates);
580    }
581  
582    // Merge adjacent basic blocks, if possible.
583    SmallPtrSet<BasicBlock *, 16> MergeBlocks;
584    MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
585    MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
586    MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
587  
588    MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
589  
590    // Apply updates to the DomTree.
591    DT = &DTU.getDomTree();
592  
593    // At this point, the code is well formed.  We now do a quick sweep over the
594    // inserted code, doing constant propagation and dead code elimination as we
595    // go.
596    simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
597    simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
598                            TTI);
599  
600    NumCompletelyUnrolledAndJammed += CompletelyUnroll;
601    ++NumUnrolledAndJammed;
602  
603    // Update LoopInfo if the loop is completely removed.
604    if (CompletelyUnroll)
605      LI->erase(L);
606  
607  #ifndef NDEBUG
608    // We shouldn't have done anything to break loop simplify form or LCSSA.
609    Loop *OutestLoop = SubLoop->getParentLoop()
610                           ? SubLoop->getParentLoop()->getParentLoop()
611                                 ? SubLoop->getParentLoop()->getParentLoop()
612                                 : SubLoop->getParentLoop()
613                           : SubLoop;
614    assert(DT->verify());
615    LI->verify(*DT);
616    assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
617    if (!CompletelyUnroll)
618      assert(L->isLoopSimplifyForm());
619    assert(SubLoop->isLoopSimplifyForm());
620    SE->verify();
621  #endif
622  
623    return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
624                            : LoopUnrollResult::PartiallyUnrolled;
625  }
626  
627  static bool getLoadsAndStores(BasicBlockSet &Blocks,
628                                SmallVector<Instruction *, 4> &MemInstr) {
629    // Scan the BBs and collect legal loads and stores.
630    // Returns false if non-simple loads/stores are found.
631    for (BasicBlock *BB : Blocks) {
632      for (Instruction &I : *BB) {
633        if (auto *Ld = dyn_cast<LoadInst>(&I)) {
634          if (!Ld->isSimple())
635            return false;
636          MemInstr.push_back(&I);
637        } else if (auto *St = dyn_cast<StoreInst>(&I)) {
638          if (!St->isSimple())
639            return false;
640          MemInstr.push_back(&I);
641        } else if (I.mayReadOrWriteMemory()) {
642          return false;
643        }
644      }
645    }
646    return true;
647  }
648  
649  static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
650                                         unsigned UnrollLevel, unsigned JamLevel,
651                                         bool Sequentialized, Dependence *D) {
652    // UnrollLevel might carry the dependency Src --> Dst
653    // Does a different loop after unrolling?
654    for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
655         ++CurLoopDepth) {
656      auto JammedDir = D->getDirection(CurLoopDepth);
657      if (JammedDir == Dependence::DVEntry::LT)
658        return true;
659  
660      if (JammedDir & Dependence::DVEntry::GT)
661        return false;
662    }
663  
664    return true;
665  }
666  
667  static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
668                                          unsigned UnrollLevel, unsigned JamLevel,
669                                          bool Sequentialized, Dependence *D) {
670    // UnrollLevel might carry the dependency Dst --> Src
671    for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
672         ++CurLoopDepth) {
673      auto JammedDir = D->getDirection(CurLoopDepth);
674      if (JammedDir == Dependence::DVEntry::GT)
675        return true;
676  
677      if (JammedDir & Dependence::DVEntry::LT)
678        return false;
679    }
680  
681    // Backward dependencies are only preserved if not interleaved.
682    return Sequentialized;
683  }
684  
685  // Check whether it is semantically safe Src and Dst considering any potential
686  // dependency between them.
687  //
688  // @param UnrollLevel The level of the loop being unrolled
689  // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
690  // different levels, the outermost common loop counts as jammed level
691  //
692  // @return true if is safe and false if there is a dependency violation.
693  static bool checkDependency(Instruction *Src, Instruction *Dst,
694                              unsigned UnrollLevel, unsigned JamLevel,
695                              bool Sequentialized, DependenceInfo &DI) {
696    assert(UnrollLevel <= JamLevel &&
697           "Expecting JamLevel to be at least UnrollLevel");
698  
699    if (Src == Dst)
700      return true;
701    // Ignore Input dependencies.
702    if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
703      return true;
704  
705    // Check whether unroll-and-jam may violate a dependency.
706    // By construction, every dependency will be lexicographically non-negative
707    // (if it was, it would violate the current execution order), such as
708    //   (0,0,>,*,*)
709    // Unroll-and-jam changes the GT execution of two executions to the same
710    // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
711    // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
712    //   (0,0,>=,*,*)
713    // Now, the dependency is not necessarily non-negative anymore, i.e.
714    // unroll-and-jam may violate correctness.
715    std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
716    if (!D)
717      return true;
718    assert(D->isOrdered() && "Expected an output, flow or anti dep.");
719  
720    if (D->isConfused()) {
721      LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
722                        << "  " << *Src << "\n"
723                        << "  " << *Dst << "\n");
724      return false;
725    }
726  
727    // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
728    // non-equal direction, then the locations accessed in the inner levels cannot
729    // overlap in memory. We assumes the indexes never overlap into neighboring
730    // dimensions.
731    for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
732      if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
733        return true;
734  
735    auto UnrollDirection = D->getDirection(UnrollLevel);
736  
737    // If the distance carried by the unrolled loop is 0, then after unrolling
738    // that distance will become non-zero resulting in non-overlapping accesses in
739    // the inner loops.
740    if (UnrollDirection == Dependence::DVEntry::EQ)
741      return true;
742  
743    if (UnrollDirection & Dependence::DVEntry::LT &&
744        !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
745                                    Sequentialized, D.get()))
746      return false;
747  
748    if (UnrollDirection & Dependence::DVEntry::GT &&
749        !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
750                                     Sequentialized, D.get()))
751      return false;
752  
753    return true;
754  }
755  
756  static bool
757  checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
758                    const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
759                    const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
760                    DependenceInfo &DI, LoopInfo &LI) {
761    SmallVector<BasicBlockSet, 8> AllBlocks;
762    for (Loop *L : Root.getLoopsInPreorder())
763      if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
764        AllBlocks.push_back(ForeBlocksMap.lookup(L));
765    AllBlocks.push_back(SubLoopBlocks);
766    for (Loop *L : Root.getLoopsInPreorder())
767      if (AftBlocksMap.find(L) != AftBlocksMap.end())
768        AllBlocks.push_back(AftBlocksMap.lookup(L));
769  
770    unsigned LoopDepth = Root.getLoopDepth();
771    SmallVector<Instruction *, 4> EarlierLoadsAndStores;
772    SmallVector<Instruction *, 4> CurrentLoadsAndStores;
773    for (BasicBlockSet &Blocks : AllBlocks) {
774      CurrentLoadsAndStores.clear();
775      if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
776        return false;
777  
778      Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
779      unsigned CurLoopDepth = CurLoop->getLoopDepth();
780  
781      for (auto *Earlier : EarlierLoadsAndStores) {
782        Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
783        unsigned EarlierDepth = EarlierLoop->getLoopDepth();
784        unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
785        for (auto *Later : CurrentLoadsAndStores) {
786          if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
787                               DI))
788            return false;
789        }
790      }
791  
792      size_t NumInsts = CurrentLoadsAndStores.size();
793      for (size_t I = 0; I < NumInsts; ++I) {
794        for (size_t J = I; J < NumInsts; ++J) {
795          if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
796                               LoopDepth, CurLoopDepth, true, DI))
797            return false;
798        }
799      }
800  
801      EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
802                                   CurrentLoadsAndStores.end());
803    }
804    return true;
805  }
806  
807  static bool isEligibleLoopForm(const Loop &Root) {
808    // Root must have a child.
809    if (Root.getSubLoops().size() != 1)
810      return false;
811  
812    const Loop *L = &Root;
813    do {
814      // All loops in Root need to be in simplify and rotated form.
815      if (!L->isLoopSimplifyForm())
816        return false;
817  
818      if (!L->isRotatedForm())
819        return false;
820  
821      if (L->getHeader()->hasAddressTaken()) {
822        LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
823        return false;
824      }
825  
826      unsigned SubLoopsSize = L->getSubLoops().size();
827      if (SubLoopsSize == 0)
828        return true;
829  
830      // Only one child is allowed.
831      if (SubLoopsSize != 1)
832        return false;
833  
834      L = L->getSubLoops()[0];
835    } while (L);
836  
837    return true;
838  }
839  
840  static Loop *getInnerMostLoop(Loop *L) {
841    while (!L->getSubLoops().empty())
842      L = L->getSubLoops()[0];
843    return L;
844  }
845  
846  bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
847                                  DependenceInfo &DI, LoopInfo &LI) {
848    if (!isEligibleLoopForm(*L)) {
849      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
850      return false;
851    }
852  
853    /* We currently handle outer loops like this:
854          |
855      ForeFirst    <------\   }
856       Blocks             |   } ForeBlocks of L
857      ForeLast            |   }
858          |               |
859         ...              |
860          |               |
861      ForeFirst    <----\ |   }
862       Blocks           | |   } ForeBlocks of a inner loop of L
863      ForeLast          | |   }
864          |             | |
865      JamLoopFirst  <\  | |   }
866       Blocks        |  | |   } JamLoopBlocks of the innermost loop
867      JamLoopLast   -/  | |   }
868          |             | |
869      AftFirst          | |   }
870       Blocks           | |   } AftBlocks of a inner loop of L
871      AftLast     ------/ |   }
872          |               |
873         ...              |
874          |               |
875      AftFirst            |   }
876       Blocks             |   } AftBlocks of L
877      AftLast     --------/   }
878          |
879  
880      There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
881      and AftBlocks, providing that there is one edge from Fores to SubLoops,
882      one edge from SubLoops to Afts and a single outer loop exit (from Afts).
883      In practice we currently limit Aft blocks to a single block, and limit
884      things further in the profitablility checks of the unroll and jam pass.
885  
886      Because of the way we rearrange basic blocks, we also require that
887      the Fore blocks of L on all unrolled iterations are safe to move before the
888      blocks of the direct child of L of all iterations. So we require that the
889      phi node looping operands of ForeHeader can be moved to at least the end of
890      ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
891      match up Phi's correctly.
892  
893      i.e. The old order of blocks used to be
894             (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
895           It needs to be safe to transform this to
896             (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
897  
898      There are then a number of checks along the lines of no calls, no
899      exceptions, inner loop IV is consistent, etc. Note that for loops requiring
900      runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
901      UnrollAndJamLoop if the trip count cannot be easily calculated.
902    */
903  
904    // Split blocks into Fore/SubLoop/Aft based on dominators
905    Loop *JamLoop = getInnerMostLoop(L);
906    BasicBlockSet SubLoopBlocks;
907    DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
908    DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
909    if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
910                                  AftBlocksMap, DT)) {
911      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
912      return false;
913    }
914  
915    // Aft blocks may need to move instructions to fore blocks, which becomes more
916    // difficult if there are multiple (potentially conditionally executed)
917    // blocks. For now we just exclude loops with multiple aft blocks.
918    if (AftBlocksMap[L].size() != 1) {
919      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
920                           "multiple blocks after the loop\n");
921      return false;
922    }
923  
924    // Check inner loop backedge count is consistent on all iterations of the
925    // outer loop
926    if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
927          return !hasIterationCountInvariantInParent(SubLoop, SE);
928        })) {
929      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
930                           "not consistent on each iteration\n");
931      return false;
932    }
933  
934    // Check the loop safety info for exceptions.
935    SimpleLoopSafetyInfo LSI;
936    LSI.computeLoopSafetyInfo(L);
937    if (LSI.anyBlockMayThrow()) {
938      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
939      return false;
940    }
941  
942    // We've ruled out the easy stuff and now need to check that there are no
943    // interdependencies which may prevent us from moving the:
944    //  ForeBlocks before Subloop and AftBlocks.
945    //  Subloop before AftBlocks.
946    //  ForeBlock phi operands before the subloop
947  
948    // Make sure we can move all instructions we need to before the subloop
949    BasicBlock *Header = L->getHeader();
950    BasicBlock *Latch = L->getLoopLatch();
951    BasicBlockSet AftBlocks = AftBlocksMap[L];
952    Loop *SubLoop = L->getSubLoops()[0];
953    if (!processHeaderPhiOperands(
954            Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
955              if (SubLoop->contains(I->getParent()))
956                return false;
957              if (AftBlocks.count(I->getParent())) {
958                // If we hit a phi node in afts we know we are done (probably
959                // LCSSA)
960                if (isa<PHINode>(I))
961                  return false;
962                // Can't move instructions with side effects or memory
963                // reads/writes
964                if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
965                  return false;
966              }
967              // Keep going
968              return true;
969            })) {
970      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
971                           "instructions after subloop to before it\n");
972      return false;
973    }
974  
975    // Check for memory dependencies which prohibit the unrolling we are doing.
976    // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
977    // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
978    if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
979                           LI)) {
980      LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
981      return false;
982    }
983  
984    return true;
985  }
986