xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision 79ac3c12a714bcd3f2354c52d948aed9575c46d6)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/DependenceAnalysis.h"
27 #include "llvm/Analysis/DomTreeUpdater.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/Analysis/LoopIterator.h"
30 #include "llvm/Analysis/MustExecute.h"
31 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32 #include "llvm/Analysis/ScalarEvolution.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DebugLoc.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/IR/User.h"
44 #include "llvm/IR/Value.h"
45 #include "llvm/IR/ValueHandle.h"
46 #include "llvm/IR/ValueMap.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/ErrorHandling.h"
50 #include "llvm/Support/GenericDomTree.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/LoopUtils.h"
55 #include "llvm/Transforms/Utils/UnrollLoop.h"
56 #include "llvm/Transforms/Utils/ValueMapper.h"
57 #include <assert.h>
58 #include <memory>
59 #include <type_traits>
60 #include <vector>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "loop-unroll-and-jam"
65 
66 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68 
69 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70 
71 // Partition blocks in an outer/inner loop pair into blocks before and after
72 // the loop
73 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
75   Loop *SubLoop = L.getSubLoops()[0];
76   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77 
78   for (BasicBlock *BB : L.blocks()) {
79     if (!SubLoop->contains(BB)) {
80       if (DT.dominates(SubLoopLatch, BB))
81         AftBlocks.insert(BB);
82       else
83         ForeBlocks.insert(BB);
84     }
85   }
86 
87   // Check that all blocks in ForeBlocks together dominate the subloop
88   // TODO: This might ideally be done better with a dominator/postdominators.
89   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90   for (BasicBlock *BB : ForeBlocks) {
91     if (BB == SubLoopPreHeader)
92       continue;
93     Instruction *TI = BB->getTerminator();
94     for (BasicBlock *Succ : successors(TI))
95       if (!ForeBlocks.count(Succ))
96         return false;
97   }
98 
99   return true;
100 }
101 
102 /// Partition blocks in a loop nest into blocks before and after each inner
103 /// loop.
104 static bool partitionOuterLoopBlocks(
105     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109 
110   for (Loop *L : Root.getLoopsInPreorder()) {
111     if (L == &JamLoop)
112       break;
113 
114     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115       return false;
116   }
117 
118   return true;
119 }
120 
121 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122 // than 2 levels loop.
123 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                      BasicBlockSet &ForeBlocks,
125                                      BasicBlockSet &SubLoopBlocks,
126                                      BasicBlockSet &AftBlocks,
127                                      DominatorTree *DT) {
128   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130 }
131 
132 // Looks at the phi nodes in Header for values coming from Latch. For these
133 // instructions and all their operands calls Visit on them, keeping going for
134 // all the operands in AftBlocks. Returns false if Visit returns false,
135 // otherwise returns true. This is used to process the instructions in the
136 // Aft blocks that need to be moved before the subloop. It is used in two
137 // places. One to check that the required set of instructions can be moved
138 // before the loop. Then to collect the instructions to actually move in
139 // moveHeaderPhiOperandsToForeBlocks.
140 template <typename T>
141 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                      BasicBlockSet &AftBlocks, T Visit) {
143   SmallVector<Instruction *, 8> Worklist;
144   for (auto &Phi : Header->phis()) {
145     Value *V = Phi.getIncomingValueForBlock(Latch);
146     if (Instruction *I = dyn_cast<Instruction>(V))
147       Worklist.push_back(I);
148   }
149 
150   while (!Worklist.empty()) {
151     Instruction *I = Worklist.pop_back_val();
152     if (!Visit(I))
153       return false;
154 
155     if (AftBlocks.count(I->getParent()))
156       for (auto &U : I->operands())
157         if (Instruction *II = dyn_cast<Instruction>(U))
158           Worklist.push_back(II);
159   }
160 
161   return true;
162 }
163 
164 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
165 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
166                                               BasicBlock *Latch,
167                                               Instruction *InsertLoc,
168                                               BasicBlockSet &AftBlocks) {
169   // We need to ensure we move the instructions in the correct order,
170   // starting with the earliest required instruction and moving forward.
171   std::vector<Instruction *> Visited;
172   processHeaderPhiOperands(Header, Latch, AftBlocks,
173                            [&Visited, &AftBlocks](Instruction *I) {
174                              if (AftBlocks.count(I->getParent()))
175                                Visited.push_back(I);
176                              return true;
177                            });
178 
179   // Move all instructions in program order to before the InsertLoc
180   BasicBlock *InsertLocBB = InsertLoc->getParent();
181   for (Instruction *I : reverse(Visited)) {
182     if (I->getParent() != InsertLocBB)
183       I->moveBefore(InsertLoc);
184   }
185 }
186 
187 /*
188   This method performs Unroll and Jam. For a simple loop like:
189   for (i = ..)
190     Fore(i)
191     for (j = ..)
192       SubLoop(i, j)
193     Aft(i)
194 
195   Instead of doing normal inner or outer unrolling, we do:
196   for (i = .., i+=2)
197     Fore(i)
198     Fore(i+1)
199     for (j = ..)
200       SubLoop(i, j)
201       SubLoop(i+1, j)
202     Aft(i)
203     Aft(i+1)
204 
205   So the outer loop is essetially unrolled and then the inner loops are fused
206   ("jammed") together into a single loop. This can increase speed when there
207   are loads in SubLoop that are invariant to i, as they become shared between
208   the now jammed inner loops.
209 
210   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
211   Fore blocks are those before the inner loop, Aft are those after. Normal
212   Unroll code is used to copy each of these sets of blocks and the results are
213   combined together into the final form above.
214 
215   isSafeToUnrollAndJam should be used prior to calling this to make sure the
216   unrolling will be valid. Checking profitablility is also advisable.
217 
218   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
219   necessary to create one and not fully unrolled).
220 */
221 LoopUnrollResult
222 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
223                        unsigned TripMultiple, bool UnrollRemainder,
224                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
225                        AssumptionCache *AC, const TargetTransformInfo *TTI,
226                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
227 
228   // When we enter here we should have already checked that it is safe
229   BasicBlock *Header = L->getHeader();
230   assert(Header && "No header.");
231   assert(L->getSubLoops().size() == 1);
232   Loop *SubLoop = *L->begin();
233 
234   // Don't enter the unroll code if there is nothing to do.
235   if (TripCount == 0 && Count < 2) {
236     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
237     return LoopUnrollResult::Unmodified;
238   }
239 
240   assert(Count > 0);
241   assert(TripMultiple > 0);
242   assert(TripCount == 0 || TripCount % TripMultiple == 0);
243 
244   // Are we eliminating the loop control altogether?
245   bool CompletelyUnroll = (Count == TripCount);
246 
247   // We use the runtime remainder in cases where we don't know trip multiple
248   if (TripMultiple == 1 || TripMultiple % Count != 0) {
249     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
250                                     /*UseEpilogRemainder*/ true,
251                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
252                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
253       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
254                            "generated when assuming runtime trip count\n");
255       return LoopUnrollResult::Unmodified;
256     }
257   }
258 
259   // Notify ScalarEvolution that the loop will be substantially changed,
260   // if not outright eliminated.
261   if (SE) {
262     SE->forgetLoop(L);
263     SE->forgetLoop(SubLoop);
264   }
265 
266   using namespace ore;
267   // Report the unrolling decision.
268   if (CompletelyUnroll) {
269     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
270                       << Header->getName() << " with trip count " << TripCount
271                       << "!\n");
272     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
273                                  L->getHeader())
274               << "completely unroll and jammed loop with "
275               << NV("UnrollCount", TripCount) << " iterations");
276   } else {
277     auto DiagBuilder = [&]() {
278       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
279                               L->getHeader());
280       return Diag << "unroll and jammed loop by a factor of "
281                   << NV("UnrollCount", Count);
282     };
283 
284     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
285                       << " by " << Count);
286     if (TripMultiple != 1) {
287       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
288       ORE->emit([&]() {
289         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
290                              << " trips per branch";
291       });
292     } else {
293       LLVM_DEBUG(dbgs() << " with run-time trip count");
294       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
295     }
296     LLVM_DEBUG(dbgs() << "!\n");
297   }
298 
299   BasicBlock *Preheader = L->getLoopPreheader();
300   BasicBlock *LatchBlock = L->getLoopLatch();
301   assert(Preheader && "No preheader");
302   assert(LatchBlock && "No latch block");
303   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
304   assert(BI && !BI->isUnconditional());
305   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
306   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
307   bool SubLoopContinueOnTrue = SubLoop->contains(
308       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
309 
310   // Partition blocks in an outer/inner loop pair into blocks before and after
311   // the loop
312   BasicBlockSet SubLoopBlocks;
313   BasicBlockSet ForeBlocks;
314   BasicBlockSet AftBlocks;
315   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
316                            DT);
317 
318   // We keep track of the entering/first and exiting/last block of each of
319   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
320   // blocks easier.
321   std::vector<BasicBlock *> ForeBlocksFirst;
322   std::vector<BasicBlock *> ForeBlocksLast;
323   std::vector<BasicBlock *> SubLoopBlocksFirst;
324   std::vector<BasicBlock *> SubLoopBlocksLast;
325   std::vector<BasicBlock *> AftBlocksFirst;
326   std::vector<BasicBlock *> AftBlocksLast;
327   ForeBlocksFirst.push_back(Header);
328   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
329   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
330   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
331   AftBlocksFirst.push_back(SubLoop->getExitBlock());
332   AftBlocksLast.push_back(L->getExitingBlock());
333   // Maps Blocks[0] -> Blocks[It]
334   ValueToValueMapTy LastValueMap;
335 
336   // Move any instructions from fore phi operands from AftBlocks into Fore.
337   moveHeaderPhiOperandsToForeBlocks(
338       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
339 
340   // The current on-the-fly SSA update requires blocks to be processed in
341   // reverse postorder so that LastValueMap contains the correct value at each
342   // exit.
343   LoopBlocksDFS DFS(L);
344   DFS.perform(LI);
345   // Stash the DFS iterators before adding blocks to the loop.
346   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
347   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
348 
349   if (Header->getParent()->isDebugInfoForProfiling())
350     for (BasicBlock *BB : L->getBlocks())
351       for (Instruction &I : *BB)
352         if (!isa<DbgInfoIntrinsic>(&I))
353           if (const DILocation *DIL = I.getDebugLoc()) {
354             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
355             if (NewDIL)
356               I.setDebugLoc(NewDIL.getValue());
357             else
358               LLVM_DEBUG(dbgs()
359                          << "Failed to create new discriminator: "
360                          << DIL->getFilename() << " Line: " << DIL->getLine());
361           }
362 
363   // Copy all blocks
364   for (unsigned It = 1; It != Count; ++It) {
365     SmallVector<BasicBlock *, 8> NewBlocks;
366     // Maps Blocks[It] -> Blocks[It-1]
367     DenseMap<Value *, Value *> PrevItValueMap;
368     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
369     NewLoops[L] = L;
370     NewLoops[SubLoop] = SubLoop;
371 
372     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
373       ValueToValueMapTy VMap;
374       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
375       Header->getParent()->getBasicBlockList().push_back(New);
376 
377       // Tell LI about New.
378       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
379 
380       if (ForeBlocks.count(*BB)) {
381         if (*BB == ForeBlocksFirst[0])
382           ForeBlocksFirst.push_back(New);
383         if (*BB == ForeBlocksLast[0])
384           ForeBlocksLast.push_back(New);
385       } else if (SubLoopBlocks.count(*BB)) {
386         if (*BB == SubLoopBlocksFirst[0])
387           SubLoopBlocksFirst.push_back(New);
388         if (*BB == SubLoopBlocksLast[0])
389           SubLoopBlocksLast.push_back(New);
390       } else if (AftBlocks.count(*BB)) {
391         if (*BB == AftBlocksFirst[0])
392           AftBlocksFirst.push_back(New);
393         if (*BB == AftBlocksLast[0])
394           AftBlocksLast.push_back(New);
395       } else {
396         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
397       }
398 
399       // Update our running maps of newest clones
400       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
401       LastValueMap[*BB] = New;
402       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
403            VI != VE; ++VI) {
404         PrevItValueMap[VI->second] =
405             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
406         LastValueMap[VI->first] = VI->second;
407       }
408 
409       NewBlocks.push_back(New);
410 
411       // Update DomTree:
412       if (*BB == ForeBlocksFirst[0])
413         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
414       else if (*BB == SubLoopBlocksFirst[0])
415         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
416       else if (*BB == AftBlocksFirst[0])
417         DT->addNewBlock(New, AftBlocksLast[It - 1]);
418       else {
419         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
420         // structure.
421         auto BBDomNode = DT->getNode(*BB);
422         auto BBIDom = BBDomNode->getIDom();
423         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
424         assert(OriginalBBIDom);
425         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
426         DT->addNewBlock(
427             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
428       }
429     }
430 
431     // Remap all instructions in the most recent iteration
432     remapInstructionsInBlocks(NewBlocks, LastValueMap);
433     for (BasicBlock *NewBlock : NewBlocks) {
434       for (Instruction &I : *NewBlock) {
435         if (auto *II = dyn_cast<IntrinsicInst>(&I))
436           if (II->getIntrinsicID() == Intrinsic::assume)
437             AC->registerAssumption(II);
438       }
439     }
440 
441     // Alter the ForeBlocks phi's, pointing them at the latest version of the
442     // value from the previous iteration's phis
443     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
444       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
445       assert(OldValue && "should have incoming edge from Aft[It]");
446       Value *NewValue = OldValue;
447       if (Value *PrevValue = PrevItValueMap[OldValue])
448         NewValue = PrevValue;
449 
450       assert(Phi.getNumOperands() == 2);
451       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
452       Phi.setIncomingValue(0, NewValue);
453       Phi.removeIncomingValue(1);
454     }
455   }
456 
457   // Now that all the basic blocks for the unrolled iterations are in place,
458   // finish up connecting the blocks and phi nodes. At this point LastValueMap
459   // is the last unrolled iterations values.
460 
461   // Update Phis in BB from OldBB to point to NewBB and use the latest value
462   // from LastValueMap
463   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
464                                      BasicBlock *NewBB,
465                                      ValueToValueMapTy &LastValueMap) {
466     for (PHINode &Phi : BB->phis()) {
467       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
468         if (Phi.getIncomingBlock(b) == OldBB) {
469           Value *OldValue = Phi.getIncomingValue(b);
470           if (Value *LastValue = LastValueMap[OldValue])
471             Phi.setIncomingValue(b, LastValue);
472           Phi.setIncomingBlock(b, NewBB);
473           break;
474         }
475       }
476     }
477   };
478   // Move all the phis from Src into Dest
479   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
480     Instruction *insertPoint = Dest->getFirstNonPHI();
481     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
482       Phi->moveBefore(insertPoint);
483   };
484 
485   // Update the PHI values outside the loop to point to the last block
486   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
487                            LastValueMap);
488 
489   // Update ForeBlocks successors and phi nodes
490   BranchInst *ForeTerm =
491       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
492   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
493   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
494 
495   if (CompletelyUnroll) {
496     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
497       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
498       Phi->getParent()->getInstList().erase(Phi);
499     }
500   } else {
501     // Update the PHI values to point to the last aft block
502     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
503                              AftBlocksLast.back(), LastValueMap);
504   }
505 
506   for (unsigned It = 1; It != Count; It++) {
507     // Remap ForeBlock successors from previous iteration to this
508     BranchInst *ForeTerm =
509         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
510     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
511     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
512   }
513 
514   // Subloop successors and phis
515   BranchInst *SubTerm =
516       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
517   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
518   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
519   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
520                                             ForeBlocksLast.back());
521   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
522                                             SubLoopBlocksLast.back());
523 
524   for (unsigned It = 1; It != Count; It++) {
525     // Replace the conditional branch of the previous iteration subloop with an
526     // unconditional one to this one
527     BranchInst *SubTerm =
528         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
529     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
530     SubTerm->eraseFromParent();
531 
532     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
533                                                ForeBlocksLast.back());
534     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
535                                                SubLoopBlocksLast.back());
536     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
537   }
538 
539   // Aft blocks successors and phis
540   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
541   if (CompletelyUnroll) {
542     BranchInst::Create(LoopExit, AftTerm);
543     AftTerm->eraseFromParent();
544   } else {
545     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
546     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
547            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
548   }
549   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
550                                         SubLoopBlocksLast.back());
551 
552   for (unsigned It = 1; It != Count; It++) {
553     // Replace the conditional branch of the previous iteration subloop with an
554     // unconditional one to this one
555     BranchInst *AftTerm =
556         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
557     BranchInst::Create(AftBlocksFirst[It], AftTerm);
558     AftTerm->eraseFromParent();
559 
560     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
561                                            SubLoopBlocksLast.back());
562     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
563   }
564 
565   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
566   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
567   // new ones required.
568   if (Count != 1) {
569     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
570     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
571                            SubLoopBlocksFirst[0]);
572     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
573                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
574 
575     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
576                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
577     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
578                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
579     DTU.applyUpdatesPermissive(DTUpdates);
580   }
581 
582   // Merge adjacent basic blocks, if possible.
583   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
584   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
585   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
586   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
587 
588   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
589 
590   // Apply updates to the DomTree.
591   DT = &DTU.getDomTree();
592 
593   // At this point, the code is well formed.  We now do a quick sweep over the
594   // inserted code, doing constant propagation and dead code elimination as we
595   // go.
596   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
597   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
598                           TTI);
599 
600   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
601   ++NumUnrolledAndJammed;
602 
603   // Update LoopInfo if the loop is completely removed.
604   if (CompletelyUnroll)
605     LI->erase(L);
606 
607 #ifndef NDEBUG
608   // We shouldn't have done anything to break loop simplify form or LCSSA.
609   Loop *OutestLoop = SubLoop->getParentLoop()
610                          ? SubLoop->getParentLoop()->getParentLoop()
611                                ? SubLoop->getParentLoop()->getParentLoop()
612                                : SubLoop->getParentLoop()
613                          : SubLoop;
614   assert(DT->verify());
615   LI->verify(*DT);
616   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
617   if (!CompletelyUnroll)
618     assert(L->isLoopSimplifyForm());
619   assert(SubLoop->isLoopSimplifyForm());
620   SE->verify();
621 #endif
622 
623   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
624                           : LoopUnrollResult::PartiallyUnrolled;
625 }
626 
627 static bool getLoadsAndStores(BasicBlockSet &Blocks,
628                               SmallVector<Instruction *, 4> &MemInstr) {
629   // Scan the BBs and collect legal loads and stores.
630   // Returns false if non-simple loads/stores are found.
631   for (BasicBlock *BB : Blocks) {
632     for (Instruction &I : *BB) {
633       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
634         if (!Ld->isSimple())
635           return false;
636         MemInstr.push_back(&I);
637       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
638         if (!St->isSimple())
639           return false;
640         MemInstr.push_back(&I);
641       } else if (I.mayReadOrWriteMemory()) {
642         return false;
643       }
644     }
645   }
646   return true;
647 }
648 
649 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
650                                        unsigned UnrollLevel, unsigned JamLevel,
651                                        bool Sequentialized, Dependence *D) {
652   // UnrollLevel might carry the dependency Src --> Dst
653   // Does a different loop after unrolling?
654   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
655        ++CurLoopDepth) {
656     auto JammedDir = D->getDirection(CurLoopDepth);
657     if (JammedDir == Dependence::DVEntry::LT)
658       return true;
659 
660     if (JammedDir & Dependence::DVEntry::GT)
661       return false;
662   }
663 
664   return true;
665 }
666 
667 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
668                                         unsigned UnrollLevel, unsigned JamLevel,
669                                         bool Sequentialized, Dependence *D) {
670   // UnrollLevel might carry the dependency Dst --> Src
671   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
672        ++CurLoopDepth) {
673     auto JammedDir = D->getDirection(CurLoopDepth);
674     if (JammedDir == Dependence::DVEntry::GT)
675       return true;
676 
677     if (JammedDir & Dependence::DVEntry::LT)
678       return false;
679   }
680 
681   // Backward dependencies are only preserved if not interleaved.
682   return Sequentialized;
683 }
684 
685 // Check whether it is semantically safe Src and Dst considering any potential
686 // dependency between them.
687 //
688 // @param UnrollLevel The level of the loop being unrolled
689 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
690 // different levels, the outermost common loop counts as jammed level
691 //
692 // @return true if is safe and false if there is a dependency violation.
693 static bool checkDependency(Instruction *Src, Instruction *Dst,
694                             unsigned UnrollLevel, unsigned JamLevel,
695                             bool Sequentialized, DependenceInfo &DI) {
696   assert(UnrollLevel <= JamLevel &&
697          "Expecting JamLevel to be at least UnrollLevel");
698 
699   if (Src == Dst)
700     return true;
701   // Ignore Input dependencies.
702   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
703     return true;
704 
705   // Check whether unroll-and-jam may violate a dependency.
706   // By construction, every dependency will be lexicographically non-negative
707   // (if it was, it would violate the current execution order), such as
708   //   (0,0,>,*,*)
709   // Unroll-and-jam changes the GT execution of two executions to the same
710   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
711   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
712   //   (0,0,>=,*,*)
713   // Now, the dependency is not necessarily non-negative anymore, i.e.
714   // unroll-and-jam may violate correctness.
715   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
716   if (!D)
717     return true;
718   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
719 
720   if (D->isConfused()) {
721     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
722                       << "  " << *Src << "\n"
723                       << "  " << *Dst << "\n");
724     return false;
725   }
726 
727   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
728   // non-equal direction, then the locations accessed in the inner levels cannot
729   // overlap in memory. We assumes the indexes never overlap into neighboring
730   // dimensions.
731   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
732     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
733       return true;
734 
735   auto UnrollDirection = D->getDirection(UnrollLevel);
736 
737   // If the distance carried by the unrolled loop is 0, then after unrolling
738   // that distance will become non-zero resulting in non-overlapping accesses in
739   // the inner loops.
740   if (UnrollDirection == Dependence::DVEntry::EQ)
741     return true;
742 
743   if (UnrollDirection & Dependence::DVEntry::LT &&
744       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
745                                   Sequentialized, D.get()))
746     return false;
747 
748   if (UnrollDirection & Dependence::DVEntry::GT &&
749       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
750                                    Sequentialized, D.get()))
751     return false;
752 
753   return true;
754 }
755 
756 static bool
757 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
758                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
759                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
760                   DependenceInfo &DI, LoopInfo &LI) {
761   SmallVector<BasicBlockSet, 8> AllBlocks;
762   for (Loop *L : Root.getLoopsInPreorder())
763     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
764       AllBlocks.push_back(ForeBlocksMap.lookup(L));
765   AllBlocks.push_back(SubLoopBlocks);
766   for (Loop *L : Root.getLoopsInPreorder())
767     if (AftBlocksMap.find(L) != AftBlocksMap.end())
768       AllBlocks.push_back(AftBlocksMap.lookup(L));
769 
770   unsigned LoopDepth = Root.getLoopDepth();
771   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
772   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
773   for (BasicBlockSet &Blocks : AllBlocks) {
774     CurrentLoadsAndStores.clear();
775     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
776       return false;
777 
778     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
779     unsigned CurLoopDepth = CurLoop->getLoopDepth();
780 
781     for (auto *Earlier : EarlierLoadsAndStores) {
782       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
783       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
784       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
785       for (auto *Later : CurrentLoadsAndStores) {
786         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
787                              DI))
788           return false;
789       }
790     }
791 
792     size_t NumInsts = CurrentLoadsAndStores.size();
793     for (size_t I = 0; I < NumInsts; ++I) {
794       for (size_t J = I; J < NumInsts; ++J) {
795         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
796                              LoopDepth, CurLoopDepth, true, DI))
797           return false;
798       }
799     }
800 
801     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
802                                  CurrentLoadsAndStores.end());
803   }
804   return true;
805 }
806 
807 static bool isEligibleLoopForm(const Loop &Root) {
808   // Root must have a child.
809   if (Root.getSubLoops().size() != 1)
810     return false;
811 
812   const Loop *L = &Root;
813   do {
814     // All loops in Root need to be in simplify and rotated form.
815     if (!L->isLoopSimplifyForm())
816       return false;
817 
818     if (!L->isRotatedForm())
819       return false;
820 
821     if (L->getHeader()->hasAddressTaken()) {
822       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
823       return false;
824     }
825 
826     unsigned SubLoopsSize = L->getSubLoops().size();
827     if (SubLoopsSize == 0)
828       return true;
829 
830     // Only one child is allowed.
831     if (SubLoopsSize != 1)
832       return false;
833 
834     L = L->getSubLoops()[0];
835   } while (L);
836 
837   return true;
838 }
839 
840 static Loop *getInnerMostLoop(Loop *L) {
841   while (!L->getSubLoops().empty())
842     L = L->getSubLoops()[0];
843   return L;
844 }
845 
846 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
847                                 DependenceInfo &DI, LoopInfo &LI) {
848   if (!isEligibleLoopForm(*L)) {
849     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
850     return false;
851   }
852 
853   /* We currently handle outer loops like this:
854         |
855     ForeFirst    <------\   }
856      Blocks             |   } ForeBlocks of L
857     ForeLast            |   }
858         |               |
859        ...              |
860         |               |
861     ForeFirst    <----\ |   }
862      Blocks           | |   } ForeBlocks of a inner loop of L
863     ForeLast          | |   }
864         |             | |
865     JamLoopFirst  <\  | |   }
866      Blocks        |  | |   } JamLoopBlocks of the innermost loop
867     JamLoopLast   -/  | |   }
868         |             | |
869     AftFirst          | |   }
870      Blocks           | |   } AftBlocks of a inner loop of L
871     AftLast     ------/ |   }
872         |               |
873        ...              |
874         |               |
875     AftFirst            |   }
876      Blocks             |   } AftBlocks of L
877     AftLast     --------/   }
878         |
879 
880     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
881     and AftBlocks, providing that there is one edge from Fores to SubLoops,
882     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
883     In practice we currently limit Aft blocks to a single block, and limit
884     things further in the profitablility checks of the unroll and jam pass.
885 
886     Because of the way we rearrange basic blocks, we also require that
887     the Fore blocks of L on all unrolled iterations are safe to move before the
888     blocks of the direct child of L of all iterations. So we require that the
889     phi node looping operands of ForeHeader can be moved to at least the end of
890     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
891     match up Phi's correctly.
892 
893     i.e. The old order of blocks used to be
894            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
895          It needs to be safe to transform this to
896            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
897 
898     There are then a number of checks along the lines of no calls, no
899     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
900     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
901     UnrollAndJamLoop if the trip count cannot be easily calculated.
902   */
903 
904   // Split blocks into Fore/SubLoop/Aft based on dominators
905   Loop *JamLoop = getInnerMostLoop(L);
906   BasicBlockSet SubLoopBlocks;
907   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
908   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
909   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
910                                 AftBlocksMap, DT)) {
911     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
912     return false;
913   }
914 
915   // Aft blocks may need to move instructions to fore blocks, which becomes more
916   // difficult if there are multiple (potentially conditionally executed)
917   // blocks. For now we just exclude loops with multiple aft blocks.
918   if (AftBlocksMap[L].size() != 1) {
919     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
920                          "multiple blocks after the loop\n");
921     return false;
922   }
923 
924   // Check inner loop backedge count is consistent on all iterations of the
925   // outer loop
926   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
927         return !hasIterationCountInvariantInParent(SubLoop, SE);
928       })) {
929     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
930                          "not consistent on each iteration\n");
931     return false;
932   }
933 
934   // Check the loop safety info for exceptions.
935   SimpleLoopSafetyInfo LSI;
936   LSI.computeLoopSafetyInfo(L);
937   if (LSI.anyBlockMayThrow()) {
938     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
939     return false;
940   }
941 
942   // We've ruled out the easy stuff and now need to check that there are no
943   // interdependencies which may prevent us from moving the:
944   //  ForeBlocks before Subloop and AftBlocks.
945   //  Subloop before AftBlocks.
946   //  ForeBlock phi operands before the subloop
947 
948   // Make sure we can move all instructions we need to before the subloop
949   BasicBlock *Header = L->getHeader();
950   BasicBlock *Latch = L->getLoopLatch();
951   BasicBlockSet AftBlocks = AftBlocksMap[L];
952   Loop *SubLoop = L->getSubLoops()[0];
953   if (!processHeaderPhiOperands(
954           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
955             if (SubLoop->contains(I->getParent()))
956               return false;
957             if (AftBlocks.count(I->getParent())) {
958               // If we hit a phi node in afts we know we are done (probably
959               // LCSSA)
960               if (isa<PHINode>(I))
961                 return false;
962               // Can't move instructions with side effects or memory
963               // reads/writes
964               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
965                 return false;
966             }
967             // Keep going
968             return true;
969           })) {
970     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
971                          "instructions after subloop to before it\n");
972     return false;
973   }
974 
975   // Check for memory dependencies which prohibit the unrolling we are doing.
976   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
977   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
978   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
979                          LI)) {
980     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
981     return false;
982   }
983 
984   return true;
985 }
986