xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision 0b37c1590418417c894529d371800dfac71ef887)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/DependenceAnalysis.h"
18 #include "llvm/Analysis/InstructionSimplify.h"
19 #include "llvm/Analysis/LoopAnalysisManager.h"
20 #include "llvm/Analysis/LoopIterator.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/Utils/Local.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/DebugInfoMetadata.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/Cloning.h"
35 #include "llvm/Transforms/Utils/LoopSimplify.h"
36 #include "llvm/Transforms/Utils/LoopUtils.h"
37 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
38 #include "llvm/Transforms/Utils/UnrollLoop.h"
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "loop-unroll-and-jam"
42 
43 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
44 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
45 
46 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
47 
48 // Partition blocks in an outer/inner loop pair into blocks before and after
49 // the loop
50 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
51                                      BasicBlockSet &ForeBlocks,
52                                      BasicBlockSet &SubLoopBlocks,
53                                      BasicBlockSet &AftBlocks,
54                                      DominatorTree *DT) {
55   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
56   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
57 
58   for (BasicBlock *BB : L->blocks()) {
59     if (!SubLoop->contains(BB)) {
60       if (DT->dominates(SubLoopLatch, BB))
61         AftBlocks.insert(BB);
62       else
63         ForeBlocks.insert(BB);
64     }
65   }
66 
67   // Check that all blocks in ForeBlocks together dominate the subloop
68   // TODO: This might ideally be done better with a dominator/postdominators.
69   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
70   for (BasicBlock *BB : ForeBlocks) {
71     if (BB == SubLoopPreHeader)
72       continue;
73     Instruction *TI = BB->getTerminator();
74     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
75       if (!ForeBlocks.count(TI->getSuccessor(i)))
76         return false;
77   }
78 
79   return true;
80 }
81 
82 // Looks at the phi nodes in Header for values coming from Latch. For these
83 // instructions and all their operands calls Visit on them, keeping going for
84 // all the operands in AftBlocks. Returns false if Visit returns false,
85 // otherwise returns true. This is used to process the instructions in the
86 // Aft blocks that need to be moved before the subloop. It is used in two
87 // places. One to check that the required set of instructions can be moved
88 // before the loop. Then to collect the instructions to actually move in
89 // moveHeaderPhiOperandsToForeBlocks.
90 template <typename T>
91 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
92                                      BasicBlockSet &AftBlocks, T Visit) {
93   SmallVector<Instruction *, 8> Worklist;
94   for (auto &Phi : Header->phis()) {
95     Value *V = Phi.getIncomingValueForBlock(Latch);
96     if (Instruction *I = dyn_cast<Instruction>(V))
97       Worklist.push_back(I);
98   }
99 
100   while (!Worklist.empty()) {
101     Instruction *I = Worklist.back();
102     Worklist.pop_back();
103     if (!Visit(I))
104       return false;
105 
106     if (AftBlocks.count(I->getParent()))
107       for (auto &U : I->operands())
108         if (Instruction *II = dyn_cast<Instruction>(U))
109           Worklist.push_back(II);
110   }
111 
112   return true;
113 }
114 
115 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
116 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
117                                               BasicBlock *Latch,
118                                               Instruction *InsertLoc,
119                                               BasicBlockSet &AftBlocks) {
120   // We need to ensure we move the instructions in the correct order,
121   // starting with the earliest required instruction and moving forward.
122   std::vector<Instruction *> Visited;
123   processHeaderPhiOperands(Header, Latch, AftBlocks,
124                            [&Visited, &AftBlocks](Instruction *I) {
125                              if (AftBlocks.count(I->getParent()))
126                                Visited.push_back(I);
127                              return true;
128                            });
129 
130   // Move all instructions in program order to before the InsertLoc
131   BasicBlock *InsertLocBB = InsertLoc->getParent();
132   for (Instruction *I : reverse(Visited)) {
133     if (I->getParent() != InsertLocBB)
134       I->moveBefore(InsertLoc);
135   }
136 }
137 
138 /*
139   This method performs Unroll and Jam. For a simple loop like:
140   for (i = ..)
141     Fore(i)
142     for (j = ..)
143       SubLoop(i, j)
144     Aft(i)
145 
146   Instead of doing normal inner or outer unrolling, we do:
147   for (i = .., i+=2)
148     Fore(i)
149     Fore(i+1)
150     for (j = ..)
151       SubLoop(i, j)
152       SubLoop(i+1, j)
153     Aft(i)
154     Aft(i+1)
155 
156   So the outer loop is essetially unrolled and then the inner loops are fused
157   ("jammed") together into a single loop. This can increase speed when there
158   are loads in SubLoop that are invariant to i, as they become shared between
159   the now jammed inner loops.
160 
161   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
162   Fore blocks are those before the inner loop, Aft are those after. Normal
163   Unroll code is used to copy each of these sets of blocks and the results are
164   combined together into the final form above.
165 
166   isSafeToUnrollAndJam should be used prior to calling this to make sure the
167   unrolling will be valid. Checking profitablility is also advisable.
168 
169   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
170   necessary to create one and not fully unrolled).
171 */
172 LoopUnrollResult llvm::UnrollAndJamLoop(
173     Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
174     bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
175     AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
176 
177   // When we enter here we should have already checked that it is safe
178   BasicBlock *Header = L->getHeader();
179   assert(Header && "No header.");
180   assert(L->getSubLoops().size() == 1);
181   Loop *SubLoop = *L->begin();
182 
183   // Don't enter the unroll code if there is nothing to do.
184   if (TripCount == 0 && Count < 2) {
185     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
186     return LoopUnrollResult::Unmodified;
187   }
188 
189   assert(Count > 0);
190   assert(TripMultiple > 0);
191   assert(TripCount == 0 || TripCount % TripMultiple == 0);
192 
193   // Are we eliminating the loop control altogether?
194   bool CompletelyUnroll = (Count == TripCount);
195 
196   // We use the runtime remainder in cases where we don't know trip multiple
197   if (TripMultiple == 1 || TripMultiple % Count != 0) {
198     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
199                                     /*UseEpilogRemainder*/ true,
200                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
201                                     LI, SE, DT, AC, true, EpilogueLoop)) {
202       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
203                            "generated when assuming runtime trip count\n");
204       return LoopUnrollResult::Unmodified;
205     }
206   }
207 
208   // Notify ScalarEvolution that the loop will be substantially changed,
209   // if not outright eliminated.
210   if (SE) {
211     SE->forgetLoop(L);
212     SE->forgetLoop(SubLoop);
213   }
214 
215   using namespace ore;
216   // Report the unrolling decision.
217   if (CompletelyUnroll) {
218     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
219                       << Header->getName() << " with trip count " << TripCount
220                       << "!\n");
221     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
222                                  L->getHeader())
223               << "completely unroll and jammed loop with "
224               << NV("UnrollCount", TripCount) << " iterations");
225   } else {
226     auto DiagBuilder = [&]() {
227       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
228                               L->getHeader());
229       return Diag << "unroll and jammed loop by a factor of "
230                   << NV("UnrollCount", Count);
231     };
232 
233     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
234                       << " by " << Count);
235     if (TripMultiple != 1) {
236       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
237       ORE->emit([&]() {
238         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
239                              << " trips per branch";
240       });
241     } else {
242       LLVM_DEBUG(dbgs() << " with run-time trip count");
243       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
244     }
245     LLVM_DEBUG(dbgs() << "!\n");
246   }
247 
248   BasicBlock *Preheader = L->getLoopPreheader();
249   BasicBlock *LatchBlock = L->getLoopLatch();
250   assert(Preheader && "No preheader");
251   assert(LatchBlock && "No latch block");
252   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
253   assert(BI && !BI->isUnconditional());
254   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
255   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
256   bool SubLoopContinueOnTrue = SubLoop->contains(
257       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
258 
259   // Partition blocks in an outer/inner loop pair into blocks before and after
260   // the loop
261   BasicBlockSet SubLoopBlocks;
262   BasicBlockSet ForeBlocks;
263   BasicBlockSet AftBlocks;
264   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
265                            DT);
266 
267   // We keep track of the entering/first and exiting/last block of each of
268   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
269   // blocks easier.
270   std::vector<BasicBlock *> ForeBlocksFirst;
271   std::vector<BasicBlock *> ForeBlocksLast;
272   std::vector<BasicBlock *> SubLoopBlocksFirst;
273   std::vector<BasicBlock *> SubLoopBlocksLast;
274   std::vector<BasicBlock *> AftBlocksFirst;
275   std::vector<BasicBlock *> AftBlocksLast;
276   ForeBlocksFirst.push_back(Header);
277   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
278   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
279   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
280   AftBlocksFirst.push_back(SubLoop->getExitBlock());
281   AftBlocksLast.push_back(L->getExitingBlock());
282   // Maps Blocks[0] -> Blocks[It]
283   ValueToValueMapTy LastValueMap;
284 
285   // Move any instructions from fore phi operands from AftBlocks into Fore.
286   moveHeaderPhiOperandsToForeBlocks(
287       Header, LatchBlock, SubLoop->getLoopPreheader()->getTerminator(),
288       AftBlocks);
289 
290   // The current on-the-fly SSA update requires blocks to be processed in
291   // reverse postorder so that LastValueMap contains the correct value at each
292   // exit.
293   LoopBlocksDFS DFS(L);
294   DFS.perform(LI);
295   // Stash the DFS iterators before adding blocks to the loop.
296   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
297   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
298 
299   if (Header->getParent()->isDebugInfoForProfiling())
300     for (BasicBlock *BB : L->getBlocks())
301       for (Instruction &I : *BB)
302         if (!isa<DbgInfoIntrinsic>(&I))
303           if (const DILocation *DIL = I.getDebugLoc()) {
304             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
305             if (NewDIL)
306               I.setDebugLoc(NewDIL.getValue());
307             else
308               LLVM_DEBUG(dbgs()
309                          << "Failed to create new discriminator: "
310                          << DIL->getFilename() << " Line: " << DIL->getLine());
311           }
312 
313   // Copy all blocks
314   for (unsigned It = 1; It != Count; ++It) {
315     std::vector<BasicBlock *> NewBlocks;
316     // Maps Blocks[It] -> Blocks[It-1]
317     DenseMap<Value *, Value *> PrevItValueMap;
318 
319     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
320       ValueToValueMapTy VMap;
321       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
322       Header->getParent()->getBasicBlockList().push_back(New);
323 
324       if (ForeBlocks.count(*BB)) {
325         L->addBasicBlockToLoop(New, *LI);
326 
327         if (*BB == ForeBlocksFirst[0])
328           ForeBlocksFirst.push_back(New);
329         if (*BB == ForeBlocksLast[0])
330           ForeBlocksLast.push_back(New);
331       } else if (SubLoopBlocks.count(*BB)) {
332         SubLoop->addBasicBlockToLoop(New, *LI);
333 
334         if (*BB == SubLoopBlocksFirst[0])
335           SubLoopBlocksFirst.push_back(New);
336         if (*BB == SubLoopBlocksLast[0])
337           SubLoopBlocksLast.push_back(New);
338       } else if (AftBlocks.count(*BB)) {
339         L->addBasicBlockToLoop(New, *LI);
340 
341         if (*BB == AftBlocksFirst[0])
342           AftBlocksFirst.push_back(New);
343         if (*BB == AftBlocksLast[0])
344           AftBlocksLast.push_back(New);
345       } else {
346         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
347       }
348 
349       // Update our running maps of newest clones
350       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
351       LastValueMap[*BB] = New;
352       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
353            VI != VE; ++VI) {
354         PrevItValueMap[VI->second] =
355             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
356         LastValueMap[VI->first] = VI->second;
357       }
358 
359       NewBlocks.push_back(New);
360 
361       // Update DomTree:
362       if (*BB == ForeBlocksFirst[0])
363         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
364       else if (*BB == SubLoopBlocksFirst[0])
365         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
366       else if (*BB == AftBlocksFirst[0])
367         DT->addNewBlock(New, AftBlocksLast[It - 1]);
368       else {
369         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
370         // structure.
371         auto BBDomNode = DT->getNode(*BB);
372         auto BBIDom = BBDomNode->getIDom();
373         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
374         assert(OriginalBBIDom);
375         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
376         DT->addNewBlock(
377             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
378       }
379     }
380 
381     // Remap all instructions in the most recent iteration
382     for (BasicBlock *NewBlock : NewBlocks) {
383       for (Instruction &I : *NewBlock) {
384         ::remapInstruction(&I, LastValueMap);
385         if (auto *II = dyn_cast<IntrinsicInst>(&I))
386           if (II->getIntrinsicID() == Intrinsic::assume)
387             AC->registerAssumption(II);
388       }
389     }
390 
391     // Alter the ForeBlocks phi's, pointing them at the latest version of the
392     // value from the previous iteration's phis
393     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
394       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
395       assert(OldValue && "should have incoming edge from Aft[It]");
396       Value *NewValue = OldValue;
397       if (Value *PrevValue = PrevItValueMap[OldValue])
398         NewValue = PrevValue;
399 
400       assert(Phi.getNumOperands() == 2);
401       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
402       Phi.setIncomingValue(0, NewValue);
403       Phi.removeIncomingValue(1);
404     }
405   }
406 
407   // Now that all the basic blocks for the unrolled iterations are in place,
408   // finish up connecting the blocks and phi nodes. At this point LastValueMap
409   // is the last unrolled iterations values.
410 
411   // Update Phis in BB from OldBB to point to NewBB
412   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
413                             BasicBlock *NewBB) {
414     for (PHINode &Phi : BB->phis()) {
415       int I = Phi.getBasicBlockIndex(OldBB);
416       Phi.setIncomingBlock(I, NewBB);
417     }
418   };
419   // Update Phis in BB from OldBB to point to NewBB and use the latest value
420   // from LastValueMap
421   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
422                                      BasicBlock *NewBB,
423                                      ValueToValueMapTy &LastValueMap) {
424     for (PHINode &Phi : BB->phis()) {
425       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
426         if (Phi.getIncomingBlock(b) == OldBB) {
427           Value *OldValue = Phi.getIncomingValue(b);
428           if (Value *LastValue = LastValueMap[OldValue])
429             Phi.setIncomingValue(b, LastValue);
430           Phi.setIncomingBlock(b, NewBB);
431           break;
432         }
433       }
434     }
435   };
436   // Move all the phis from Src into Dest
437   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
438     Instruction *insertPoint = Dest->getFirstNonPHI();
439     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
440       Phi->moveBefore(insertPoint);
441   };
442 
443   // Update the PHI values outside the loop to point to the last block
444   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
445                            LastValueMap);
446 
447   // Update ForeBlocks successors and phi nodes
448   BranchInst *ForeTerm =
449       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
450   BasicBlock *Dest = SubLoopBlocksFirst[0];
451   ForeTerm->setSuccessor(0, Dest);
452 
453   if (CompletelyUnroll) {
454     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
455       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
456       Phi->getParent()->getInstList().erase(Phi);
457     }
458   } else {
459     // Update the PHI values to point to the last aft block
460     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
461                              AftBlocksLast.back(), LastValueMap);
462   }
463 
464   for (unsigned It = 1; It != Count; It++) {
465     // Remap ForeBlock successors from previous iteration to this
466     BranchInst *ForeTerm =
467         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
468     BasicBlock *Dest = ForeBlocksFirst[It];
469     ForeTerm->setSuccessor(0, Dest);
470   }
471 
472   // Subloop successors and phis
473   BranchInst *SubTerm =
474       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
475   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
476   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
477   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
478                   ForeBlocksLast.back());
479   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
480                   SubLoopBlocksLast.back());
481 
482   for (unsigned It = 1; It != Count; It++) {
483     // Replace the conditional branch of the previous iteration subloop with an
484     // unconditional one to this one
485     BranchInst *SubTerm =
486         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
487     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
488     SubTerm->eraseFromParent();
489 
490     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
491                     ForeBlocksLast.back());
492     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
493                     SubLoopBlocksLast.back());
494     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
495   }
496 
497   // Aft blocks successors and phis
498   BranchInst *Term = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
499   if (CompletelyUnroll) {
500     BranchInst::Create(LoopExit, Term);
501     Term->eraseFromParent();
502   } else {
503     Term->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
504   }
505   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
506                   SubLoopBlocksLast.back());
507 
508   for (unsigned It = 1; It != Count; It++) {
509     // Replace the conditional branch of the previous iteration subloop with an
510     // unconditional one to this one
511     BranchInst *AftTerm =
512         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
513     BranchInst::Create(AftBlocksFirst[It], AftTerm);
514     AftTerm->eraseFromParent();
515 
516     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
517                     SubLoopBlocksLast.back());
518     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
519   }
520 
521   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
522   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
523   // new ones required.
524   if (Count != 1) {
525     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
526     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
527                            SubLoopBlocksFirst[0]);
528     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
529                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
530 
531     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
532                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
533     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
534                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
535     DTU.applyUpdatesPermissive(DTUpdates);
536   }
537 
538   // Merge adjacent basic blocks, if possible.
539   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
540   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
541   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
542   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
543   while (!MergeBlocks.empty()) {
544     BasicBlock *BB = *MergeBlocks.begin();
545     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
546     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
547       BasicBlock *Dest = Term->getSuccessor(0);
548       BasicBlock *Fold = Dest->getUniquePredecessor();
549       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
550         // Don't remove BB and add Fold as they are the same BB
551         assert(Fold == BB);
552         (void)Fold;
553         MergeBlocks.erase(Dest);
554       } else
555         MergeBlocks.erase(BB);
556     } else
557       MergeBlocks.erase(BB);
558   }
559   // Apply updates to the DomTree.
560   DT = &DTU.getDomTree();
561 
562   // At this point, the code is well formed.  We now do a quick sweep over the
563   // inserted code, doing constant propagation and dead code elimination as we
564   // go.
565   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
566   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
567 
568   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
569   ++NumUnrolledAndJammed;
570 
571 #ifndef NDEBUG
572   // We shouldn't have done anything to break loop simplify form or LCSSA.
573   Loop *OuterL = L->getParentLoop();
574   Loop *OutestLoop = OuterL ? OuterL : (!CompletelyUnroll ? L : SubLoop);
575   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
576   if (!CompletelyUnroll)
577     assert(L->isLoopSimplifyForm());
578   assert(SubLoop->isLoopSimplifyForm());
579   assert(DT->verify());
580 #endif
581 
582   // Update LoopInfo if the loop is completely removed.
583   if (CompletelyUnroll)
584     LI->erase(L);
585 
586   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
587                           : LoopUnrollResult::PartiallyUnrolled;
588 }
589 
590 static bool getLoadsAndStores(BasicBlockSet &Blocks,
591                               SmallVector<Value *, 4> &MemInstr) {
592   // Scan the BBs and collect legal loads and stores.
593   // Returns false if non-simple loads/stores are found.
594   for (BasicBlock *BB : Blocks) {
595     for (Instruction &I : *BB) {
596       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
597         if (!Ld->isSimple())
598           return false;
599         MemInstr.push_back(&I);
600       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
601         if (!St->isSimple())
602           return false;
603         MemInstr.push_back(&I);
604       } else if (I.mayReadOrWriteMemory()) {
605         return false;
606       }
607     }
608   }
609   return true;
610 }
611 
612 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
613                               SmallVector<Value *, 4> &Later,
614                               unsigned LoopDepth, bool InnerLoop,
615                               DependenceInfo &DI) {
616   // Use DA to check for dependencies between loads and stores that make unroll
617   // and jam invalid
618   for (Value *I : Earlier) {
619     for (Value *J : Later) {
620       Instruction *Src = cast<Instruction>(I);
621       Instruction *Dst = cast<Instruction>(J);
622       if (Src == Dst)
623         continue;
624       // Ignore Input dependencies.
625       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
626         continue;
627 
628       // Track dependencies, and if we find them take a conservative approach
629       // by allowing only = or < (not >), altough some > would be safe
630       // (depending upon unroll width).
631       // For the inner loop, we need to disallow any (> <) dependencies
632       // FIXME: Allow > so long as distance is less than unroll width
633       if (auto D = DI.depends(Src, Dst, true)) {
634         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
635 
636         if (D->isConfused()) {
637           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
638                             << "  " << *Src << "\n"
639                             << "  " << *Dst << "\n");
640           return false;
641         }
642         if (!InnerLoop) {
643           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
644             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
645                               << "  " << *Src << "\n"
646                               << "  " << *Dst << "\n");
647             return false;
648           }
649         } else {
650           assert(LoopDepth + 1 <= D->getLevels());
651           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
652               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
653             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
654                               << "  " << *Src << "\n"
655                               << "  " << *Dst << "\n");
656             return false;
657           }
658         }
659       }
660     }
661   }
662   return true;
663 }
664 
665 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
666                               BasicBlockSet &SubLoopBlocks,
667                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
668   // Get all loads/store pairs for each blocks
669   SmallVector<Value *, 4> ForeMemInstr;
670   SmallVector<Value *, 4> SubLoopMemInstr;
671   SmallVector<Value *, 4> AftMemInstr;
672   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
673       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
674       !getLoadsAndStores(AftBlocks, AftMemInstr))
675     return false;
676 
677   // Check for dependencies between any blocks that may change order
678   unsigned LoopDepth = L->getLoopDepth();
679   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
680                            DI) &&
681          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
682          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
683                            DI) &&
684          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
685                            DI);
686 }
687 
688 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
689                                 DependenceInfo &DI) {
690   /* We currently handle outer loops like this:
691         |
692     ForeFirst    <----\    }
693      Blocks           |    } ForeBlocks
694     ForeLast          |    }
695         |             |
696     SubLoopFirst  <\  |    }
697      Blocks        |  |    } SubLoopBlocks
698     SubLoopLast   -/  |    }
699         |             |
700     AftFirst          |    }
701      Blocks           |    } AftBlocks
702     AftLast     ------/    }
703         |
704 
705     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
706     and AftBlocks, providing that there is one edge from Fores to SubLoops,
707     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
708     In practice we currently limit Aft blocks to a single block, and limit
709     things further in the profitablility checks of the unroll and jam pass.
710 
711     Because of the way we rearrange basic blocks, we also require that
712     the Fore blocks on all unrolled iterations are safe to move before the
713     SubLoop blocks of all iterations. So we require that the phi node looping
714     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
715     we can arrange cloned Fore Blocks before the subloop and match up Phi's
716     correctly.
717 
718     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
719     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
720 
721     There are then a number of checks along the lines of no calls, no
722     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
723     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
724     UnrollAndJamLoop if the trip count cannot be easily calculated.
725   */
726 
727   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
728     return false;
729   Loop *SubLoop = L->getSubLoops()[0];
730   if (!SubLoop->isLoopSimplifyForm())
731     return false;
732 
733   BasicBlock *Header = L->getHeader();
734   BasicBlock *Latch = L->getLoopLatch();
735   BasicBlock *Exit = L->getExitingBlock();
736   BasicBlock *SubLoopHeader = SubLoop->getHeader();
737   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
738   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
739 
740   if (Latch != Exit)
741     return false;
742   if (SubLoopLatch != SubLoopExit)
743     return false;
744 
745   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
746     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
747     return false;
748   }
749 
750   // Split blocks into Fore/SubLoop/Aft based on dominators
751   BasicBlockSet SubLoopBlocks;
752   BasicBlockSet ForeBlocks;
753   BasicBlockSet AftBlocks;
754   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
755                                 AftBlocks, &DT)) {
756     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
757     return false;
758   }
759 
760   // Aft blocks may need to move instructions to fore blocks, which becomes more
761   // difficult if there are multiple (potentially conditionally executed)
762   // blocks. For now we just exclude loops with multiple aft blocks.
763   if (AftBlocks.size() != 1) {
764     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
765                          "multiple blocks after the loop\n");
766     return false;
767   }
768 
769   // Check inner loop backedge count is consistent on all iterations of the
770   // outer loop
771   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
772     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
773                          "not consistent on each iteration\n");
774     return false;
775   }
776 
777   // Check the loop safety info for exceptions.
778   SimpleLoopSafetyInfo LSI;
779   LSI.computeLoopSafetyInfo(L);
780   if (LSI.anyBlockMayThrow()) {
781     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
782     return false;
783   }
784 
785   // We've ruled out the easy stuff and now need to check that there are no
786   // interdependencies which may prevent us from moving the:
787   //  ForeBlocks before Subloop and AftBlocks.
788   //  Subloop before AftBlocks.
789   //  ForeBlock phi operands before the subloop
790 
791   // Make sure we can move all instructions we need to before the subloop
792   if (!processHeaderPhiOperands(
793           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
794             if (SubLoop->contains(I->getParent()))
795               return false;
796             if (AftBlocks.count(I->getParent())) {
797               // If we hit a phi node in afts we know we are done (probably
798               // LCSSA)
799               if (isa<PHINode>(I))
800                 return false;
801               // Can't move instructions with side effects or memory
802               // reads/writes
803               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
804                 return false;
805             }
806             // Keep going
807             return true;
808           })) {
809     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
810                          "instructions after subloop to before it\n");
811     return false;
812   }
813 
814   // Check for memory dependencies which prohibit the unrolling we are doing.
815   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
816   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
817   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
818     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
819     return false;
820   }
821 
822   return true;
823 }
824