xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision c66ec88fed842fbaad62c30d510644ceb7bd2d71)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/DependenceAnalysis.h"
27 #include "llvm/Analysis/DomTreeUpdater.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/Analysis/LoopIterator.h"
30 #include "llvm/Analysis/MustExecute.h"
31 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32 #include "llvm/Analysis/ScalarEvolution.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DebugLoc.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/IR/User.h"
44 #include "llvm/IR/Value.h"
45 #include "llvm/IR/ValueHandle.h"
46 #include "llvm/IR/ValueMap.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/ErrorHandling.h"
50 #include "llvm/Support/GenericDomTree.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/LoopUtils.h"
55 #include "llvm/Transforms/Utils/UnrollLoop.h"
56 #include "llvm/Transforms/Utils/ValueMapper.h"
57 #include <assert.h>
58 #include <memory>
59 #include <type_traits>
60 #include <vector>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "loop-unroll-and-jam"
65 
66 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68 
69 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70 
71 // Partition blocks in an outer/inner loop pair into blocks before and after
72 // the loop
73 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
75   Loop *SubLoop = L.getSubLoops()[0];
76   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77 
78   for (BasicBlock *BB : L.blocks()) {
79     if (!SubLoop->contains(BB)) {
80       if (DT.dominates(SubLoopLatch, BB))
81         AftBlocks.insert(BB);
82       else
83         ForeBlocks.insert(BB);
84     }
85   }
86 
87   // Check that all blocks in ForeBlocks together dominate the subloop
88   // TODO: This might ideally be done better with a dominator/postdominators.
89   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90   for (BasicBlock *BB : ForeBlocks) {
91     if (BB == SubLoopPreHeader)
92       continue;
93     Instruction *TI = BB->getTerminator();
94     for (BasicBlock *Succ : successors(TI))
95       if (!ForeBlocks.count(Succ))
96         return false;
97   }
98 
99   return true;
100 }
101 
102 /// Partition blocks in a loop nest into blocks before and after each inner
103 /// loop.
104 static bool partitionOuterLoopBlocks(
105     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109 
110   for (Loop *L : Root.getLoopsInPreorder()) {
111     if (L == &JamLoop)
112       break;
113 
114     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115       return false;
116   }
117 
118   return true;
119 }
120 
121 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122 // than 2 levels loop.
123 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                      BasicBlockSet &ForeBlocks,
125                                      BasicBlockSet &SubLoopBlocks,
126                                      BasicBlockSet &AftBlocks,
127                                      DominatorTree *DT) {
128   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130 }
131 
132 // Looks at the phi nodes in Header for values coming from Latch. For these
133 // instructions and all their operands calls Visit on them, keeping going for
134 // all the operands in AftBlocks. Returns false if Visit returns false,
135 // otherwise returns true. This is used to process the instructions in the
136 // Aft blocks that need to be moved before the subloop. It is used in two
137 // places. One to check that the required set of instructions can be moved
138 // before the loop. Then to collect the instructions to actually move in
139 // moveHeaderPhiOperandsToForeBlocks.
140 template <typename T>
141 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                      BasicBlockSet &AftBlocks, T Visit) {
143   SmallVector<Instruction *, 8> Worklist;
144   for (auto &Phi : Header->phis()) {
145     Value *V = Phi.getIncomingValueForBlock(Latch);
146     if (Instruction *I = dyn_cast<Instruction>(V))
147       Worklist.push_back(I);
148   }
149 
150   while (!Worklist.empty()) {
151     Instruction *I = Worklist.back();
152     Worklist.pop_back();
153     if (!Visit(I))
154       return false;
155 
156     if (AftBlocks.count(I->getParent()))
157       for (auto &U : I->operands())
158         if (Instruction *II = dyn_cast<Instruction>(U))
159           Worklist.push_back(II);
160   }
161 
162   return true;
163 }
164 
165 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
166 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
167                                               BasicBlock *Latch,
168                                               Instruction *InsertLoc,
169                                               BasicBlockSet &AftBlocks) {
170   // We need to ensure we move the instructions in the correct order,
171   // starting with the earliest required instruction and moving forward.
172   std::vector<Instruction *> Visited;
173   processHeaderPhiOperands(Header, Latch, AftBlocks,
174                            [&Visited, &AftBlocks](Instruction *I) {
175                              if (AftBlocks.count(I->getParent()))
176                                Visited.push_back(I);
177                              return true;
178                            });
179 
180   // Move all instructions in program order to before the InsertLoc
181   BasicBlock *InsertLocBB = InsertLoc->getParent();
182   for (Instruction *I : reverse(Visited)) {
183     if (I->getParent() != InsertLocBB)
184       I->moveBefore(InsertLoc);
185   }
186 }
187 
188 /*
189   This method performs Unroll and Jam. For a simple loop like:
190   for (i = ..)
191     Fore(i)
192     for (j = ..)
193       SubLoop(i, j)
194     Aft(i)
195 
196   Instead of doing normal inner or outer unrolling, we do:
197   for (i = .., i+=2)
198     Fore(i)
199     Fore(i+1)
200     for (j = ..)
201       SubLoop(i, j)
202       SubLoop(i+1, j)
203     Aft(i)
204     Aft(i+1)
205 
206   So the outer loop is essetially unrolled and then the inner loops are fused
207   ("jammed") together into a single loop. This can increase speed when there
208   are loads in SubLoop that are invariant to i, as they become shared between
209   the now jammed inner loops.
210 
211   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
212   Fore blocks are those before the inner loop, Aft are those after. Normal
213   Unroll code is used to copy each of these sets of blocks and the results are
214   combined together into the final form above.
215 
216   isSafeToUnrollAndJam should be used prior to calling this to make sure the
217   unrolling will be valid. Checking profitablility is also advisable.
218 
219   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
220   necessary to create one and not fully unrolled).
221 */
222 LoopUnrollResult
223 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
224                        unsigned TripMultiple, bool UnrollRemainder,
225                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
226                        AssumptionCache *AC, const TargetTransformInfo *TTI,
227                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
228 
229   // When we enter here we should have already checked that it is safe
230   BasicBlock *Header = L->getHeader();
231   assert(Header && "No header.");
232   assert(L->getSubLoops().size() == 1);
233   Loop *SubLoop = *L->begin();
234 
235   // Don't enter the unroll code if there is nothing to do.
236   if (TripCount == 0 && Count < 2) {
237     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
238     return LoopUnrollResult::Unmodified;
239   }
240 
241   assert(Count > 0);
242   assert(TripMultiple > 0);
243   assert(TripCount == 0 || TripCount % TripMultiple == 0);
244 
245   // Are we eliminating the loop control altogether?
246   bool CompletelyUnroll = (Count == TripCount);
247 
248   // We use the runtime remainder in cases where we don't know trip multiple
249   if (TripMultiple == 1 || TripMultiple % Count != 0) {
250     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
251                                     /*UseEpilogRemainder*/ true,
252                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
253                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
254       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
255                            "generated when assuming runtime trip count\n");
256       return LoopUnrollResult::Unmodified;
257     }
258   }
259 
260   // Notify ScalarEvolution that the loop will be substantially changed,
261   // if not outright eliminated.
262   if (SE) {
263     SE->forgetLoop(L);
264     SE->forgetLoop(SubLoop);
265   }
266 
267   using namespace ore;
268   // Report the unrolling decision.
269   if (CompletelyUnroll) {
270     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
271                       << Header->getName() << " with trip count " << TripCount
272                       << "!\n");
273     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
274                                  L->getHeader())
275               << "completely unroll and jammed loop with "
276               << NV("UnrollCount", TripCount) << " iterations");
277   } else {
278     auto DiagBuilder = [&]() {
279       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
280                               L->getHeader());
281       return Diag << "unroll and jammed loop by a factor of "
282                   << NV("UnrollCount", Count);
283     };
284 
285     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
286                       << " by " << Count);
287     if (TripMultiple != 1) {
288       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
289       ORE->emit([&]() {
290         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
291                              << " trips per branch";
292       });
293     } else {
294       LLVM_DEBUG(dbgs() << " with run-time trip count");
295       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
296     }
297     LLVM_DEBUG(dbgs() << "!\n");
298   }
299 
300   BasicBlock *Preheader = L->getLoopPreheader();
301   BasicBlock *LatchBlock = L->getLoopLatch();
302   assert(Preheader && "No preheader");
303   assert(LatchBlock && "No latch block");
304   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
305   assert(BI && !BI->isUnconditional());
306   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
307   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
308   bool SubLoopContinueOnTrue = SubLoop->contains(
309       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
310 
311   // Partition blocks in an outer/inner loop pair into blocks before and after
312   // the loop
313   BasicBlockSet SubLoopBlocks;
314   BasicBlockSet ForeBlocks;
315   BasicBlockSet AftBlocks;
316   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
317                            DT);
318 
319   // We keep track of the entering/first and exiting/last block of each of
320   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
321   // blocks easier.
322   std::vector<BasicBlock *> ForeBlocksFirst;
323   std::vector<BasicBlock *> ForeBlocksLast;
324   std::vector<BasicBlock *> SubLoopBlocksFirst;
325   std::vector<BasicBlock *> SubLoopBlocksLast;
326   std::vector<BasicBlock *> AftBlocksFirst;
327   std::vector<BasicBlock *> AftBlocksLast;
328   ForeBlocksFirst.push_back(Header);
329   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
330   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
331   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
332   AftBlocksFirst.push_back(SubLoop->getExitBlock());
333   AftBlocksLast.push_back(L->getExitingBlock());
334   // Maps Blocks[0] -> Blocks[It]
335   ValueToValueMapTy LastValueMap;
336 
337   // Move any instructions from fore phi operands from AftBlocks into Fore.
338   moveHeaderPhiOperandsToForeBlocks(
339       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
340 
341   // The current on-the-fly SSA update requires blocks to be processed in
342   // reverse postorder so that LastValueMap contains the correct value at each
343   // exit.
344   LoopBlocksDFS DFS(L);
345   DFS.perform(LI);
346   // Stash the DFS iterators before adding blocks to the loop.
347   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
348   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
349 
350   if (Header->getParent()->isDebugInfoForProfiling())
351     for (BasicBlock *BB : L->getBlocks())
352       for (Instruction &I : *BB)
353         if (!isa<DbgInfoIntrinsic>(&I))
354           if (const DILocation *DIL = I.getDebugLoc()) {
355             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
356             if (NewDIL)
357               I.setDebugLoc(NewDIL.getValue());
358             else
359               LLVM_DEBUG(dbgs()
360                          << "Failed to create new discriminator: "
361                          << DIL->getFilename() << " Line: " << DIL->getLine());
362           }
363 
364   // Copy all blocks
365   for (unsigned It = 1; It != Count; ++It) {
366     SmallVector<BasicBlock *, 8> NewBlocks;
367     // Maps Blocks[It] -> Blocks[It-1]
368     DenseMap<Value *, Value *> PrevItValueMap;
369     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
370     NewLoops[L] = L;
371     NewLoops[SubLoop] = SubLoop;
372 
373     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
374       ValueToValueMapTy VMap;
375       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
376       Header->getParent()->getBasicBlockList().push_back(New);
377 
378       // Tell LI about New.
379       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
380 
381       if (ForeBlocks.count(*BB)) {
382         if (*BB == ForeBlocksFirst[0])
383           ForeBlocksFirst.push_back(New);
384         if (*BB == ForeBlocksLast[0])
385           ForeBlocksLast.push_back(New);
386       } else if (SubLoopBlocks.count(*BB)) {
387         if (*BB == SubLoopBlocksFirst[0])
388           SubLoopBlocksFirst.push_back(New);
389         if (*BB == SubLoopBlocksLast[0])
390           SubLoopBlocksLast.push_back(New);
391       } else if (AftBlocks.count(*BB)) {
392         if (*BB == AftBlocksFirst[0])
393           AftBlocksFirst.push_back(New);
394         if (*BB == AftBlocksLast[0])
395           AftBlocksLast.push_back(New);
396       } else {
397         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
398       }
399 
400       // Update our running maps of newest clones
401       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
402       LastValueMap[*BB] = New;
403       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
404            VI != VE; ++VI) {
405         PrevItValueMap[VI->second] =
406             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
407         LastValueMap[VI->first] = VI->second;
408       }
409 
410       NewBlocks.push_back(New);
411 
412       // Update DomTree:
413       if (*BB == ForeBlocksFirst[0])
414         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
415       else if (*BB == SubLoopBlocksFirst[0])
416         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
417       else if (*BB == AftBlocksFirst[0])
418         DT->addNewBlock(New, AftBlocksLast[It - 1]);
419       else {
420         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
421         // structure.
422         auto BBDomNode = DT->getNode(*BB);
423         auto BBIDom = BBDomNode->getIDom();
424         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
425         assert(OriginalBBIDom);
426         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
427         DT->addNewBlock(
428             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
429       }
430     }
431 
432     // Remap all instructions in the most recent iteration
433     remapInstructionsInBlocks(NewBlocks, LastValueMap);
434     for (BasicBlock *NewBlock : NewBlocks) {
435       for (Instruction &I : *NewBlock) {
436         if (auto *II = dyn_cast<IntrinsicInst>(&I))
437           if (II->getIntrinsicID() == Intrinsic::assume)
438             AC->registerAssumption(II);
439       }
440     }
441 
442     // Alter the ForeBlocks phi's, pointing them at the latest version of the
443     // value from the previous iteration's phis
444     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
445       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
446       assert(OldValue && "should have incoming edge from Aft[It]");
447       Value *NewValue = OldValue;
448       if (Value *PrevValue = PrevItValueMap[OldValue])
449         NewValue = PrevValue;
450 
451       assert(Phi.getNumOperands() == 2);
452       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
453       Phi.setIncomingValue(0, NewValue);
454       Phi.removeIncomingValue(1);
455     }
456   }
457 
458   // Now that all the basic blocks for the unrolled iterations are in place,
459   // finish up connecting the blocks and phi nodes. At this point LastValueMap
460   // is the last unrolled iterations values.
461 
462   // Update Phis in BB from OldBB to point to NewBB
463   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
464                             BasicBlock *NewBB) {
465     for (PHINode &Phi : BB->phis()) {
466       int I = Phi.getBasicBlockIndex(OldBB);
467       Phi.setIncomingBlock(I, NewBB);
468     }
469   };
470   // Update Phis in BB from OldBB to point to NewBB and use the latest value
471   // from LastValueMap
472   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
473                                      BasicBlock *NewBB,
474                                      ValueToValueMapTy &LastValueMap) {
475     for (PHINode &Phi : BB->phis()) {
476       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
477         if (Phi.getIncomingBlock(b) == OldBB) {
478           Value *OldValue = Phi.getIncomingValue(b);
479           if (Value *LastValue = LastValueMap[OldValue])
480             Phi.setIncomingValue(b, LastValue);
481           Phi.setIncomingBlock(b, NewBB);
482           break;
483         }
484       }
485     }
486   };
487   // Move all the phis from Src into Dest
488   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
489     Instruction *insertPoint = Dest->getFirstNonPHI();
490     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
491       Phi->moveBefore(insertPoint);
492   };
493 
494   // Update the PHI values outside the loop to point to the last block
495   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
496                            LastValueMap);
497 
498   // Update ForeBlocks successors and phi nodes
499   BranchInst *ForeTerm =
500       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
501   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
502   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
503 
504   if (CompletelyUnroll) {
505     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
506       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
507       Phi->getParent()->getInstList().erase(Phi);
508     }
509   } else {
510     // Update the PHI values to point to the last aft block
511     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
512                              AftBlocksLast.back(), LastValueMap);
513   }
514 
515   for (unsigned It = 1; It != Count; It++) {
516     // Remap ForeBlock successors from previous iteration to this
517     BranchInst *ForeTerm =
518         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
519     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
520     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
521   }
522 
523   // Subloop successors and phis
524   BranchInst *SubTerm =
525       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
526   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
527   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
528   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
529                   ForeBlocksLast.back());
530   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
531                   SubLoopBlocksLast.back());
532 
533   for (unsigned It = 1; It != Count; It++) {
534     // Replace the conditional branch of the previous iteration subloop with an
535     // unconditional one to this one
536     BranchInst *SubTerm =
537         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
538     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
539     SubTerm->eraseFromParent();
540 
541     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
542                     ForeBlocksLast.back());
543     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
544                     SubLoopBlocksLast.back());
545     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
546   }
547 
548   // Aft blocks successors and phis
549   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
550   if (CompletelyUnroll) {
551     BranchInst::Create(LoopExit, AftTerm);
552     AftTerm->eraseFromParent();
553   } else {
554     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
555     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
556            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
557   }
558   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
559                   SubLoopBlocksLast.back());
560 
561   for (unsigned It = 1; It != Count; It++) {
562     // Replace the conditional branch of the previous iteration subloop with an
563     // unconditional one to this one
564     BranchInst *AftTerm =
565         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
566     BranchInst::Create(AftBlocksFirst[It], AftTerm);
567     AftTerm->eraseFromParent();
568 
569     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
570                     SubLoopBlocksLast.back());
571     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
572   }
573 
574   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
575   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
576   // new ones required.
577   if (Count != 1) {
578     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
579     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
580                            SubLoopBlocksFirst[0]);
581     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
582                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
583 
584     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
585                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
586     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
587                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
588     DTU.applyUpdatesPermissive(DTUpdates);
589   }
590 
591   // Merge adjacent basic blocks, if possible.
592   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
593   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
594   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
595   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
596 
597   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
598 
599   // Apply updates to the DomTree.
600   DT = &DTU.getDomTree();
601 
602   // At this point, the code is well formed.  We now do a quick sweep over the
603   // inserted code, doing constant propagation and dead code elimination as we
604   // go.
605   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
606   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
607                           TTI);
608 
609   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
610   ++NumUnrolledAndJammed;
611 
612   // Update LoopInfo if the loop is completely removed.
613   if (CompletelyUnroll)
614     LI->erase(L);
615 
616 #ifndef NDEBUG
617   // We shouldn't have done anything to break loop simplify form or LCSSA.
618   Loop *OutestLoop = SubLoop->getParentLoop()
619                          ? SubLoop->getParentLoop()->getParentLoop()
620                                ? SubLoop->getParentLoop()->getParentLoop()
621                                : SubLoop->getParentLoop()
622                          : SubLoop;
623   assert(DT->verify());
624   LI->verify(*DT);
625   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
626   if (!CompletelyUnroll)
627     assert(L->isLoopSimplifyForm());
628   assert(SubLoop->isLoopSimplifyForm());
629   SE->verify();
630 #endif
631 
632   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
633                           : LoopUnrollResult::PartiallyUnrolled;
634 }
635 
636 static bool getLoadsAndStores(BasicBlockSet &Blocks,
637                               SmallVector<Instruction *, 4> &MemInstr) {
638   // Scan the BBs and collect legal loads and stores.
639   // Returns false if non-simple loads/stores are found.
640   for (BasicBlock *BB : Blocks) {
641     for (Instruction &I : *BB) {
642       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
643         if (!Ld->isSimple())
644           return false;
645         MemInstr.push_back(&I);
646       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
647         if (!St->isSimple())
648           return false;
649         MemInstr.push_back(&I);
650       } else if (I.mayReadOrWriteMemory()) {
651         return false;
652       }
653     }
654   }
655   return true;
656 }
657 
658 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
659                                        unsigned UnrollLevel, unsigned JamLevel,
660                                        bool Sequentialized, Dependence *D) {
661   // UnrollLevel might carry the dependency Src --> Dst
662   // Does a different loop after unrolling?
663   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
664        ++CurLoopDepth) {
665     auto JammedDir = D->getDirection(CurLoopDepth);
666     if (JammedDir == Dependence::DVEntry::LT)
667       return true;
668 
669     if (JammedDir & Dependence::DVEntry::GT)
670       return false;
671   }
672 
673   return true;
674 }
675 
676 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
677                                         unsigned UnrollLevel, unsigned JamLevel,
678                                         bool Sequentialized, Dependence *D) {
679   // UnrollLevel might carry the dependency Dst --> Src
680   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
681        ++CurLoopDepth) {
682     auto JammedDir = D->getDirection(CurLoopDepth);
683     if (JammedDir == Dependence::DVEntry::GT)
684       return true;
685 
686     if (JammedDir & Dependence::DVEntry::LT)
687       return false;
688   }
689 
690   // Backward dependencies are only preserved if not interleaved.
691   return Sequentialized;
692 }
693 
694 // Check whether it is semantically safe Src and Dst considering any potential
695 // dependency between them.
696 //
697 // @param UnrollLevel The level of the loop being unrolled
698 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
699 // different levels, the outermost common loop counts as jammed level
700 //
701 // @return true if is safe and false if there is a dependency violation.
702 static bool checkDependency(Instruction *Src, Instruction *Dst,
703                             unsigned UnrollLevel, unsigned JamLevel,
704                             bool Sequentialized, DependenceInfo &DI) {
705   assert(UnrollLevel <= JamLevel &&
706          "Expecting JamLevel to be at least UnrollLevel");
707 
708   if (Src == Dst)
709     return true;
710   // Ignore Input dependencies.
711   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
712     return true;
713 
714   // Check whether unroll-and-jam may violate a dependency.
715   // By construction, every dependency will be lexicographically non-negative
716   // (if it was, it would violate the current execution order), such as
717   //   (0,0,>,*,*)
718   // Unroll-and-jam changes the GT execution of two executions to the same
719   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
720   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
721   //   (0,0,>=,*,*)
722   // Now, the dependency is not necessarily non-negative anymore, i.e.
723   // unroll-and-jam may violate correctness.
724   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
725   if (!D)
726     return true;
727   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
728 
729   if (D->isConfused()) {
730     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
731                       << "  " << *Src << "\n"
732                       << "  " << *Dst << "\n");
733     return false;
734   }
735 
736   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
737   // non-equal direction, then the locations accessed in the inner levels cannot
738   // overlap in memory. We assumes the indexes never overlap into neighboring
739   // dimensions.
740   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
741     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
742       return true;
743 
744   auto UnrollDirection = D->getDirection(UnrollLevel);
745 
746   // If the distance carried by the unrolled loop is 0, then after unrolling
747   // that distance will become non-zero resulting in non-overlapping accesses in
748   // the inner loops.
749   if (UnrollDirection == Dependence::DVEntry::EQ)
750     return true;
751 
752   if (UnrollDirection & Dependence::DVEntry::LT &&
753       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
754                                   Sequentialized, D.get()))
755     return false;
756 
757   if (UnrollDirection & Dependence::DVEntry::GT &&
758       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
759                                    Sequentialized, D.get()))
760     return false;
761 
762   return true;
763 }
764 
765 static bool
766 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
767                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
768                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
769                   DependenceInfo &DI, LoopInfo &LI) {
770   SmallVector<BasicBlockSet, 8> AllBlocks;
771   for (Loop *L : Root.getLoopsInPreorder())
772     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
773       AllBlocks.push_back(ForeBlocksMap.lookup(L));
774   AllBlocks.push_back(SubLoopBlocks);
775   for (Loop *L : Root.getLoopsInPreorder())
776     if (AftBlocksMap.find(L) != AftBlocksMap.end())
777       AllBlocks.push_back(AftBlocksMap.lookup(L));
778 
779   unsigned LoopDepth = Root.getLoopDepth();
780   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
781   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
782   for (BasicBlockSet &Blocks : AllBlocks) {
783     CurrentLoadsAndStores.clear();
784     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
785       return false;
786 
787     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
788     unsigned CurLoopDepth = CurLoop->getLoopDepth();
789 
790     for (auto *Earlier : EarlierLoadsAndStores) {
791       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
792       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
793       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
794       for (auto *Later : CurrentLoadsAndStores) {
795         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
796                              DI))
797           return false;
798       }
799     }
800 
801     size_t NumInsts = CurrentLoadsAndStores.size();
802     for (size_t I = 0; I < NumInsts; ++I) {
803       for (size_t J = I; J < NumInsts; ++J) {
804         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
805                              LoopDepth, CurLoopDepth, true, DI))
806           return false;
807       }
808     }
809 
810     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
811                                  CurrentLoadsAndStores.end());
812   }
813   return true;
814 }
815 
816 static bool isEligibleLoopForm(const Loop &Root) {
817   // Root must have a child.
818   if (Root.getSubLoops().size() != 1)
819     return false;
820 
821   const Loop *L = &Root;
822   do {
823     // All loops in Root need to be in simplify and rotated form.
824     if (!L->isLoopSimplifyForm())
825       return false;
826 
827     if (!L->isRotatedForm())
828       return false;
829 
830     if (L->getHeader()->hasAddressTaken()) {
831       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
832       return false;
833     }
834 
835     unsigned SubLoopsSize = L->getSubLoops().size();
836     if (SubLoopsSize == 0)
837       return true;
838 
839     // Only one child is allowed.
840     if (SubLoopsSize != 1)
841       return false;
842 
843     L = L->getSubLoops()[0];
844   } while (L);
845 
846   return true;
847 }
848 
849 static Loop *getInnerMostLoop(Loop *L) {
850   while (!L->getSubLoops().empty())
851     L = L->getSubLoops()[0];
852   return L;
853 }
854 
855 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
856                                 DependenceInfo &DI, LoopInfo &LI) {
857   if (!isEligibleLoopForm(*L)) {
858     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
859     return false;
860   }
861 
862   /* We currently handle outer loops like this:
863         |
864     ForeFirst    <------\   }
865      Blocks             |   } ForeBlocks of L
866     ForeLast            |   }
867         |               |
868        ...              |
869         |               |
870     ForeFirst    <----\ |   }
871      Blocks           | |   } ForeBlocks of a inner loop of L
872     ForeLast          | |   }
873         |             | |
874     JamLoopFirst  <\  | |   }
875      Blocks        |  | |   } JamLoopBlocks of the innermost loop
876     JamLoopLast   -/  | |   }
877         |             | |
878     AftFirst          | |   }
879      Blocks           | |   } AftBlocks of a inner loop of L
880     AftLast     ------/ |   }
881         |               |
882        ...              |
883         |               |
884     AftFirst            |   }
885      Blocks             |   } AftBlocks of L
886     AftLast     --------/   }
887         |
888 
889     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
890     and AftBlocks, providing that there is one edge from Fores to SubLoops,
891     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
892     In practice we currently limit Aft blocks to a single block, and limit
893     things further in the profitablility checks of the unroll and jam pass.
894 
895     Because of the way we rearrange basic blocks, we also require that
896     the Fore blocks of L on all unrolled iterations are safe to move before the
897     blocks of the direct child of L of all iterations. So we require that the
898     phi node looping operands of ForeHeader can be moved to at least the end of
899     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
900     match up Phi's correctly.
901 
902     i.e. The old order of blocks used to be
903            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
904          It needs to be safe to transform this to
905            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
906 
907     There are then a number of checks along the lines of no calls, no
908     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
909     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
910     UnrollAndJamLoop if the trip count cannot be easily calculated.
911   */
912 
913   // Split blocks into Fore/SubLoop/Aft based on dominators
914   Loop *JamLoop = getInnerMostLoop(L);
915   BasicBlockSet SubLoopBlocks;
916   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
917   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
918   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
919                                 AftBlocksMap, DT)) {
920     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
921     return false;
922   }
923 
924   // Aft blocks may need to move instructions to fore blocks, which becomes more
925   // difficult if there are multiple (potentially conditionally executed)
926   // blocks. For now we just exclude loops with multiple aft blocks.
927   if (AftBlocksMap[L].size() != 1) {
928     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
929                          "multiple blocks after the loop\n");
930     return false;
931   }
932 
933   // Check inner loop backedge count is consistent on all iterations of the
934   // outer loop
935   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
936         return !hasIterationCountInvariantInParent(SubLoop, SE);
937       })) {
938     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
939                          "not consistent on each iteration\n");
940     return false;
941   }
942 
943   // Check the loop safety info for exceptions.
944   SimpleLoopSafetyInfo LSI;
945   LSI.computeLoopSafetyInfo(L);
946   if (LSI.anyBlockMayThrow()) {
947     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
948     return false;
949   }
950 
951   // We've ruled out the easy stuff and now need to check that there are no
952   // interdependencies which may prevent us from moving the:
953   //  ForeBlocks before Subloop and AftBlocks.
954   //  Subloop before AftBlocks.
955   //  ForeBlock phi operands before the subloop
956 
957   // Make sure we can move all instructions we need to before the subloop
958   BasicBlock *Header = L->getHeader();
959   BasicBlock *Latch = L->getLoopLatch();
960   BasicBlockSet AftBlocks = AftBlocksMap[L];
961   Loop *SubLoop = L->getSubLoops()[0];
962   if (!processHeaderPhiOperands(
963           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
964             if (SubLoop->contains(I->getParent()))
965               return false;
966             if (AftBlocks.count(I->getParent())) {
967               // If we hit a phi node in afts we know we are done (probably
968               // LCSSA)
969               if (isa<PHINode>(I))
970                 return false;
971               // Can't move instructions with side effects or memory
972               // reads/writes
973               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
974                 return false;
975             }
976             // Keep going
977             return true;
978           })) {
979     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
980                          "instructions after subloop to before it\n");
981     return false;
982   }
983 
984   // Check for memory dependencies which prohibit the unrolling we are doing.
985   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
986   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
987   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
988                          LI)) {
989     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
990     return false;
991   }
992 
993   return true;
994 }
995