xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Analysis/AssumptionCache.h"
23 #include "llvm/Analysis/DependenceAnalysis.h"
24 #include "llvm/Analysis/DomTreeUpdater.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/Analysis/LoopIterator.h"
27 #include "llvm/Analysis/MustExecute.h"
28 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
29 #include "llvm/Analysis/ScalarEvolution.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/DebugInfoMetadata.h"
32 #include "llvm/IR/DebugLoc.h"
33 #include "llvm/IR/DiagnosticInfo.h"
34 #include "llvm/IR/Dominators.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/IntrinsicInst.h"
39 #include "llvm/IR/User.h"
40 #include "llvm/IR/Value.h"
41 #include "llvm/IR/ValueHandle.h"
42 #include "llvm/IR/ValueMap.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Support/GenericDomTree.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
49 #include "llvm/Transforms/Utils/Cloning.h"
50 #include "llvm/Transforms/Utils/LoopUtils.h"
51 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
52 #include "llvm/Transforms/Utils/UnrollLoop.h"
53 #include "llvm/Transforms/Utils/ValueMapper.h"
54 #include <assert.h>
55 #include <memory>
56 #include <type_traits>
57 #include <vector>
58 
59 using namespace llvm;
60 
61 #define DEBUG_TYPE "loop-unroll-and-jam"
62 
63 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
64 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
65 
66 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
67 
68 // Partition blocks in an outer/inner loop pair into blocks before and after
69 // the loop
70 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
71                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
72   Loop *SubLoop = L.getSubLoops()[0];
73   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
74 
75   for (BasicBlock *BB : L.blocks()) {
76     if (!SubLoop->contains(BB)) {
77       if (DT.dominates(SubLoopLatch, BB))
78         AftBlocks.insert(BB);
79       else
80         ForeBlocks.insert(BB);
81     }
82   }
83 
84   // Check that all blocks in ForeBlocks together dominate the subloop
85   // TODO: This might ideally be done better with a dominator/postdominators.
86   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
87   for (BasicBlock *BB : ForeBlocks) {
88     if (BB == SubLoopPreHeader)
89       continue;
90     Instruction *TI = BB->getTerminator();
91     for (BasicBlock *Succ : successors(TI))
92       if (!ForeBlocks.count(Succ))
93         return false;
94   }
95 
96   return true;
97 }
98 
99 /// Partition blocks in a loop nest into blocks before and after each inner
100 /// loop.
101 static bool partitionOuterLoopBlocks(
102     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
103     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
104     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
105   JamLoopBlocks.insert_range(JamLoop.blocks());
106 
107   for (Loop *L : Root.getLoopsInPreorder()) {
108     if (L == &JamLoop)
109       break;
110 
111     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
112       return false;
113   }
114 
115   return true;
116 }
117 
118 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
119 // than 2 levels loop.
120 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
121                                      BasicBlockSet &ForeBlocks,
122                                      BasicBlockSet &SubLoopBlocks,
123                                      BasicBlockSet &AftBlocks,
124                                      DominatorTree *DT) {
125   SubLoopBlocks.insert_range(SubLoop->blocks());
126   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
127 }
128 
129 // Looks at the phi nodes in Header for values coming from Latch. For these
130 // instructions and all their operands calls Visit on them, keeping going for
131 // all the operands in AftBlocks. Returns false if Visit returns false,
132 // otherwise returns true. This is used to process the instructions in the
133 // Aft blocks that need to be moved before the subloop. It is used in two
134 // places. One to check that the required set of instructions can be moved
135 // before the loop. Then to collect the instructions to actually move in
136 // moveHeaderPhiOperandsToForeBlocks.
137 template <typename T>
138 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
139                                      BasicBlockSet &AftBlocks, T Visit) {
140   SmallPtrSet<Instruction *, 8> VisitedInstr;
141 
142   std::function<bool(Instruction * I)> ProcessInstr = [&](Instruction *I) {
143     if (!VisitedInstr.insert(I).second)
144       return true;
145 
146     if (AftBlocks.count(I->getParent()))
147       for (auto &U : I->operands())
148         if (Instruction *II = dyn_cast<Instruction>(U))
149           if (!ProcessInstr(II))
150             return false;
151 
152     return Visit(I);
153   };
154 
155   for (auto &Phi : Header->phis()) {
156     Value *V = Phi.getIncomingValueForBlock(Latch);
157     if (Instruction *I = dyn_cast<Instruction>(V))
158       if (!ProcessInstr(I))
159         return false;
160   }
161 
162   return true;
163 }
164 
165 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
166 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
167                                               BasicBlock *Latch,
168                                               BasicBlock::iterator InsertLoc,
169                                               BasicBlockSet &AftBlocks) {
170   // We need to ensure we move the instructions in the correct order,
171   // starting with the earliest required instruction and moving forward.
172   processHeaderPhiOperands(Header, Latch, AftBlocks,
173                            [&AftBlocks, &InsertLoc](Instruction *I) {
174                              if (AftBlocks.count(I->getParent()))
175                                I->moveBefore(InsertLoc);
176                              return true;
177                            });
178 }
179 
180 /*
181   This method performs Unroll and Jam. For a simple loop like:
182   for (i = ..)
183     Fore(i)
184     for (j = ..)
185       SubLoop(i, j)
186     Aft(i)
187 
188   Instead of doing normal inner or outer unrolling, we do:
189   for (i = .., i+=2)
190     Fore(i)
191     Fore(i+1)
192     for (j = ..)
193       SubLoop(i, j)
194       SubLoop(i+1, j)
195     Aft(i)
196     Aft(i+1)
197 
198   So the outer loop is essetially unrolled and then the inner loops are fused
199   ("jammed") together into a single loop. This can increase speed when there
200   are loads in SubLoop that are invariant to i, as they become shared between
201   the now jammed inner loops.
202 
203   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
204   Fore blocks are those before the inner loop, Aft are those after. Normal
205   Unroll code is used to copy each of these sets of blocks and the results are
206   combined together into the final form above.
207 
208   isSafeToUnrollAndJam should be used prior to calling this to make sure the
209   unrolling will be valid. Checking profitablility is also advisable.
210 
211   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
212   necessary to create one and not fully unrolled).
213 */
214 LoopUnrollResult
215 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
216                        unsigned TripMultiple, bool UnrollRemainder,
217                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
218                        AssumptionCache *AC, const TargetTransformInfo *TTI,
219                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
220 
221   // When we enter here we should have already checked that it is safe
222   BasicBlock *Header = L->getHeader();
223   assert(Header && "No header.");
224   assert(L->getSubLoops().size() == 1);
225   Loop *SubLoop = *L->begin();
226 
227   // Don't enter the unroll code if there is nothing to do.
228   if (TripCount == 0 && Count < 2) {
229     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
230     return LoopUnrollResult::Unmodified;
231   }
232 
233   assert(Count > 0);
234   assert(TripMultiple > 0);
235   assert(TripCount == 0 || TripCount % TripMultiple == 0);
236 
237   // Are we eliminating the loop control altogether?
238   bool CompletelyUnroll = (Count == TripCount);
239 
240   // We use the runtime remainder in cases where we don't know trip multiple
241   if (TripMultiple % Count != 0) {
242     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
243                                     /*UseEpilogRemainder*/ true,
244                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
245                                     LI, SE, DT, AC, TTI, true,
246                                     SCEVCheapExpansionBudget, EpilogueLoop)) {
247       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
248                            "generated when assuming runtime trip count\n");
249       return LoopUnrollResult::Unmodified;
250     }
251   }
252 
253   // Notify ScalarEvolution that the loop will be substantially changed,
254   // if not outright eliminated.
255   if (SE) {
256     SE->forgetLoop(L);
257     SE->forgetBlockAndLoopDispositions();
258   }
259 
260   using namespace ore;
261   // Report the unrolling decision.
262   if (CompletelyUnroll) {
263     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
264                       << Header->getName() << " with trip count " << TripCount
265                       << "!\n");
266     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
267                                  L->getHeader())
268               << "completely unroll and jammed loop with "
269               << NV("UnrollCount", TripCount) << " iterations");
270   } else {
271     auto DiagBuilder = [&]() {
272       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
273                               L->getHeader());
274       return Diag << "unroll and jammed loop by a factor of "
275                   << NV("UnrollCount", Count);
276     };
277 
278     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
279                       << " by " << Count);
280     if (TripMultiple != 1) {
281       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
282       ORE->emit([&]() {
283         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
284                              << " trips per branch";
285       });
286     } else {
287       LLVM_DEBUG(dbgs() << " with run-time trip count");
288       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
289     }
290     LLVM_DEBUG(dbgs() << "!\n");
291   }
292 
293   BasicBlock *Preheader = L->getLoopPreheader();
294   BasicBlock *LatchBlock = L->getLoopLatch();
295   assert(Preheader && "No preheader");
296   assert(LatchBlock && "No latch block");
297   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
298   assert(BI && !BI->isUnconditional());
299   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
300   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
301   bool SubLoopContinueOnTrue = SubLoop->contains(
302       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
303 
304   // Partition blocks in an outer/inner loop pair into blocks before and after
305   // the loop
306   BasicBlockSet SubLoopBlocks;
307   BasicBlockSet ForeBlocks;
308   BasicBlockSet AftBlocks;
309   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
310                            DT);
311 
312   // We keep track of the entering/first and exiting/last block of each of
313   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
314   // blocks easier.
315   std::vector<BasicBlock *> ForeBlocksFirst;
316   std::vector<BasicBlock *> ForeBlocksLast;
317   std::vector<BasicBlock *> SubLoopBlocksFirst;
318   std::vector<BasicBlock *> SubLoopBlocksLast;
319   std::vector<BasicBlock *> AftBlocksFirst;
320   std::vector<BasicBlock *> AftBlocksLast;
321   ForeBlocksFirst.push_back(Header);
322   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
323   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
324   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
325   AftBlocksFirst.push_back(SubLoop->getExitBlock());
326   AftBlocksLast.push_back(L->getExitingBlock());
327   // Maps Blocks[0] -> Blocks[It]
328   ValueToValueMapTy LastValueMap;
329 
330   // Move any instructions from fore phi operands from AftBlocks into Fore.
331   moveHeaderPhiOperandsToForeBlocks(
332       Header, LatchBlock, ForeBlocksLast[0]->getTerminator()->getIterator(),
333       AftBlocks);
334 
335   // The current on-the-fly SSA update requires blocks to be processed in
336   // reverse postorder so that LastValueMap contains the correct value at each
337   // exit.
338   LoopBlocksDFS DFS(L);
339   DFS.perform(LI);
340   // Stash the DFS iterators before adding blocks to the loop.
341   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
342   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
343 
344   // When a FSDiscriminator is enabled, we don't need to add the multiply
345   // factors to the discriminators.
346   if (Header->getParent()->shouldEmitDebugInfoForProfiling() &&
347       !EnableFSDiscriminator)
348     for (BasicBlock *BB : L->getBlocks())
349       for (Instruction &I : *BB)
350         if (!I.isDebugOrPseudoInst())
351           if (const DILocation *DIL = I.getDebugLoc()) {
352             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
353             if (NewDIL)
354               I.setDebugLoc(*NewDIL);
355             else
356               LLVM_DEBUG(dbgs()
357                          << "Failed to create new discriminator: "
358                          << DIL->getFilename() << " Line: " << DIL->getLine());
359           }
360 
361   // Copy all blocks
362   for (unsigned It = 1; It != Count; ++It) {
363     SmallVector<BasicBlock *, 8> NewBlocks;
364     // Maps Blocks[It] -> Blocks[It-1]
365     DenseMap<Value *, Value *> PrevItValueMap;
366     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
367     NewLoops[L] = L;
368     NewLoops[SubLoop] = SubLoop;
369 
370     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
371       ValueToValueMapTy VMap;
372       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
373       Header->getParent()->insert(Header->getParent()->end(), New);
374 
375       // Tell LI about New.
376       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
377 
378       if (ForeBlocks.count(*BB)) {
379         if (*BB == ForeBlocksFirst[0])
380           ForeBlocksFirst.push_back(New);
381         if (*BB == ForeBlocksLast[0])
382           ForeBlocksLast.push_back(New);
383       } else if (SubLoopBlocks.count(*BB)) {
384         if (*BB == SubLoopBlocksFirst[0])
385           SubLoopBlocksFirst.push_back(New);
386         if (*BB == SubLoopBlocksLast[0])
387           SubLoopBlocksLast.push_back(New);
388       } else if (AftBlocks.count(*BB)) {
389         if (*BB == AftBlocksFirst[0])
390           AftBlocksFirst.push_back(New);
391         if (*BB == AftBlocksLast[0])
392           AftBlocksLast.push_back(New);
393       } else {
394         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
395       }
396 
397       // Update our running maps of newest clones
398       auto &Last = LastValueMap[*BB];
399       PrevItValueMap[New] = (It == 1 ? *BB : Last);
400       Last = New;
401       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
402            VI != VE; ++VI) {
403         auto &LVM = LastValueMap[VI->first];
404         PrevItValueMap[VI->second] =
405             const_cast<Value *>(It == 1 ? VI->first : LVM);
406         LVM = VI->second;
407       }
408 
409       NewBlocks.push_back(New);
410 
411       // Update DomTree:
412       if (*BB == ForeBlocksFirst[0])
413         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
414       else if (*BB == SubLoopBlocksFirst[0])
415         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
416       else if (*BB == AftBlocksFirst[0])
417         DT->addNewBlock(New, AftBlocksLast[It - 1]);
418       else {
419         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
420         // structure.
421         auto BBDomNode = DT->getNode(*BB);
422         auto BBIDom = BBDomNode->getIDom();
423         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
424         assert(OriginalBBIDom);
425         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
426         DT->addNewBlock(
427             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
428       }
429     }
430 
431     // Remap all instructions in the most recent iteration
432     remapInstructionsInBlocks(NewBlocks, LastValueMap);
433     for (BasicBlock *NewBlock : NewBlocks) {
434       for (Instruction &I : *NewBlock) {
435         if (auto *II = dyn_cast<AssumeInst>(&I))
436           AC->registerAssumption(II);
437       }
438     }
439 
440     // Alter the ForeBlocks phi's, pointing them at the latest version of the
441     // value from the previous iteration's phis
442     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
443       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
444       assert(OldValue && "should have incoming edge from Aft[It]");
445       Value *NewValue = OldValue;
446       if (Value *PrevValue = PrevItValueMap[OldValue])
447         NewValue = PrevValue;
448 
449       assert(Phi.getNumOperands() == 2);
450       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
451       Phi.setIncomingValue(0, NewValue);
452       Phi.removeIncomingValue(1);
453     }
454   }
455 
456   // Now that all the basic blocks for the unrolled iterations are in place,
457   // finish up connecting the blocks and phi nodes. At this point LastValueMap
458   // is the last unrolled iterations values.
459 
460   // Update Phis in BB from OldBB to point to NewBB and use the latest value
461   // from LastValueMap
462   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
463                                      BasicBlock *NewBB,
464                                      ValueToValueMapTy &LastValueMap) {
465     for (PHINode &Phi : BB->phis()) {
466       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
467         if (Phi.getIncomingBlock(b) == OldBB) {
468           Value *OldValue = Phi.getIncomingValue(b);
469           if (Value *LastValue = LastValueMap[OldValue])
470             Phi.setIncomingValue(b, LastValue);
471           Phi.setIncomingBlock(b, NewBB);
472           break;
473         }
474       }
475     }
476   };
477   // Move all the phis from Src into Dest
478   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
479     BasicBlock::iterator insertPoint = Dest->getFirstNonPHIIt();
480     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
481       Phi->moveBefore(*Dest, insertPoint);
482   };
483 
484   // Update the PHI values outside the loop to point to the last block
485   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
486                            LastValueMap);
487 
488   // Update ForeBlocks successors and phi nodes
489   BranchInst *ForeTerm =
490       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
491   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
492   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
493 
494   if (CompletelyUnroll) {
495     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
496       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
497       Phi->eraseFromParent();
498     }
499   } else {
500     // Update the PHI values to point to the last aft block
501     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
502                              AftBlocksLast.back(), LastValueMap);
503   }
504 
505   for (unsigned It = 1; It != Count; It++) {
506     // Remap ForeBlock successors from previous iteration to this
507     BranchInst *ForeTerm =
508         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
509     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
510     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
511   }
512 
513   // Subloop successors and phis
514   BranchInst *SubTerm =
515       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
516   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
517   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
518   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
519                                             ForeBlocksLast.back());
520   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
521                                             SubLoopBlocksLast.back());
522 
523   for (unsigned It = 1; It != Count; It++) {
524     // Replace the conditional branch of the previous iteration subloop with an
525     // unconditional one to this one
526     BranchInst *SubTerm =
527         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
528     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm->getIterator());
529     SubTerm->eraseFromParent();
530 
531     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
532                                                ForeBlocksLast.back());
533     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
534                                                SubLoopBlocksLast.back());
535     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
536   }
537 
538   // Aft blocks successors and phis
539   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
540   if (CompletelyUnroll) {
541     BranchInst::Create(LoopExit, AftTerm->getIterator());
542     AftTerm->eraseFromParent();
543   } else {
544     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
545     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
546            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
547   }
548   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
549                                         SubLoopBlocksLast.back());
550 
551   for (unsigned It = 1; It != Count; It++) {
552     // Replace the conditional branch of the previous iteration subloop with an
553     // unconditional one to this one
554     BranchInst *AftTerm =
555         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
556     BranchInst::Create(AftBlocksFirst[It], AftTerm->getIterator());
557     AftTerm->eraseFromParent();
558 
559     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
560                                            SubLoopBlocksLast.back());
561     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
562   }
563 
564   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
565   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
566   // new ones required.
567   if (Count != 1) {
568     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
569     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
570                            SubLoopBlocksFirst[0]);
571     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
572                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
573 
574     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
575                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
576     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
577                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
578     DTU.applyUpdatesPermissive(DTUpdates);
579   }
580 
581   // Merge adjacent basic blocks, if possible.
582   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
583   MergeBlocks.insert_range(ForeBlocksLast);
584   MergeBlocks.insert_range(SubLoopBlocksLast);
585   MergeBlocks.insert_range(AftBlocksLast);
586 
587   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
588 
589   // Apply updates to the DomTree.
590   DT = &DTU.getDomTree();
591 
592   // At this point, the code is well formed.  We now do a quick sweep over the
593   // inserted code, doing constant propagation and dead code elimination as we
594   // go.
595   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
596   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
597                           TTI);
598 
599   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
600   ++NumUnrolledAndJammed;
601 
602   // Update LoopInfo if the loop is completely removed.
603   if (CompletelyUnroll)
604     LI->erase(L);
605 
606 #ifndef NDEBUG
607   // We shouldn't have done anything to break loop simplify form or LCSSA.
608   Loop *OutestLoop = SubLoop->getParentLoop()
609                          ? SubLoop->getParentLoop()->getParentLoop()
610                                ? SubLoop->getParentLoop()->getParentLoop()
611                                : SubLoop->getParentLoop()
612                          : SubLoop;
613   assert(DT->verify());
614   LI->verify(*DT);
615   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
616   if (!CompletelyUnroll)
617     assert(L->isLoopSimplifyForm());
618   assert(SubLoop->isLoopSimplifyForm());
619   SE->verify();
620 #endif
621 
622   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
623                           : LoopUnrollResult::PartiallyUnrolled;
624 }
625 
626 static bool getLoadsAndStores(BasicBlockSet &Blocks,
627                               SmallVector<Instruction *, 4> &MemInstr) {
628   // Scan the BBs and collect legal loads and stores.
629   // Returns false if non-simple loads/stores are found.
630   for (BasicBlock *BB : Blocks) {
631     for (Instruction &I : *BB) {
632       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
633         if (!Ld->isSimple())
634           return false;
635         MemInstr.push_back(&I);
636       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
637         if (!St->isSimple())
638           return false;
639         MemInstr.push_back(&I);
640       } else if (I.mayReadOrWriteMemory()) {
641         return false;
642       }
643     }
644   }
645   return true;
646 }
647 
648 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
649                                        unsigned UnrollLevel, unsigned JamLevel,
650                                        bool Sequentialized, Dependence *D) {
651   // UnrollLevel might carry the dependency Src --> Dst
652   // Does a different loop after unrolling?
653   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
654        ++CurLoopDepth) {
655     auto JammedDir = D->getDirection(CurLoopDepth);
656     if (JammedDir == Dependence::DVEntry::LT)
657       return true;
658 
659     if (JammedDir & Dependence::DVEntry::GT)
660       return false;
661   }
662 
663   return true;
664 }
665 
666 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
667                                         unsigned UnrollLevel, unsigned JamLevel,
668                                         bool Sequentialized, Dependence *D) {
669   // UnrollLevel might carry the dependency Dst --> Src
670   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
671        ++CurLoopDepth) {
672     auto JammedDir = D->getDirection(CurLoopDepth);
673     if (JammedDir == Dependence::DVEntry::GT)
674       return true;
675 
676     if (JammedDir & Dependence::DVEntry::LT)
677       return false;
678   }
679 
680   // Backward dependencies are only preserved if not interleaved.
681   return Sequentialized;
682 }
683 
684 // Check whether it is semantically safe Src and Dst considering any potential
685 // dependency between them.
686 //
687 // @param UnrollLevel The level of the loop being unrolled
688 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
689 // different levels, the outermost common loop counts as jammed level
690 //
691 // @return true if is safe and false if there is a dependency violation.
692 static bool checkDependency(Instruction *Src, Instruction *Dst,
693                             unsigned UnrollLevel, unsigned JamLevel,
694                             bool Sequentialized, DependenceInfo &DI) {
695   assert(UnrollLevel <= JamLevel &&
696          "Expecting JamLevel to be at least UnrollLevel");
697 
698   if (Src == Dst)
699     return true;
700   // Ignore Input dependencies.
701   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
702     return true;
703 
704   // Check whether unroll-and-jam may violate a dependency.
705   // By construction, every dependency will be lexicographically non-negative
706   // (if it was, it would violate the current execution order), such as
707   //   (0,0,>,*,*)
708   // Unroll-and-jam changes the GT execution of two executions to the same
709   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
710   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
711   //   (0,0,>=,*,*)
712   // Now, the dependency is not necessarily non-negative anymore, i.e.
713   // unroll-and-jam may violate correctness.
714   std::unique_ptr<Dependence> D = DI.depends(Src, Dst);
715   if (!D)
716     return true;
717   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
718 
719   if (D->isConfused()) {
720     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
721                       << "  " << *Src << "\n"
722                       << "  " << *Dst << "\n");
723     return false;
724   }
725 
726   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
727   // non-equal direction, then the locations accessed in the inner levels cannot
728   // overlap in memory. We assumes the indexes never overlap into neighboring
729   // dimensions.
730   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
731     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
732       return true;
733 
734   auto UnrollDirection = D->getDirection(UnrollLevel);
735 
736   // If the distance carried by the unrolled loop is 0, then after unrolling
737   // that distance will become non-zero resulting in non-overlapping accesses in
738   // the inner loops.
739   if (UnrollDirection == Dependence::DVEntry::EQ)
740     return true;
741 
742   if (UnrollDirection & Dependence::DVEntry::LT &&
743       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
744                                   Sequentialized, D.get()))
745     return false;
746 
747   if (UnrollDirection & Dependence::DVEntry::GT &&
748       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
749                                    Sequentialized, D.get()))
750     return false;
751 
752   return true;
753 }
754 
755 static bool
756 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
757                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
758                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
759                   DependenceInfo &DI, LoopInfo &LI) {
760   SmallVector<BasicBlockSet, 8> AllBlocks;
761   for (Loop *L : Root.getLoopsInPreorder())
762     if (ForeBlocksMap.contains(L))
763       AllBlocks.push_back(ForeBlocksMap.lookup(L));
764   AllBlocks.push_back(SubLoopBlocks);
765   for (Loop *L : Root.getLoopsInPreorder())
766     if (AftBlocksMap.contains(L))
767       AllBlocks.push_back(AftBlocksMap.lookup(L));
768 
769   unsigned LoopDepth = Root.getLoopDepth();
770   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
771   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
772   for (BasicBlockSet &Blocks : AllBlocks) {
773     CurrentLoadsAndStores.clear();
774     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
775       return false;
776 
777     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
778     unsigned CurLoopDepth = CurLoop->getLoopDepth();
779 
780     for (auto *Earlier : EarlierLoadsAndStores) {
781       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
782       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
783       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
784       for (auto *Later : CurrentLoadsAndStores) {
785         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
786                              DI))
787           return false;
788       }
789     }
790 
791     size_t NumInsts = CurrentLoadsAndStores.size();
792     for (size_t I = 0; I < NumInsts; ++I) {
793       for (size_t J = I; J < NumInsts; ++J) {
794         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
795                              LoopDepth, CurLoopDepth, true, DI))
796           return false;
797       }
798     }
799 
800     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
801                                  CurrentLoadsAndStores.end());
802   }
803   return true;
804 }
805 
806 static bool isEligibleLoopForm(const Loop &Root) {
807   // Root must have a child.
808   if (Root.getSubLoops().size() != 1)
809     return false;
810 
811   const Loop *L = &Root;
812   do {
813     // All loops in Root need to be in simplify and rotated form.
814     if (!L->isLoopSimplifyForm())
815       return false;
816 
817     if (!L->isRotatedForm())
818       return false;
819 
820     if (L->getHeader()->hasAddressTaken()) {
821       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
822       return false;
823     }
824 
825     unsigned SubLoopsSize = L->getSubLoops().size();
826     if (SubLoopsSize == 0)
827       return true;
828 
829     // Only one child is allowed.
830     if (SubLoopsSize != 1)
831       return false;
832 
833     // Only loops with a single exit block can be unrolled and jammed.
834     // The function getExitBlock() is used for this check, rather than
835     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
836     // disallowed.
837     if (!L->getExitBlock()) {
838       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
839                            "blocks can be unrolled and jammed.\n");
840       return false;
841     }
842 
843     // Only loops with a single exiting block can be unrolled and jammed.
844     if (!L->getExitingBlock()) {
845       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
846                            "exiting blocks can be unrolled and jammed.\n");
847       return false;
848     }
849 
850     L = L->getSubLoops()[0];
851   } while (L);
852 
853   return true;
854 }
855 
856 static Loop *getInnerMostLoop(Loop *L) {
857   while (!L->getSubLoops().empty())
858     L = L->getSubLoops()[0];
859   return L;
860 }
861 
862 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
863                                 DependenceInfo &DI, LoopInfo &LI) {
864   if (!isEligibleLoopForm(*L)) {
865     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
866     return false;
867   }
868 
869   /* We currently handle outer loops like this:
870         |
871     ForeFirst    <------\   }
872      Blocks             |   } ForeBlocks of L
873     ForeLast            |   }
874         |               |
875        ...              |
876         |               |
877     ForeFirst    <----\ |   }
878      Blocks           | |   } ForeBlocks of a inner loop of L
879     ForeLast          | |   }
880         |             | |
881     JamLoopFirst  <\  | |   }
882      Blocks        |  | |   } JamLoopBlocks of the innermost loop
883     JamLoopLast   -/  | |   }
884         |             | |
885     AftFirst          | |   }
886      Blocks           | |   } AftBlocks of a inner loop of L
887     AftLast     ------/ |   }
888         |               |
889        ...              |
890         |               |
891     AftFirst            |   }
892      Blocks             |   } AftBlocks of L
893     AftLast     --------/   }
894         |
895 
896     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
897     and AftBlocks, providing that there is one edge from Fores to SubLoops,
898     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
899     In practice we currently limit Aft blocks to a single block, and limit
900     things further in the profitablility checks of the unroll and jam pass.
901 
902     Because of the way we rearrange basic blocks, we also require that
903     the Fore blocks of L on all unrolled iterations are safe to move before the
904     blocks of the direct child of L of all iterations. So we require that the
905     phi node looping operands of ForeHeader can be moved to at least the end of
906     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
907     match up Phi's correctly.
908 
909     i.e. The old order of blocks used to be
910            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
911          It needs to be safe to transform this to
912            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
913 
914     There are then a number of checks along the lines of no calls, no
915     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
916     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
917     UnrollAndJamLoop if the trip count cannot be easily calculated.
918   */
919 
920   // Split blocks into Fore/SubLoop/Aft based on dominators
921   Loop *JamLoop = getInnerMostLoop(L);
922   BasicBlockSet SubLoopBlocks;
923   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
924   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
925   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
926                                 AftBlocksMap, DT)) {
927     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
928     return false;
929   }
930 
931   // Aft blocks may need to move instructions to fore blocks, which becomes more
932   // difficult if there are multiple (potentially conditionally executed)
933   // blocks. For now we just exclude loops with multiple aft blocks.
934   if (AftBlocksMap[L].size() != 1) {
935     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
936                          "multiple blocks after the loop\n");
937     return false;
938   }
939 
940   // Check inner loop backedge count is consistent on all iterations of the
941   // outer loop
942   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
943         return !hasIterationCountInvariantInParent(SubLoop, SE);
944       })) {
945     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
946                          "not consistent on each iteration\n");
947     return false;
948   }
949 
950   // Check the loop safety info for exceptions.
951   SimpleLoopSafetyInfo LSI;
952   LSI.computeLoopSafetyInfo(L);
953   if (LSI.anyBlockMayThrow()) {
954     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
955     return false;
956   }
957 
958   // We've ruled out the easy stuff and now need to check that there are no
959   // interdependencies which may prevent us from moving the:
960   //  ForeBlocks before Subloop and AftBlocks.
961   //  Subloop before AftBlocks.
962   //  ForeBlock phi operands before the subloop
963 
964   // Make sure we can move all instructions we need to before the subloop
965   BasicBlock *Header = L->getHeader();
966   BasicBlock *Latch = L->getLoopLatch();
967   BasicBlockSet AftBlocks = AftBlocksMap[L];
968   Loop *SubLoop = L->getSubLoops()[0];
969   if (!processHeaderPhiOperands(
970           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
971             if (SubLoop->contains(I->getParent()))
972               return false;
973             if (AftBlocks.count(I->getParent())) {
974               // If we hit a phi node in afts we know we are done (probably
975               // LCSSA)
976               if (isa<PHINode>(I))
977                 return false;
978               // Can't move instructions with side effects or memory
979               // reads/writes
980               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
981                 return false;
982             }
983             // Keep going
984             return true;
985           })) {
986     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
987                          "instructions after subloop to before it\n");
988     return false;
989   }
990 
991   // Check for memory dependencies which prohibit the unrolling we are doing.
992   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
993   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
994   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
995                          LI)) {
996     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
997     return false;
998   }
999 
1000   return true;
1001 }
1002